1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
18
19 #include <functional>
20
21 #include <utility>
22 #include <vector>
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/control_flow.h"
26 #include "tensorflow/core/framework/device_base.h"
27 #include "tensorflow/core/framework/kernel_def_builder.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
30 #include "tensorflow/core/framework/rendezvous.h"
31 #include "tensorflow/core/framework/selective_registration.h"
32 #include "tensorflow/core/framework/session_state.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
36 #include "tensorflow/core/framework/tracking_allocator.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/framework/types.pb.h"
39 #include "tensorflow/core/framework/unique_tensor_references.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/lib/gtl/array_slice.h"
43 #include "tensorflow/core/lib/gtl/manual_constructor.h"
44 #include "tensorflow/core/platform/env.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/macros.h"
47 #include "tensorflow/core/platform/mutex.h"
48 #include "tensorflow/core/platform/thread_annotations.h"
49 #include "tensorflow/core/platform/types.h"
50
51 namespace Eigen {
52 struct ThreadPoolDevice;
53 struct GpuDevice;
54 struct SyclDevice;
55 } // end namespace Eigen
56
57 namespace tensorflow {
58
59 namespace checkpoint {
60 class TensorSliceReaderCacheWrapper;
61 } // namespace checkpoint
62
63 class AsyncOpKernel;
64 class CallFrameInterface;
65 class FunctionLibraryRuntime;
66 class OpKernelConstruction; // declared below
67 class OpKernelContext; // declared below
68 class OpRegistryInterface;
69 class ResourceMgr;
70 class ScopedStepContainer;
71 class StepStatsCollector;
72
73 class OpKernel {
74 public:
75 // OpKernel won't be instantiated by the scheduler, so you may perform
76 // expensive initialization in the descendant's constructor.
77 explicit OpKernel(OpKernelConstruction* context);
78
79 // Specialized constructor that enables the descendant to provide a different
80 // `NodeDef` value. For example, this constructor can be used to provide a
81 // stripped-down `NodeDef` that does not contain the full set of attrs (such
82 // as tensor values) if the descendant stores them in a different form.
83 explicit OpKernel(OpKernelConstruction* context,
84 std::unique_ptr<const NodeDef> node_def);
85
86 virtual ~OpKernel();
87
88 // An OpKernel's computation can be either synchronous or
89 // asynchronous. All OpKernel Compute() methods must be thread-safe as they
90 // may be called concurrently (e.g. by multiple executions of the same graph
91 // concurrently).
92 //
93 // Most OpKernels should compute synchronously. They should
94 // subclass OpKernel and override the Compute() method and have it
95 // return after completing the supplied work.
96 //
97 // A few special kernels might need to be asynchronous to bound the
98 // number of threads (e.g., network receive operations). These
99 // kernels must subclass AsyncOpKernel and override
100 // AsyncOpKernel::ComputeAsync().
101 //
102 // In both cases, implementations of Compute() and ComputeAsync()
103 // get inputs and write outputs through the given OpKernelContext
104 // and returns a status via context->SetStatus(). They must be
105 // thread-safe.
106
107 // Synchronous compute.
108 //
109 // "context" is guaranteed to be alive until Compute() returns.
110 virtual void Compute(OpKernelContext* context) = 0;
111
112 // Returns nullptr iff this op kernel is synchronous.
AsAsync()113 virtual AsyncOpKernel* AsAsync() { return nullptr; }
114
115 // Returns true iff this op kernel is considered "expensive". The
116 // runtime may use this flag to optimize graph execution for example
117 // to "inline" inexpensive kernels.
IsExpensive()118 virtual bool IsExpensive() { return expensive_; }
119
120 // Accessors.
def()121 const NodeDef& def() const { return *def_; }
122 const string& name() const; // Same as def().name()
123 const string& type_string() const; // Same as def().op()
124 const string& requested_device() const; // Same as def().device()
is_internal()125 bool is_internal() const { return is_internal_; }
126
num_inputs()127 int num_inputs() const { return input_types_.size(); }
input_type(int i)128 DataType input_type(int i) const { return input_types_[i]; }
input_types()129 const DataTypeVector& input_types() const { return input_types_; }
input_memory_types()130 const MemoryTypeVector& input_memory_types() const {
131 return input_memory_types_;
132 }
133 const string& requested_input(int i) const; // Same as def().input(i)
134
num_outputs()135 int num_outputs() const { return output_types_.size(); }
output_type(int o)136 DataType output_type(int o) const { return output_types_[o]; }
output_types()137 const DataTypeVector& output_types() const { return output_types_; }
output_memory_types()138 const MemoryTypeVector& output_memory_types() const {
139 return output_memory_types_;
140 }
141
142 Status InputRange(StringPiece input_name, int* start, int* stop) const;
143 Status OutputRange(StringPiece output_name, int* start, int* stop) const;
144
145 // We allow legacy scalars within Google up until GraphDef version 6.
146 // TODO(irving): Remove when we can drop support for GraphDef version 5.
allow_legacy_scalars()147 bool allow_legacy_scalars() const {
148 #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
149 return graph_def_version_ < 6;
150 #else
151 return false;
152 #endif
153 }
154
155 // Allow either scalars or (if allowing legacy scalars) shape (1,).
IsLegacyScalar(const TensorShape & shape)156 bool IsLegacyScalar(const TensorShape& shape) const {
157 return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 &&
158 shape.dim_size(0) == 1);
159 }
160
161 // Allow rank 1 or (if allowing legacy scalars) rank 0.
IsLegacyVector(const TensorShape & shape)162 bool IsLegacyVector(const TensorShape& shape) const {
163 return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0);
164 }
165
166 // Turn a shape Tensor into a TensorShape
167 // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars
168 Status MakeShape(const Tensor& shape, TensorShape* out) const;
169
170 private:
171 const std::unique_ptr<const NodeDef> def_;
172 const DataTypeVector input_types_;
173 const MemoryTypeVector input_memory_types_;
174 const DataTypeVector output_types_;
175 const MemoryTypeVector output_memory_types_;
176 const int graph_def_version_;
177 const bool is_internal_; // True if this is an internal operation
178 NameRangeMap input_name_map_;
179 NameRangeMap output_name_map_;
180 bool expensive_;
181
182 TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
183 };
184
185 class AsyncOpKernel : public OpKernel {
186 public:
187 using OpKernel::OpKernel; // Lift OpKernel constructors.
188
189 // Asynchronous compute.
190 //
191 // Implementations of ComputeAsync() must run "done" to signal the
192 // completion of the computation. "context" is guaranteed to be
193 // alive until the "done" callback starts.
194 typedef std::function<void()> DoneCallback;
195 virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
196
AsAsync()197 AsyncOpKernel* AsAsync() final { return this; }
198
199 void Compute(OpKernelContext* context) final;
200
IsExpensive()201 bool IsExpensive() override { return true; }
202 };
203
204 // Wraps a tensor that is held by an Op across calls to Compute(). For
205 // memory safety when using asynchronous devices like GPUs, the system
206 // must be notified when a Tensor is used inside an Op execution. The
207 // wrapper ensures that all uses of the Tensor are tracked, because in
208 // order to retrieve the Tensor the caller must use AccessTensor which
209 // notifies the context.
210 class PersistentTensor {
211 public:
PersistentTensor()212 PersistentTensor() {}
PersistentTensor(const Tensor & tensor)213 explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
214
215 // Caller does not own the returned Tensor*.
216 Tensor* AccessTensor(OpKernelConstruction* context);
217 // Caller does not own the returned Tensor*.
218 Tensor* AccessTensor(OpKernelContext* context);
219
220 // The check for initialization does not need to access the
221 // underlying tensor buffer.
IsInitialized()222 bool IsInitialized() const { return tensor_.IsInitialized(); }
223
NumElements()224 int64 NumElements() const { return tensor_.NumElements(); }
225
AllocatedBytes()226 int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
227
228 private:
229 Tensor tensor_;
230 };
231
232 class OpKernelConstruction {
233 public:
234 OpKernelConstruction(DeviceType device_type, DeviceBase* device,
235 Allocator* allocator, const NodeDef* node_def,
236 const OpDef* op_def, FunctionLibraryRuntime* flib,
237 const DataTypeSlice& input_types,
238 const MemoryTypeSlice& input_memory_types,
239 const DataTypeSlice& output_types,
240 const MemoryTypeSlice& output_memory_types,
241 int graph_def_version, Status* status);
242
env()243 Env* env() const { return device_->env(); }
244
245 // Allocation of tensors during kernel construction:
246 //
247 // It is legal to temporarily allocate scratch tensor storage during
248 // Op kernel construction. Scratch tensors should be allocated using
249 // allocate_temp below. Some kernels need to keep tensors in between
250 // invocations. If such a Tensor is allocated during kernel
251 // construction this must be done using allocate_persistent, and the
252 // Op may only store the returned PersistentTensor object. When the
253 // Tensor is needed in a subsequent invocation, it can be retrieved
254 // from the PersistentTensor using the AccessTensor method. This
255 // ensures that the system is made aware of any use of the tensor's
256 // allocated memory, which is needed for correctness on asynchronous
257 // devices such as GPUs.
258
259 // Allocates a temporary Tensor of the specified type and shape. The
260 // Tensor must not be used after kernel construction is
261 // complete. See comment above.
262 Status allocate_temp(DataType type, const TensorShape& shape,
263 Tensor* out_temp);
264
265 // Allocates a Tensor of the specified type and shape which the Op
266 // plans to maintain as persistent state. out_persistent holds the
267 // PersistentTensor which is the object the caller should store. For
268 // convenience, if out_tensor is non-null then it will be filled in
269 // with a Tensor* pointing to the newly-allocated tensor which the
270 // caller can use instead of calling
271 // out_persistent->AccessTensor. The caller does not own out_tensor
272 // and should not keep a copy of it. See comment above.
273 Status allocate_persistent(DataType type, const TensorShape& shape,
274 PersistentTensor* out_persistent,
275 Tensor** out_tensor);
276
277 // User-supplied configuration of this operation.
def()278 const NodeDef& def() const { return *def_; }
279
280 // For inspecting the inputs to this operation.
num_inputs()281 int num_inputs() const { return input_types_.size(); }
input_type(int i)282 DataType input_type(int i) const { return input_types_[i]; }
input_types()283 const DataTypeSlice& input_types() const { return input_types_; }
input_memory_types()284 const MemoryTypeSlice& input_memory_types() const {
285 return input_memory_types_;
286 }
287
288 // For inspecting the outputs expected from this operation.
num_outputs()289 int num_outputs() const { return output_types_.size(); }
output_type(int i)290 DataType output_type(int i) const { return output_types_[i]; }
output_types()291 const DataTypeSlice& output_types() const { return output_types_; }
output_memory_types()292 const MemoryTypeSlice& output_memory_types() const {
293 return output_memory_types_;
294 }
295
296 // If expected_inputs == inputs() and expected_outputs == output_types(),
297 // returns OK, else returns INVALID_ARGUMENT with an error message.
298 // Recommended for Ops with dynamic signatures.
299 Status MatchSignature(const DataTypeSlice expected_inputs,
300 const DataTypeSlice expected_outputs);
301
302 // For recording configuration errors during construction.
303 void SetStatus(const Status& status);
status()304 const Status& status() const { return *status_; }
305
306 // Look up the attr with name attr_name and set *value to its value. If no
307 // attr with attr_name is found in def(), or the attr does not have
308 // a matching type, a non-ok status will be returned.
309 template <class T>
310 Status GetAttr(StringPiece attr_name, T* value) const;
311
312 // Return true if the attr_name is defined in def().
313 bool HasAttr(StringPiece attr_name) const;
314
315 // Return the device type.
device_type()316 const DeviceType& device_type() const { return device_type_; }
317
318 // If not nullptr, the kernel can instantiate functions defined in
319 // the library. E.g.,
320 // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()321 FunctionLibraryRuntime* function_library() const { return flib_; }
322
323 // The GraphDef version whose behavior we should follow.
graph_def_version()324 int graph_def_version() const { return graph_def_version_; }
325
326 // Helper routines for the OP_REQUIRES macros
327 void CtxFailure(const Status& s);
328 void CtxFailureWithWarning(const Status& s);
329 void CtxFailure(const char* file, int line, const Status& s);
330 void CtxFailureWithWarning(const char* file, int line, const Status& s);
331
332 // Unrecommended functions: these are functions that have some
333 // current uses but are not recommended for use, and may go away at
334 // some future major version release.
335
336 // May be used, e.g., to get GPU handles, etc.
337 //
338 // Currently only used to call MakeTensorFromProto() for
339 // implementing ConstantOp for every device. See comments
340 // on Device::MakeTensorFromProto for longer-term replacement
341 // ideas.
device()342 DeviceBase* device() const { return device_; }
343
344 private:
345 const DeviceType device_type_;
346 DeviceBase* const device_;
347 Allocator* allocator_;
348 const NodeDef* def_;
349 const OpDef* op_def_;
350 FunctionLibraryRuntime* flib_;
351 DataTypeSlice input_types_;
352 MemoryTypeSlice input_memory_types_;
353 DataTypeSlice output_types_;
354 MemoryTypeSlice output_memory_types_;
355 const int graph_def_version_;
356 Status* status_;
357
358 // Allow op_def_ across from OpKernel, but not from subclasses.
359 // TODO(irving): Remove protos from this header entirely.
360 friend class OpKernel;
361
362 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
363 };
364
365 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
366 // tensorflow::gtl::iterator_range to make the below container classes
367 // unnecessary.
368 template <typename ListType, typename ElementType>
369 class OpArgIterator {
370 public:
371 typedef OpArgIterator<ListType, ElementType> ME;
OpArgIterator(const ListType * list,int i)372 OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
373 bool operator==(const ME& rhs) {
374 DCHECK(list_ == rhs.list_);
375 return i_ == rhs.i_;
376 }
377 bool operator!=(const ME& rhs) {
378 DCHECK(list_ == rhs.list_);
379 return i_ != rhs.i_;
380 }
381 void operator++() { ++i_; }
382 ElementType& operator*() { return (*list_)[i_]; }
383
384 private:
385 const ListType* const list_;
386 int i_;
387 };
388
389 // Utility class for representing a list of immutable input tensors
390 // that are passed to the op as a single named argument.
391 class OpInputList {
392 public:
393 typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
OpInputList()394 OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)395 OpInputList(OpKernelContext* ctx, int start, int stop)
396 : ctx_(ctx), start_(start), stop_(stop) {}
397 OpInputList& operator=(const OpInputList& other) = default;
398 const Tensor& operator[](int i) const;
size()399 int size() const { return stop_ - start_; }
begin()400 Iterator begin() const { return Iterator(this, 0); }
end()401 Iterator end() const { return Iterator(this, size()); }
402
403 private:
404 OpKernelContext* ctx_; // not owned
405 int start_;
406 int stop_;
407 };
408
409 // Utility class for representing a list of mutable ("ref") input tensors
410 // that are passed to the op as a single named argument.
411 class OpMutableInputList {
412 public:
413 typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)414 OpMutableInputList(OpKernelContext* ctx, int start, int stop)
415 : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()416 OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
417 OpMutableInputList& operator=(const OpMutableInputList& other) = default;
418 Tensor at(int i, bool lock_held);
419 mutex* ref_mutex(int i);
size()420 int size() const { return stop_ - start_; }
begin()421 Iterator begin() const { return Iterator(this, 0); }
end()422 Iterator end() const { return Iterator(this, size()); }
423
424 private:
425 OpKernelContext* ctx_; // not owned
426 int start_;
427 int stop_;
428 };
429
430 // Utility class for representing a list of output tensors that are
431 // grouped as a single named output.
432 class OpOutputList {
433 public:
434 typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()435 OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)436 OpOutputList(OpKernelContext* ctx, int start, int stop)
437 : ctx_(ctx), start_(start), stop_(stop) {}
438 OpOutputList& operator=(const OpOutputList& other) = default;
439 Tensor* operator[](int i);
440 bool required(int i) const;
441 DataType expected_output_dtype(int i) const;
442 Status allocate(int i, const TensorShape& shape, Tensor** output);
443 void set(int i, const Tensor& tensor);
444 void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()445 int size() const { return stop_ - start_; }
begin()446 Iterator begin() const { return Iterator(this, 0); }
end()447 Iterator end() const { return Iterator(this, size()); }
448
449 private:
450 OpKernelContext* ctx_; // not owned
451 int start_;
452 int stop_;
453 };
454
455 // Holds a tensor or tensor reference. For tensor references, we need
456 // a mutex to prevent concurrent access to the tensor.
457 struct TensorValue {
TensorValueTensorValue458 TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue459 TensorValue(Tensor* t) // NOLINT(runtime/explicit)
460 : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue461 TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
462 Tensor* operator->() const { return tensor; }
is_refTensorValue463 bool is_ref() const { return mutex_if_ref != nullptr; }
464
465 mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref
466 Tensor* tensor;
467 };
468
469 class OpKernelContext {
470 public:
471 // The first element of a WrappedAllocator is a "base" Allocator and
472 // the second element is that Allocator wrapped by a
473 // TrackingAllocator
474 typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
475
476 // TODO(zhifengc): Do some cleanup of Params.
477 // The Params struct is passed in to initialize an OpKernelContext,
478 // and must outlive the OpKernelContext.
479 struct Params {
~ParamsParams480 ~Params() { delete eigen_gpu_device; }
481
482 // The step being executed.
483 int64 step_id = 0;
484
485 // The op kernel being computed.
486 OpKernel* op_kernel = nullptr;
487
488 // The device on which the kernel is running.
489 DeviceBase* device = nullptr;
490
491 // The Eigen GPU device wrapper, which may include a per-op
492 // wrapped allocator. The concrete type of this object depends on
493 // the type of this->device, so eigen_gpu_device can't be an
494 // inline member and must be heap allocated. However, we don't
495 // want to allocate a new eigen_gpu_device for every Op that is
496 // executed. Instead this member is allocated on first use using
497 // ensure_eigen_gpu_device, and then if the Params structure is
498 // re-used for subsequent Ops, the eigen_gpu_device is
499 // ReInitialized in the OpKernelContext constructor. Unlike the
500 // other pointers in Params, this one is owned by Params.
501 PerOpGpuDevice* eigen_gpu_device = nullptr;
502
ensure_eigen_gpu_deviceParams503 inline void ensure_eigen_gpu_device() {
504 DCHECK(device);
505 if (nullptr == eigen_gpu_device) {
506 // Surprisingly, MakeGpuDevice will return nullptr if the
507 // device is not a GPU device. This is ok, since those devices
508 // will never use eigen_gpu_device. It seems better to have
509 // ensure_eigen_gpu_device fall through and regenerate the
510 // nullptr every time an OpKernelContext is instantiated, than
511 // to do an unnecessary allocation of a dummy eigen GPU
512 // device for CPU device Ops.
513 eigen_gpu_device = device->MakeGpuDevice();
514 }
515 }
516
517 bool track_allocations = false;
518 bool log_memory = false;
519 bool record_tensor_accesses = false;
520
521 // Array indexed by output number for this node
522 const AllocatorAttributes* output_attr_array = nullptr;
523
524 // Shared resources accessible by this op kernel invocation.
525 ResourceMgr* resource_manager = nullptr;
526
527 // Per-step resources accessible by this op kernel invocation should be
528 // stored in this container..
529 ScopedStepContainer* step_container = nullptr;
530
531 // Mechanism used by this op kernel invocation to communicate with
532 // computations running on other devices.
533 Rendezvous* rendezvous = nullptr;
534
535 // The session state for this op.
536 SessionState* session_state = nullptr;
537
538 // The tensor store for this op.
539 TensorStore* tensor_store = nullptr;
540
541 // Mechanism used by this op kernel invocation to register a callback
542 // for its cancellation.
543 CancellationManager* cancellation_manager = nullptr;
544
545 // Inputs to this op kernel.
546 const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
547 bool is_input_dead = false;
548
549 const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
550 nullptr;
551
552 // Device contexts.
553 const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts =
554 nullptr;
555 DeviceContext* op_device_context = nullptr;
556
557 // Control-flow op supports.
558 FrameAndIter frame_iter;
559
560 // Function call supports.
561 CallFrameInterface* call_frame = nullptr;
562 FunctionLibraryRuntime* function_library = nullptr;
563 std::function<void(std::function<void()>)>* runner = nullptr;
564 StepStatsCollector* stats_collector = nullptr;
565
566 // TensorSliceReaderCache support.
567 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
568 };
569
570 // params must outlive the OpKernelContext.
571 explicit OpKernelContext(Params* params);
572 OpKernelContext(Params* params, int noutputs);
573 ~OpKernelContext();
574
env()575 Env* env() const { return params_->device->env(); }
576
step_id()577 int64 step_id() const { return params_->step_id; }
578
op_kernel()579 const OpKernel& op_kernel() const { return *params_->op_kernel; }
580
581 // Input/output signature.
582
num_inputs()583 int num_inputs() const { return params_->inputs->size(); }
584 DataType input_dtype(int index) const;
585 Status input_dtype(StringPiece name, DataType* dtype) const;
586 MemoryType input_memory_type(int index) const;
587
num_outputs()588 int num_outputs() const { return outputs_.size(); }
589 DataType expected_output_dtype(int index) const;
590 MemoryType output_memory_type(int index) const;
591
592 // Input
593
594 // Returns an immutable input tensor. May only be used for non-Ref
595 // inputs. For Ref inputs use mutable_input below.
596 // REQUIRES: !IsRefType(input_dtype(index))
597 // TODO(mrry): Convert this to return Status.
598 const Tensor& input(int index);
599
600 // Returns the named immutable input tensor in "tensor", as defined
601 // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
602 // use mutable_input below.
603 // REQUIRES: !IsRefType(input_dtype(index))
604 // REQUIRES: the named input must not be a list.
605 Status input(StringPiece name, const Tensor** tensor);
606
607 // Returns the named list-valued immutable input in "list", as
608 // defined in the OpDef. If the named output is not list-valued,
609 // returns a one-element list. May only be used for non-Ref
610 // inputs. For Ref inputs use mutable_input below.
611 // REQUIRES: !IsRefType(input_dtype(index))
612 Status input_list(StringPiece name, OpInputList* list);
613
614 // For mutable inputs, use the following together to make sure there
615 // is no concurrent access to mutable_input(), e.g.:
616 // {
617 // Tensor& t = context->mutable_input(index);
618 // mutex_lock lock(*context->input_ref_mutex(index));
619 // // modify the values in t
620 // }
621 // REQUIRES: IsRefType(input_dtype(index))
622 Status input_ref_mutex(StringPiece name, mutex** out_mutex);
623
624 // Returns a mutable input tensor. Must be used to access Ref
625 // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
626 // modify the values stored in the Tensor buffer, and modifications
627 // will be visible to other Ops reading the same ref tensor. If
628 // !lock_held the input mutex will be acquired before returning the
629 // Tensor.
630 // TODO(mrry): Convert this to return Status.
631 Tensor mutable_input(int index, bool lock_held);
632
633 // Returns the named mutable input tensor in "tensor", as defined in
634 // the OpDef. Must be used to access Ref inputs. The values stored
635 // in the Tensor buffer may be modified, and modifications will be
636 // visible to other Ops reading the same ref tensor. If !lock_held
637 // the input mutex will be acquired before returning the Tensor.
638 // REQUIRES: the named input must not be a list.
639 // REQUIRES: the named input must be a ref tensor.
640 Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
641
642 // Returns the named list-valued mutable input in "list", as defined
643 // in the OpDef. If the named input is not list-valued, returns a
644 // one-element list. Must be used to access Ref inputs. The values
645 // stored in the Tensor buffer may be modified, and modifications
646 // will be visible to other Ops reading the same ref tensor.
647 // REQUIRES: the named input must be a ref tensor.
648 Status mutable_input_list(StringPiece name, OpMutableInputList* list);
649
650 // Replace the corresponding Ref Input to use the storage buffer
651 // used by tensor. If !lock_held the input mutex will be acquired
652 // before returning the Tensor.
653 // REQUIRES: IsRefType(input_dtype(index)).
654 void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
655
656 // Replace the corresponding named Ref Input to use the storage
657 // buffer used by tensor. If !lock_held the input mutex will be
658 // acquired before returning the Tensor.
659 // REQUIRES: IsRefType(input_dtype(index)).
660 Status replace_ref_input(StringPiece name, const Tensor& tensor,
661 bool lock_held);
662
663 // Deletes the Tensor object used as the Ref Input at
664 // input_index. This is not usually necessary and should be used
665 // with caution. If !lock_held the input mutex will be acquired
666 // before returning the Tensor.
667 // REQUIRES: IsRefType(input_dtype(input_index)).
668 void delete_ref_input(int input_index, bool lock_held);
669
670 // Return true if there is input at the given index. An operator has no
671 // input at index if its tensor is null. This is primarily used by the
672 // merge operator.
673 // TODO(mrry): Convert this to return Status.
674 bool has_input(int index) const;
675
676 // Returns true if all inputs are the same shape, otherwise sets the
677 // status to a non-OK value and returns false.
678 // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
679 bool ValidateInputsAreSameShape(OpKernel* op);
680
681 // Input to output forwarding.
682
683 // Set the output Ref Tensor at output_index to be an alias of the
684 // input Ref Tensor at input_index.
685 // REQUIRES: IsRefType(input_dtype(input_index)).
686 // REQUIRES: IsRefType(output_dtype(output_index)).
687 void forward_ref_input_to_ref_output(int input_index, int output_index);
688
689 // Returns true when an alias to input[input_index], reshaped to output_shape,
690 // which is safe to use for in-place computation was written to *output.
691 // Returns false if input[input_index] has a refcount greater than one, or if
692 // its type does not match the expected output type of output[output_index],
693 // or the number of elements in input[input_index] does not equal the number
694 // of elements in output_shape.
695 bool forward_input_to_output_with_shape(int input_index, int output_index,
696 const TensorShape& output_shape,
697 Tensor** output) TF_MUST_USE_RESULT;
698 Status forward_input_to_output_with_shape(StringPiece input_name,
699 StringPiece output_name,
700 const TensorShape& output_shape,
701 Tensor** output) TF_MUST_USE_RESULT;
702
703 // Returns a pointer to a Tensor aliasing the underlying buffer backing
704 // input[input_index] iff
705 // * input[input_index] is not a ref,
706 // * the data type, shape, memory type, and allocator attributes of
707 // input[input_index] are compatible with those given in dtype, shape,
708 // memory_type, and attr,
709 // * refcount on the underlying buffer is one.
710 // Otherwise returns nullptr.
711 // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
712 // forwarding is only safe if there are no reads via __ldg() after writes
713 // to the same address.
714 std::unique_ptr<Tensor> forward_input(
715 int input_index, DataType dtype, const TensorShape& shape,
716 MemoryType memory_type,
717 const AllocatorAttributes& attr) TF_MUST_USE_RESULT;
718
719 // Tries to forward one of the inputs given in input_indices to
720 // output[output_index]. If none of the given inputs can be forwarded, calls
721 // allocate_output() to allocate a new output buffer.
722 Status forward_input_or_allocate_output(
723 gtl::ArraySlice<int> candidate_input_indices, int output_index,
724 const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT;
725 Status forward_input_or_allocate_output(
726 gtl::ArraySlice<StringPiece> candidate_input_names,
727 StringPiece output_name, const TensorShape& output_shape,
728 Tensor** output) TF_MUST_USE_RESULT;
729
730 // Tries to reuse one of the inputs given in input_indices as a temporary.
731 // If none of the given inputs can be forwarded, calls
732 // allocate_temp() to allocate a new temporary buffer.
733 Status forward_input_or_allocate_temp(
734 gtl::ArraySlice<int> candidate_input_indices, DataType type,
735 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
736 Tensor* out_temp) TF_MUST_USE_RESULT;
737
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)738 Status forward_input_or_allocate_temp(
739 gtl::ArraySlice<int> candidate_input_indices, DataType type,
740 const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
741 return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
742 AllocatorAttributes(), out_temp);
743 }
744
745 // Output
746
747 // Returns the named list-valued output in "list", as defined in the OpDef.
748 // If the named output is not list-valued, returns a one-element list.
749 Status output_list(StringPiece name, OpOutputList* list);
750
751 // If output_required(index) returns true, the OpKernel's Compute() method
752 // should call allocate_output(index, ...), set_output(index, ...),
753 // set_output_ref(index, ...), or set the status to a non-ok value.
754 // If it returns false, it may output, but is not required to do so.
755 // TODO(mrry): Convert this to return Status, and implement a string
756 // name version.
output_required(int index)757 bool output_required(int index) const {
758 return true; // TODO(josh11b): implement
759 }
760
761 // Allocation of tensors during kernel execution inside the Compute
762 // method:
763 //
764 // There are three methods to allocate Tensors when an Op kernel
765 // executes.
766 //
767 // 1) allocate_persistent. This is only needed for Tensors that will
768 // be stored by the Op between invocations, and it *must* be used
769 // for those Tensors. The call returns a PersistentTensor, and that
770 // is the only object the Op is allowed to hold on to between
771 // invocations. When the Tensor is needed in a subsequent
772 // invocation, it can be retrieved from the PersistentTensor using
773 // the AccessTensor method. This ensures that the system is made
774 // aware of any use of the tensor's allocated memory, which is
775 // needed for correctness on asynchronous devices such as GPUs.
776 //
777 // 2) allocate_output. This should be used to allocate any tensor
778 // that is going to be used as an output from the Op at the end of
779 // the current execution. The caller indicates which output the
780 // Tensor will be assigned to, and the call returns the
781 // newly-allocated Tensor. The Tensor can subsequently be assigned
782 // to during kernel execution, and will be used as the designated
783 // output when the kernel execution completes.
784 //
785 // 3) allocate_temp. This should be used to allocate any scratch
786 // storage that is needed while the kernel is executing, and will
787 // not be retained by the Op.
788 //
789 // In some cases a Tensor needs to be used as an output even though
790 // it was previously allocated elsewhere. The Tensor may have been
791 // passed as an input, or stored in a PersistentTensor during a
792 // previous kernel execution, or allocated earlier in the kernel
793 // execution at a time when it was not known which output it would
794 // be assigned to. In this case the kernel can use set_output or
795 // set_output_ref to indicate that the tensor should be used as the
796 // designated output. It is legal to use any previously-allocated
797 // Tensor as an argument to set_output or set_output_ref, including
798 // Tensors allocated via allocate_temp. There may be a performance
799 // penalty to using a Tensor that was not allocated using
800 // allocate_output. This is because allocate_output uses the
801 // AllocatorAttributes stored in output_attr_array for the
802 // designated output. In some cases, using the wrong attributes may
803 // cause an extra copy of the Tensor's buffer.
804
805 // Allocates output for the specified output index with shape.
806 // OpKernelContext retains ownership of the returned pointer. See
807 // comment above.
808 //
809 // If memory allocation fails, returns an error status.
810 //
811 // REQUIRES: !IsRefType(expected_output_dtype(index))
812 Status allocate_output(int index, const TensorShape& shape,
813 Tensor** tensor) TF_MUST_USE_RESULT;
814 Status allocate_output(StringPiece name, const TensorShape& shape,
815 Tensor** tensor) TF_MUST_USE_RESULT;
816 // The following methods use the supplied attributes instead of
817 // those in output_attr_array. The caller is responsible for
818 // ensuring that the attributes are "compatible" with the
819 // output_attr_array, e.g. the tensor is allocated on the correct
820 // device. See comment above.
821 Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
822 AllocatorAttributes attr) TF_MUST_USE_RESULT;
823 Status allocate_output(StringPiece name, const TensorShape& shape,
824 Tensor** tensor,
825 AllocatorAttributes attr) TF_MUST_USE_RESULT;
826
827 // Allocates a temporary Tensor of the specified type and
828 // shape. Devices such as GPUs that enqueue Ops for lazy execution
829 // may retain references to the temporary tensors after the Op's
830 // Compute method has run. See comment above.
831 Status allocate_temp(DataType type, const TensorShape& shape,
832 Tensor* out_temp, AllocatorAttributes allocator_attr,
833 const AllocationAttributes& allocation_attr);
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)834 Status allocate_temp(DataType type, const TensorShape& shape,
835 Tensor* out_temp, AllocatorAttributes allocator_attr) {
836 return allocate_temp(type, shape, out_temp, allocator_attr,
837 AllocationAttributes());
838 }
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)839 Status allocate_temp(DataType type, const TensorShape& shape,
840 Tensor* out_temp) {
841 return allocate_temp(type, shape, out_temp, AllocatorAttributes());
842 }
843
844 // Allocates a Tensor of the specified type and shape which the Op
845 // plans to maintain as persistent state. out_persistent holds the
846 // PersistentTensor which is the object the caller should store. For
847 // convenience, if out_tensor is non-null then it will be filled in
848 // with a Tensor* pointing to the newly-allocated tensor which the
849 // caller can use instead of calling
850 // out_persistent->AccessTensor. The caller does not own out_tensor
851 // and should not keep a copy of it. See comment above.
852 Status allocate_persistent(DataType type, const TensorShape& shape,
853 PersistentTensor* out_persistent,
854 Tensor** out_tensor, AllocatorAttributes attr);
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)855 Status allocate_persistent(DataType type, const TensorShape& shape,
856 PersistentTensor* out_persistent,
857 Tensor** out_tensor) {
858 return allocate_persistent(type, shape, out_persistent, out_tensor,
859 AllocatorAttributes());
860 }
861
862 // Copies a tensor (allocated by the caller) to the specified output
863 // index. REQUIRES: !IsRefType(expected_output_dtype(index))
864 // REQUIRES: 'tensor' must have the same MemoryType as
865 // output_memory_types[index]. See comment above.
866 Status set_output(StringPiece name, const Tensor& tensor);
867
868 // To output a reference. Caller retains ownership of mu and tensor_for_ref,
869 // and they must outlive all uses within the step. See comment above.
870 // REQUIRES: IsRefType(expected_output_dtype(index))
871 Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
872
873 // Returns nullptr if allocate_output() or set_output() have not been called.
874 Status mutable_output(StringPiece name, Tensor** tensor);
875
876 // Transfers ownership of an output tensor to the caller.
877 // NOTE: For non-reference outputs, the caller takes responsibility
878 // for deletion. For reference outputs, the caller does NOT take
879 // responsibility for deletion.
880 Status release_output(StringPiece name, TensorValue* value);
881
882 // Records device specific state about how the input tensors were
883 // computed.
884 //
885 // If using the templated function, the type must be a subclass
886 // of DeviceContext.
887 //
888 // Get the DeviceContext used for the index input. Returns nullptr
889 // if no DeviceContext was provided.
890 template <typename T>
891 T* input_device_context(int index);
892 DeviceContext* input_device_context(int index);
893
894 // Return the DeviceContext that should be used for this Op.
895 //
896 // If using the templated function, the type must be a subclass
897 // of DeviceContext.
898 //
899 // Returns nullptr if the device did not provide one.
900 template <typename T>
901 T* op_device_context();
op_device_context()902 DeviceContext* op_device_context() {
903 DeviceContext* ret = params_->op_device_context;
904 if (ret == nullptr) {
905 auto* dev_info = device()->tensorflow_gpu_device_info();
906 if (dev_info) ret = dev_info->default_context;
907 }
908 return ret;
909 }
910
input_alloc_attr(int index)911 AllocatorAttributes input_alloc_attr(int index) const {
912 if (params_->input_alloc_attrs == nullptr) {
913 return AllocatorAttributes();
914 } else {
915 DCHECK_GE(index, 0);
916 DCHECK_LT(index, params_->input_alloc_attrs->size());
917 return (*params_->input_alloc_attrs)[index];
918 }
919 }
920
output_alloc_attr(int index)921 AllocatorAttributes output_alloc_attr(int index) const {
922 return params_->output_attr_array[index];
923 }
924
wrapped_allocators()925 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators() const {
926 mutex_lock lock(mu_);
927 gtl::InlinedVector<WrappedAllocator, 4> retrieved = wrapped_allocators_;
928 return retrieved;
929 }
930
931 // Communication.
932 //
933 // An op kernel communicates with outside environment through
934 // Rendezvous Send() and Recv().
rendezvous()935 Rendezvous* rendezvous() const { return params_->rendezvous; }
936
937 // An op kernel can access the session state it belongs to.
session_state()938 SessionState* session_state() const { return params_->session_state; }
939
940 // An op kernel can access the tensor store of the run it belongs to.
tensor_store()941 TensorStore* tensor_store() const { return params_->tensor_store; }
942
943 // Function call support.
944 //
945 // If this kernel invocation is within a function execution,
946 // call_frame() returns the call frame for the function call.
call_frame()947 CallFrameInterface* call_frame() const { return params_->call_frame; }
948
949 // If not nullptr, the kernel invoke functions defined in the
950 // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()951 FunctionLibraryRuntime* function_library() const {
952 return params_->function_library;
953 }
954
runner()955 std::function<void(std::function<void()>)>* runner() const {
956 return params_->runner;
957 }
stats_collector()958 StepStatsCollector* stats_collector() const {
959 return params_->stats_collector;
960 }
961
962 // Shared resources accessible to this kernel.
resource_manager()963 ResourceMgr* resource_manager() const { return params_->resource_manager; }
964
slice_reader_cache()965 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
966 return params_->slice_reader_cache;
967 }
968
969 // Execution.
970 //
971 // OpKernels can use these eigen devices to carry out their
972 // numerical computation.
eigen_cpu_device()973 const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
974 return *device()->eigen_cpu_device();
975 }
eigen_gpu_device()976 const Eigen::GpuDevice& eigen_gpu_device() const {
977 return params_->eigen_gpu_device->device();
978 }
979 #ifdef TENSORFLOW_USE_SYCL
eigen_sycl_device()980 const Eigen::SyclDevice& eigen_sycl_device() const {
981 return *device()->eigen_sycl_device();
982 }
983 #endif
984 template <typename EigenDeviceType>
985 const EigenDeviceType& eigen_device() const;
986
987 // Error handling.
988
989 // If expected_inputs == inputs() and expected_outputs == output_types(),
990 // returns OK, else returns INVALID_ARGUMENT with an error message.
991 // Recommended for Ops with dynamic signatures, where validation can only
992 // be performed at runtime.
993 Status MatchSignature(const DataTypeSlice expected_inputs,
994 const DataTypeSlice expected_outputs);
995
996 // An OpKernel should call SetStatus() if Compute() encounters an
997 // error.
998 void SetStatus(const Status& status);
status()999 const Status& status() const { return status_; }
1000
1001 // Cancellation.
1002 //
1003 // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an
1004 // example of how to use this API.
cancellation_manager()1005 CancellationManager* cancellation_manager() const {
1006 return params_->cancellation_manager;
1007 }
1008
1009 // Other accessors.
1010
1011 // For control flow.
frame_iter()1012 FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1013 bool is_input_dead() const { return params_->is_input_dead; }
is_output_dead()1014 bool* is_output_dead() { return &is_output_dead_; }
1015
1016 // May be used, e.g., to get GPU handles, etc.
1017 // TODO(tucker): Add example usage.
device()1018 DeviceBase* device() const { return params_->device; }
1019
1020 // Retrieve list of referenced tensors in out_vector. Once this is
1021 // called, it is not legal to reference any more tensors. Should
1022 // not be called from Op kernels.
1023 void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
1024
1025 // Per-step container for use by white-listed internal ops.
step_container()1026 ScopedStepContainer* step_container() const {
1027 return params_->step_container;
1028 }
1029
1030 // Helper routines for the OP_REQUIRES macros
1031 void CtxFailure(const Status& s);
1032 void CtxFailureWithWarning(const Status& s);
1033 void CtxFailure(const char* file, int line, const Status& s);
1034 void CtxFailureWithWarning(const char* file, int line, const Status& s);
1035
1036 // Unrecommended functions: these are functions that have some
1037 // current uses but are not recommended for use, and may go away at
1038 // some future major version release.
1039 //
1040 // The following functions all have versions that return Status
1041 // to capture error conditions, and are strongly preferred.
1042 Tensor* mutable_output(int index);
1043 void set_output(int index, const Tensor& tensor);
1044 mutex* input_ref_mutex(int index);
1045 void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1046 TensorValue release_output(int index);
1047
track_allocations()1048 bool track_allocations() const { return params_->track_allocations; }
1049
1050 // Records temp memory allocation. Tensor object is recorded to identify the
1051 // case where temp memory is used as output memory.
1052 void record_temp_memory_allocation(int64 size, const Tensor& t)
1053 LOCKS_EXCLUDED(stats_mu_);
1054
1055 // Returns recorded size of temporary memory;
1056 int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
1057
1058 // Records persistent memory allocation, size can be negative indicating
1059 // deallocation.
1060 void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
1061 LOCKS_EXCLUDED(stats_mu_);
1062
1063 // Returns recorded size and ids of persistent memory.
1064 int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
1065
1066 std::vector<int64> persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_);
1067
1068 // Resets counters for temp and persistent memory and recorded ids.
1069 void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_);
1070
1071 bool input_is_ref(int index) const;
1072
1073 private:
1074 Allocator* get_allocator(AllocatorAttributes attr);
1075
1076 // Internal method to add a tensor's buffer to the list of buffers
1077 // referenced during the execution of the Op, so that GPUs may
1078 // accurately track the memory that may not be reused until the Op
1079 // execution completes.
1080 void record_tensor_reference(const Tensor& tensor);
1081 void really_record_tensor_reference(const Tensor& tensor);
1082
1083 // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1084 Status allocate_tensor(DataType type, const TensorShape& shape,
1085 Tensor* out_tensor,
1086 AllocatorAttributes allocator_attr) {
1087 return allocate_tensor(type, shape, out_tensor, allocator_attr,
1088 AllocationAttributes());
1089 }
1090
1091 Status allocate_tensor(DataType type, const TensorShape& shape,
1092 Tensor* out_tensor, AllocatorAttributes allocator_attr,
1093 const AllocationAttributes& allocation_attr);
1094
1095 // This is called by PersistentTensor::AccessTensor whenever the
1096 // wrapped tensor is retrieved, to ensure the runtime knows that the
1097 // Tensor is being accessed within an Op. This is necessary for
1098 // memory safety of devices like GPUs that queue Ops for
1099 // asynchronous execution after the Compute() method completes.
1100 friend class PersistentTensor;
1101 void NotifyUseOfPersistentTensor(const Tensor& tensor);
1102
1103 Status status_;
1104 Params* params_; // not owned
1105 mutable mutex mu_; // mutable so const accessors can acquire the lock
1106 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
1107 gtl::InlinedVector<TensorValue, 4> outputs_;
1108
1109 // Constructed only if <params->record_tensor_accesses>.
1110 ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
1111
1112 bool is_output_dead_ = false;
1113
1114 // The following data members are only used when allocation tracking is
1115 // enabled.
1116 mutable mutex stats_mu_;
1117 int64 temp_memory_allocated_ GUARDED_BY(stats_mu_);
1118 int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_);
1119 std::unique_ptr<gtl::InlinedVector<std::pair<const void*, int64>, 2>>
1120 temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_);
1121 std::unique_ptr<gtl::InlinedVector<int64, 2>> persistent_alloc_ids_
1122 GUARDED_BY(stats_mu_);
1123
1124 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1125 };
1126
1127 // Register your OpKernel by specifying the Op's name, the device the
1128 // kernel runs on, any type attr constraints for this kernel, any
1129 // host-memory args, and the class to instantiate. Examples:
1130 //
1131 // // A kernel that supports all types.
1132 // REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1133 //
1134 // // The following are equivalent ways of specifying that the kernel only
1135 // // works if the "T" type attr is set to DT_FLOAT.
1136 // REGISTER_KERNEL_BUILDER(
1137 // Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1138 // SubOp<float>);
1139 // // (You would then repeat this for every type supported by "Sub".)
1140 //
1141 // // This form allows you to specify a list of types as the constraint.
1142 // REGISTER_KERNEL_BUILDER(Name("Sub")
1143 // .Device(DEVICE_CPU)
1144 // .TypeConstraint("T", {DT_FLOAT}),
1145 // SubOp<float>);
1146 //
1147 // // A kernel that expects one of the input tensors in host memory.
1148 // REGISTER_KERNEL_BUILDER(
1149 // Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1150 //
1151 // See kernel_def_builder for details.
1152
1153 // Instantiate an OpKernel that has been registered. Returns nullptr
1154 // if no operation for that type of device / input signature combination
1155 // (and a NOT_FOUND *status), or there is an error in construction (and
1156 // an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership
1157 // of the returned pointer.
1158 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1159 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1160 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1161 DeviceBase* device,
1162 Allocator* allocator,
1163 const NodeDef& def,
1164 int graph_def_version, Status* status);
1165 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1166 Allocator* allocator, FunctionLibraryRuntime* flib,
1167 const NodeDef& def, int graph_def_version,
1168 OpKernel** kernel);
1169
1170 // Returns into 'device_types' the subset of prioritized_types that this
1171 // binary has registered for the given NodeDef.
1172 //
1173 // REQUIRES: * 'device_types' is not nullptr.
1174 // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1175 Status SupportedDeviceTypesForNode(
1176 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1177 DeviceTypeVector* device_types);
1178
1179 // Returns a message with a description of the kernels registered for op
1180 // `op_name`.
1181 string KernelsRegisteredForOp(StringPiece op_name);
1182
1183 // Call once after Op registration has completed.
1184 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1185
1186 // -----------------------------------------------------------------------------
1187 // OpKernel registration implementation follows, please ignore.
1188
1189 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1190 namespace register_kernel {
1191
1192 class Name : public KernelDefBuilder {
1193 public:
1194 // With selective registration, kernels whose implementation class is not used
1195 // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in
1196 // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an
1197 // implementation class with a used kernel would get through that mechanism.
1198 //
1199 // This mechanism stops that registration by changing the name of the kernel
1200 // for the unused op to one that is ignored by
1201 // OpKernelRegistrar::InitInternal. Note that this method alone is
1202 // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at
1203 // compilation time, so this method doesn't actually reduce code size.
Name(const char * op)1204 explicit Name(const char* op)
1205 : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
1206 };
1207
1208 namespace system {
1209
1210 class Name : public KernelDefBuilder {
1211 public:
1212 // For system kernels, we ignore selective registration and
1213 // unconditionally register the kernel.
Name(const char * op)1214 explicit Name(const char* op) : KernelDefBuilder(op) {}
1215 };
1216
1217 } // namespace system
1218
1219 } // namespace register_kernel
1220
1221 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1222 REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
1223
1224 #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1225 REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1226
1227 #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
1228 constexpr bool should_register_##ctr##__flag = \
1229 SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \
1230 static ::tensorflow::kernel_factory::OpKernelRegistrar \
1231 registrar__body__##ctr##__object( \
1232 should_register_##ctr##__flag \
1233 ? ::tensorflow::register_kernel::kernel_builder.Build() \
1234 : nullptr, \
1235 #__VA_ARGS__, \
1236 [](::tensorflow::OpKernelConstruction* context) \
1237 -> ::tensorflow::OpKernel* { \
1238 return new __VA_ARGS__(context); \
1239 });
1240
1241 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1242 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1243 // unconditionally even when selective registration is used.
1244 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
1245 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \
1246 __VA_ARGS__)
1247
1248 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1249 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1250
1251 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
1252 static ::tensorflow::kernel_factory::OpKernelRegistrar \
1253 registrar__body__##ctr##__object( \
1254 ::tensorflow::register_kernel::system::kernel_builder.Build(), \
1255 #__VA_ARGS__, \
1256 [](::tensorflow::OpKernelConstruction* context) \
1257 -> ::tensorflow::OpKernel* { \
1258 return new __VA_ARGS__(context); \
1259 });
1260
1261 void* GlobalKernelRegistry();
1262
1263 // If node_def has a corresponding kernel registered on device_type,
1264 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1265 // <kernel_class_name> may be null.
1266 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1267 const KernelDef** def, string* kernel_class_name);
1268
1269 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1270 // missing kernel errors.
1271 void LogAllRegisteredKernels();
1272
1273 namespace kernel_factory {
1274
1275 class OpKernelRegistrar {
1276 public:
1277 typedef OpKernel* (*Factory)(OpKernelConstruction*);
1278
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,Factory factory)1279 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1280 Factory factory) {
1281 // Perform the check in the header to allow compile-time optimization
1282 // to a no-op, allowing the linker to remove the kernel symbols.
1283 if (kernel_def != nullptr) {
1284 InitInternal(kernel_def, kernel_class_name, factory);
1285 }
1286 }
1287
1288 private:
1289 void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1290 Factory factory);
1291 };
1292
1293 } // namespace kernel_factory
1294
1295 // -----------------------------------------------------------------------------
1296 // Template and inline method implementations, please ignore
1297
1298 template <class T>
GetAttr(StringPiece attr_name,T * value)1299 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1300 return GetNodeAttr(def(), attr_name, value);
1301 }
1302
input_dtype(int index)1303 inline DataType OpKernelContext::input_dtype(int index) const {
1304 DCHECK_GE(index, 0);
1305 DCHECK_LT(index, num_inputs());
1306 const TensorValue& value((*params_->inputs)[index]);
1307 if (value.is_ref()) {
1308 return MakeRefType(value->dtype());
1309 } else {
1310 return value->dtype();
1311 }
1312 }
1313
input_memory_type(int index)1314 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1315 DCHECK_GE(index, 0);
1316 DCHECK_LT(index, num_inputs());
1317 return op_kernel().input_memory_types()[index];
1318 }
1319
expected_output_dtype(int index)1320 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1321 DCHECK_GE(index, 0);
1322 DCHECK_LT(index, num_outputs());
1323 return params_->op_kernel->output_type(index);
1324 }
1325
output_memory_type(int index)1326 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1327 DCHECK_GE(index, 0);
1328 DCHECK_LT(index, num_outputs());
1329 return op_kernel().output_memory_types()[index];
1330 }
1331
input_is_ref(int index)1332 inline bool OpKernelContext::input_is_ref(int index) const {
1333 const TensorValue& value((*params_->inputs)[index]);
1334 return value.is_ref();
1335 }
1336
record_tensor_reference(const Tensor & tensor)1337 inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
1338 DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
1339 params_->record_tensor_accesses);
1340 if (params_->record_tensor_accesses) {
1341 really_record_tensor_reference(tensor);
1342 }
1343 }
1344
retrieve_accessed_tensors(TensorReferenceVector * out_vector)1345 inline void OpKernelContext::retrieve_accessed_tensors(
1346 TensorReferenceVector* out_vector) {
1347 if (params_->record_tensor_accesses) {
1348 mutex_lock l(mu_);
1349 referenced_tensors_->FreezeAndReturnReferences(out_vector);
1350 }
1351 }
1352
1353 // no input if tensor == nullptr.
has_input(int index)1354 inline bool OpKernelContext::has_input(int index) const {
1355 DCHECK_GE(index, 0);
1356 DCHECK_LT(index, num_inputs());
1357 return (*params_->inputs)[index].tensor != nullptr;
1358 }
1359
input_ref_mutex(int index)1360 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1361 DCHECK_GE(index, 0);
1362 DCHECK_LT(index, num_inputs());
1363 DCHECK(input_is_ref(index));
1364 return (*params_->inputs)[index].mutex_if_ref;
1365 }
1366
NotifyUseOfPersistentTensor(const Tensor & t)1367 inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
1368 if (t.IsInitialized()) {
1369 record_tensor_reference(t);
1370 }
1371 }
1372
mutable_output(int index)1373 inline Tensor* OpKernelContext::mutable_output(int index) {
1374 DCHECK_GE(index, 0);
1375 DCHECK_LT(index, num_outputs());
1376 // No need to record_tensor_reference since the output must already
1377 // have been set by a call that did so.
1378 return outputs_[index].tensor;
1379 }
1380
release_output(int index)1381 inline TensorValue OpKernelContext::release_output(int index) {
1382 DCHECK_GE(index, 0);
1383 DCHECK_LT(index, num_outputs());
1384 TensorValue value = outputs_[index];
1385 outputs_[index] = TensorValue();
1386 return value;
1387 }
1388
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output)1389 inline Status OpKernelContext::forward_input_or_allocate_output(
1390 gtl::ArraySlice<int> candidate_input_indices, int output_index,
1391 const TensorShape& output_shape, Tensor** output) {
1392 for (int input_index : candidate_input_indices) {
1393 if (forward_input_to_output_with_shape(input_index, output_index,
1394 output_shape, output)) {
1395 return Status::OK();
1396 }
1397 }
1398 return allocate_output(output_index, output_shape, output);
1399 }
1400
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)1401 inline Status OpKernelContext::forward_input_or_allocate_output(
1402 gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
1403 const TensorShape& output_shape, Tensor** output) {
1404 for (const StringPiece& input_name : candidate_input_names) {
1405 if (forward_input_to_output_with_shape(input_name, output_name,
1406 output_shape, output)
1407 .ok()) {
1408 return Status::OK();
1409 }
1410 }
1411 return allocate_output(output_name, output_shape, output);
1412 }
1413
1414 template <typename T>
op_device_context()1415 T* OpKernelContext::op_device_context() {
1416 static_assert(std::is_base_of<DeviceContext, T>::value,
1417 "T is not a subclass of DeviceContext");
1418 return static_cast<T*>(op_device_context());
1419 }
1420
1421 template <typename T>
input_device_context(int index)1422 T* OpKernelContext::input_device_context(int index) {
1423 DCHECK_GE(index, 0);
1424 DCHECK_LT(index, params_->input_device_contexts->size());
1425 static_assert(std::is_base_of<DeviceContext, T>::value,
1426 "T is not a subclass of DeviceContext");
1427 return static_cast<T*>((*params_->input_device_contexts)[index]);
1428 }
1429
input_device_context(int index)1430 inline DeviceContext* OpKernelContext::input_device_context(int index) {
1431 DCHECK_GE(index, 0);
1432 DCHECK_LT(index, params_->input_device_contexts->size());
1433 return (*params_->input_device_contexts)[index];
1434 }
1435
1436 inline const Tensor& OpInputList::operator[](int i) const {
1437 DCHECK_GE(i, 0);
1438 DCHECK_LT(i, stop_ - start_);
1439 return ctx_->input(start_ + i);
1440 }
1441
ref_mutex(int i)1442 inline mutex* OpMutableInputList::ref_mutex(int i) {
1443 DCHECK_GE(i, 0);
1444 DCHECK_LT(i, stop_ - start_);
1445 return ctx_->input_ref_mutex(start_ + i);
1446 }
1447
at(int i,bool lock_held)1448 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1449 DCHECK_GE(i, 0);
1450 DCHECK_LT(i, stop_ - start_);
1451 return ctx_->mutable_input(start_ + i, lock_held);
1452 }
1453
1454 inline Tensor* OpOutputList::operator[](int i) {
1455 DCHECK_GE(i, 0);
1456 DCHECK_LT(i, stop_ - start_);
1457 return ctx_->mutable_output(start_ + i);
1458 }
1459
required(int i)1460 inline bool OpOutputList::required(int i) const {
1461 DCHECK_GE(i, 0);
1462 DCHECK_LT(i, stop_ - start_);
1463 return ctx_->output_required(start_ + i);
1464 }
1465
expected_output_dtype(int i)1466 inline DataType OpOutputList::expected_output_dtype(int i) const {
1467 DCHECK_GE(i, 0);
1468 DCHECK_LT(i, stop_ - start_);
1469 return ctx_->expected_output_dtype(start_ + i);
1470 }
1471
allocate(int i,const TensorShape & shape,Tensor ** output)1472 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1473 Tensor** output) {
1474 DCHECK_GE(i, 0);
1475 DCHECK_LT(i, stop_ - start_);
1476 return ctx_->allocate_output(start_ + i, shape, output);
1477 }
1478
set(int i,const Tensor & tensor)1479 inline void OpOutputList::set(int i, const Tensor& tensor) {
1480 DCHECK_GE(i, 0);
1481 DCHECK_LT(i, stop_ - start_);
1482 ctx_->set_output(start_ + i, tensor);
1483 }
1484
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1485 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1486 DCHECK_GE(i, 0);
1487 DCHECK_LT(i, stop_ - start_);
1488 ctx_->set_output_ref(i, mu, tensor_for_ref);
1489 }
1490
1491 // Convenience macros for asserting and handling exceptional conditions.
1492 // Analogous to the CHECK* macros provided by logging.h.
1493 //
1494 // Example use:
1495 // void Compute(OperationContext* context) {
1496 // OP_REQUIRES(context, context->num_inputs() == 2,
1497 // errors::InvalidArgument("FooOp requires 2 arguments"));
1498 // ...
1499 // Status status = SomeUncertainMethod();
1500 // OP_REQUIRES_OK(context, status);
1501 // ...
1502 // }
1503
1504 #define OP_REQUIRES(CTX, EXP, STATUS) \
1505 do { \
1506 if (!TF_PREDICT_TRUE(EXP)) { \
1507 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
1508 return; \
1509 } \
1510 } while (0)
1511
1512 #define OP_REQUIRES_OK(CTX, ...) \
1513 do { \
1514 ::tensorflow::Status _s(__VA_ARGS__); \
1515 if (!TF_PREDICT_TRUE(_s.ok())) { \
1516 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
1517 return; \
1518 } \
1519 } while (0)
1520
1521 #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
1522 do { \
1523 if (!TF_PREDICT_TRUE(EXP)) { \
1524 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
1525 (CALLBACK)(); \
1526 return; \
1527 } \
1528 } while (0)
1529
1530 #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
1531 do { \
1532 ::tensorflow::Status _s(STATUS); \
1533 if (!TF_PREDICT_TRUE(_s.ok())) { \
1534 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
1535 (CALLBACK)(); \
1536 return; \
1537 } \
1538 } while (0)
1539
1540 } // namespace tensorflow
1541
1542 #endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
1543