1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Common kernel registrations for XLA devices.
17 
18 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
19 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/kernels/cast_op.h"
24 #include "tensorflow/core/kernels/constant_op.h"
25 #include "tensorflow/core/kernels/control_flow_ops.h"
26 #include "tensorflow/core/kernels/data/generator_dataset_op.h"
27 #include "tensorflow/core/kernels/data/iterator_ops.h"
28 #include "tensorflow/core/kernels/data/optional_ops.h"
29 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
30 #include "tensorflow/core/kernels/fifo_queue.h"
31 #include "tensorflow/core/kernels/function_ops.h"
32 #include "tensorflow/core/kernels/host_constant_op.h"
33 #include "tensorflow/core/kernels/identity_n_op.h"
34 #include "tensorflow/core/kernels/identity_op.h"
35 #include "tensorflow/core/kernels/no_op.h"
36 #include "tensorflow/core/kernels/queue_op.h"
37 #include "tensorflow/core/kernels/resource_variable_ops.h"
38 #include "tensorflow/core/kernels/sendrecv_ops.h"
39 #include "tensorflow/core/kernels/shape_ops.h"
40 #include "tensorflow/core/kernels/stack.h"
41 #include "tensorflow/core/kernels/variable_ops.h"
42 
43 namespace tensorflow {
44 
45 // Dummy OpKernel, used for kernels assigned to an XLA device that should be
46 // compiled. Should never be called at runtime since such ops should be
47 // rewritten to a XlaLaunch op. If it is called, it means the placer placed an
48 // operator on an XLA device but the compiler did not compile it.
49 class XlaDeviceDummyOp : public OpKernel {
50  public:
51   explicit XlaDeviceDummyOp(OpKernelConstruction* ctx);
52   void Compute(OpKernelContext* ctx) override;
53 };
54 
55 class XlaAssignVariableOp : public OpKernel {
56  public:
57   explicit XlaAssignVariableOp(OpKernelConstruction* c);
58   void Compute(OpKernelContext* context) override;
59 
60  private:
61   DataType dtype_;
62 };
63 
64 #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
65   REGISTER_KERNEL_BUILDER(Name("XlaLaunch")               \
66                               .Device(DEVICE)             \
67                               .HostMemory("constants")    \
68                               .HostMemory("resources"),   \
69                           KERNEL);
70 
71 #define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES)          \
72   REGISTER_KERNEL_BUILDER(Name("_XlaCompile")                       \
73                               .Device(DEVICE)                       \
74                               .HostMemory("constants")              \
75                               .HostMemory("key")                    \
76                               .HostMemory("compilation_successful") \
77                               .HostMemory("resources"),             \
78                           KERNEL);
79 
80 #define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
81   REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
82 
83 #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES)                             \
84   REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp);               \
85   REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp);               \
86   REGISTER_KERNEL_BUILDER(                                                     \
87       Name("_HostSend").Device(DEVICE).HostMemory("tensor"), SendOp);          \
88   REGISTER_KERNEL_BUILDER(                                                     \
89       Name("_HostRecv").Device(DEVICE).HostMemory("tensor"), RecvOp);          \
90   REGISTER_KERNEL_BUILDER(                                                     \
91       Name("_HostCast").Device(DEVICE).HostMemory("x").HostMemory("y"),        \
92       CpuCastOp);                                                              \
93   REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp);                  \
94   REGISTER_KERNEL_BUILDER(                                                     \
95       Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES),             \
96       ConstantOp);                                                             \
97   REGISTER_KERNEL_BUILDER(                                                     \
98       Name("HostConst").Device(DEVICE).HostMemory("output"), _HostConstantOp); \
99   REGISTER_KERNEL_BUILDER(                                                     \
100       Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
101   REGISTER_KERNEL_BUILDER(                                                     \
102       Name("Identity").Device(DEVICE).TypeConstraint("T", DT_STRING),          \
103       IdentityOp);                                                             \
104   REGISTER_KERNEL_BUILDER(                                                     \
105       Name("Identity").Device(DEVICE).TypeConstraint<Variant>("T"),            \
106       IdentityOp);                                                             \
107   REGISTER_KERNEL_BUILDER(Name("Identity")                                     \
108                               .Device(DEVICE)                                  \
109                               .TypeConstraint<ResourceHandle>("T")             \
110                               .HostMemory("input")                             \
111                               .HostMemory("output"),                           \
112                           IdentityOp);                                         \
113   REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp);      \
114   REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp);  \
115   REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE),                \
116                           PlaceholderOp);                                      \
117                                                                                \
118   REGISTER_KERNEL_BUILDER(                                                     \
119       Name("VarHandleOp").Device(DEVICE).HostMemory("resource"),               \
120       ResourceHandleOp<Var>);                                                  \
121   REGISTER_KERNEL_BUILDER(                                                     \
122       Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"),            \
123       ResourceHandlesOp<Var>);                                                 \
124   REGISTER_KERNEL_BUILDER(                                                     \
125       Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"),            \
126       ReadVariableOp);                                                         \
127   REGISTER_KERNEL_BUILDER(                                                     \
128       Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"),         \
129       ReadVariablesOp);                                                        \
130   REGISTER_KERNEL_BUILDER(                                                     \
131       Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"),         \
132       DestroyResourceOp);                                                      \
133   REGISTER_KERNEL_BUILDER(Name("Shape")                                        \
134                               .Device(DEVICE)                                  \
135                               .HostMemory("output")                            \
136                               .TypeConstraint<int32>("out_type")               \
137                               .TypeConstraint("T", TYPES),                     \
138                           ShapeOp<int32>);                                     \
139   REGISTER_KERNEL_BUILDER(Name("Shape")                                        \
140                               .Device(DEVICE)                                  \
141                               .HostMemory("output")                            \
142                               .TypeConstraint<int64>("out_type")               \
143                               .TypeConstraint("T", TYPES),                     \
144                           ShapeOp<int64>);                                     \
145   REGISTER_KERNEL_BUILDER(Name("ShapeN")                                       \
146                               .Device(DEVICE)                                  \
147                               .HostMemory("output")                            \
148                               .TypeConstraint<int32>("out_type")               \
149                               .TypeConstraint("T", TYPES),                     \
150                           ShapeNOp<int32>);                                    \
151   REGISTER_KERNEL_BUILDER(Name("ShapeN")                                       \
152                               .Device(DEVICE)                                  \
153                               .HostMemory("output")                            \
154                               .TypeConstraint<int64>("out_type")               \
155                               .TypeConstraint("T", TYPES),                     \
156                           ShapeNOp<int64>);                                    \
157   REGISTER_KERNEL_BUILDER(Name("Size")                                         \
158                               .Device(DEVICE)                                  \
159                               .HostMemory("output")                            \
160                               .TypeConstraint<int32>("out_type")               \
161                               .TypeConstraint("T", TYPES),                     \
162                           SizeOp<int32>);                                      \
163   REGISTER_KERNEL_BUILDER(Name("Size")                                         \
164                               .Device(DEVICE)                                  \
165                               .HostMemory("output")                            \
166                               .TypeConstraint<int64>("out_type")               \
167                               .TypeConstraint("T", TYPES),                     \
168                           SizeOp<int64>);                                      \
169   REGISTER_KERNEL_BUILDER(                                                     \
170       Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T",     \
171                                                                       TYPES),  \
172       RankOp);                                                                 \
173   REGISTER_KERNEL_BUILDER(                                                     \
174       Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"),          \
175       XlaAssignVariableOp);                                                    \
176   REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE),               \
177                           ControlTriggerOp);                                   \
178   REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"),    \
179                           SwitchOp);                                           \
180   REGISTER_KERNEL_BUILDER(                                                     \
181       Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp);        \
182   REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp);              \
183   REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp);                \
184   REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE),                \
185                           NextIterationOp);                                    \
186   REGISTER_KERNEL_BUILDER(Name("LoopCond")                                     \
187                               .Device(DEVICE)                                  \
188                               .HostMemory("input")                             \
189                               .HostMemory("output"),                           \
190                           LoopCondOp);                                         \
191                                                                                \
192   REGISTER_KERNEL_BUILDER(                                                     \
193       Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp);  \
194   REGISTER_KERNEL_BUILDER(                                                     \
195       Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp);  \
196   REGISTER_KERNEL_BUILDER(                                                     \
197       Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
198   REGISTER_KERNEL_BUILDER(Name("QueueSizeV2")                                  \
199                               .Device(DEVICE)                                  \
200                               .HostMemory("size")                              \
201                               .HostMemory("handle"),                           \
202                           QueueSizeOp);                                        \
203   REGISTER_KERNEL_BUILDER(                                                     \
204       Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"),             \
205       QueueIsClosedOp);                                                        \
206                                                                                \
207   REGISTER_KERNEL_BUILDER(                                                     \
208       Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);   \
209                                                                                \
210   REGISTER_KERNEL_BUILDER(                                                     \
211       Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp);          \
212   REGISTER_KERNEL_BUILDER(Name(kArgOp)                                         \
213                               .Device(DEVICE)                                  \
214                               .HostMemory("output")                            \
215                               .TypeConstraint<ResourceHandle>("T"),            \
216                           ArgOp);                                              \
217   REGISTER_KERNEL_BUILDER(                                                     \
218       Name(kArgOp).Device(DEVICE).TypeConstraint<Variant>("T"), ArgOp);        \
219                                                                                \
220   REGISTER_KERNEL_BUILDER(                                                     \
221       Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp);       \
222   REGISTER_KERNEL_BUILDER(Name(kRetOp)                                         \
223                               .Device(DEVICE)                                  \
224                               .TypeConstraint<ResourceHandle>("T")             \
225                               .HostMemory("input"),                            \
226                           RetvalOp);                                           \
227   REGISTER_KERNEL_BUILDER(                                                     \
228       Name(kDeviceRetOp).Device(DEVICE).TypeConstraint<int32>("T"), RetvalOp); \
229                                                                                \
230   REGISTER_KERNEL_BUILDER(                                                     \
231       Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp);   \
232                                                                                \
233   REGISTER_KERNEL_BUILDER(                                                     \
234       Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"),            \
235       data::GeneratorDatasetOp);                                               \
236   REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")                              \
237                               .Device(DEVICE)                                  \
238                               .HostMemory("buffer_size")                       \
239                               .HostMemory("input_dataset")                     \
240                               .HostMemory("handle"),                           \
241                           data::PrefetchDatasetOp);                            \
242                                                                                \
243   REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE),                   \
244                           data::IteratorHandleOp);                             \
245   REGISTER_KERNEL_BUILDER(                                                     \
246       Name("MakeIterator").Device(DEVICE).HostMemory("dataset"),               \
247       data::MakeIteratorOp);                                                   \
248   REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE),            \
249                           data::AnonymousIteratorHandleOp);                    \
250   REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE),              \
251                           data::IteratorGetNextOp);                            \
252   REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE),    \
253                           data::IteratorGetNextAsOptionalOp);                  \
254   REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE),          \
255                           data::IteratorGetNextSyncOp);                        \
256   REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")                       \
257                               .Device(DEVICE)                                  \
258                               .HostMemory("string_handle"),                    \
259                           data::IteratorToStringHandleOp);                     \
260   REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")                   \
261                               .Device(DEVICE)                                  \
262                               .HostMemory("string_handle"),                    \
263                           data::IteratorFromStringHandleOp);                   \
264   REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE),                 \
265                           data::OptionalNoneOp);                               \
266   REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE),            \
267                           data::OptionalFromValueOp);                          \
268   REGISTER_KERNEL_BUILDER(                                                     \
269       Name("OptionalHasValue").Device(DEVICE).HostMemory("has_value"),         \
270       data::OptionalHasValueOp);                                               \
271   REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE),             \
272                           data::OptionalGetValueOp);                           \
273   REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp)              \
274                               .Device(DEVICE)                                  \
275                               .HostMemory("output")                            \
276                               .TypeConstraint<string>("T"),                    \
277                           ArgOp);                                              \
278   REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp)              \
279                               .Device(DEVICE)                                  \
280                               .TypeConstraint<string>("T")                     \
281                               .HostMemory("input"),                            \
282                           RetvalOp);                                           \
283                                                                                \
284   REGISTER_KERNEL_BUILDER(Name("StackV2")                                      \
285                               .Device(DEVICE)                                  \
286                               .HostMemory("max_size")                          \
287                               .HostMemory("handle"),                           \
288                           StackOp);                                            \
289   REGISTER_KERNEL_BUILDER(Name("StackPushV2")                                  \
290                               .Device(DEVICE)                                  \
291                               .HostMemory("handle")                            \
292                               .TypeConstraint("T", TYPES),                     \
293                           TemplatedStackPushOp</*allow_swapping=*/false>);     \
294   REGISTER_KERNEL_BUILDER(Name("StackPopV2")                                   \
295                               .Device(DEVICE)                                  \
296                               .HostMemory("handle")                            \
297                               .TypeConstraint("elem_type", TYPES),             \
298                           StackPopOp);                                         \
299   REGISTER_KERNEL_BUILDER(                                                     \
300       Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp);
301 
302 // TODO(b/118881356): currently we do not register the QueueEnqueueMany,
303 // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
304 // and write the tensors they access in order to concatenate them into a batch.
305 // We would need either to call out to an XLA computation to perform the
306 // concatenation, or we would need to refactor those kernels so the splitting
307 // or merging is done in a separate operator that can be compiled.
308 
309 }  // namespace tensorflow
310 
311 #endif  // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
312