1 /* Copyright 2019 The TensorFlow Authors. Al Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
17 
18 #include <string>
19 
20 #include "tensorflow/core/lib/gtl/flatset.h"
21 #include "tensorflow/core/platform/types.h"
22 
23 namespace tensorflow {
24 
25 // TensorFlow runtime (both eager and graph) will aim to colocate ops with
26 // their resource inputs so that the ops can access the resource state. In some
27 // cases, such as tf.data ops, this is not desirable as the ops themselves might
28 // not have a kernel registered for the device on which the resource is placed
29 // and instead use a mechanism, such as a multi-device function, to access the
30 // resource state.
31 //
32 // This registry can be used to register and list ops that should be exempt from
33 // the input colocation described above.
34 //
35 // Example usage:
36 //   REGISTER_INPUT_COLOCATION_EXEMPTION("MapDataset");
37 class InputColocationExemptionRegistry {
38  public:
39   // Returns a pointer to a global InputColocationExemptionRegistry object.
40   static InputColocationExemptionRegistry* Global();
41 
42   // Returns the set of ops exempt from the input colocation constraints.
Get()43   const gtl::FlatSet<string>& Get() { return ops_; }
44 
45   // Registers an op to be excluded from the input colocation constraints.
46   void Register(const string& op);
47 
48  private:
49   gtl::FlatSet<string> ops_;
50 };
51 
52 namespace input_colocation_exemption_registration {
53 
54 class InputColocationExemptionRegistration {
55  public:
InputColocationExemptionRegistration(const string & op)56   explicit InputColocationExemptionRegistration(const string& op) {
57     InputColocationExemptionRegistry::Global()->Register(op);
58   }
59 };
60 
61 }  // namespace input_colocation_exemption_registration
62 
63 #define REGISTER_INPUT_COLOCATION_EXEMPTION(op) \
64   REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(__COUNTER__, op)
65 
66 #define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(ctr, op) \
67   REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op)
68 
69 #define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op) \
70   static input_colocation_exemption_registration::        \
71       InputColocationExemptionRegistration                \
72           input_colocation_exemption_registration_fn_##ctr(op)
73 
74 }  // namespace tensorflow
75 
76 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
77