1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
18 
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/framework/dataset.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/kernels/ops_util.h"
25 
26 namespace tensorflow {
27 namespace data {
28 
29 class IteratorResource;
30 
31 class IteratorHandleOp : public OpKernel {
32  public:
33   explicit IteratorHandleOp(OpKernelConstruction* ctx);
34 
35   // The resource is deleted from the resource manager only when it is private
36   // to kernel. Ideally the resource should be deleted when it is no longer held
37   // by anyone, but it would break backward compatibility.
38   ~IteratorHandleOp() override;
39 
40   void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_);
41 
42  private:
43   // During the first Compute(), resource is either created or looked up using
44   // shared_name. In the latter case, the resource found should be verified if
45   // it is compatible with this op's configuration. The verification may fail in
46   // cases such as two graphs asking queues of the same shared name to have
47   // inconsistent capacities.
48   Status VerifyResource(IteratorResource* resource);
49 
50   template <typename To, typename From>  // use like this: down_cast<T*>(foo);
down_cast(From * f)51   static inline To down_cast(From* f) {  // so we only accept pointers
52     static_assert(
53         (std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
54         "target type not derived from source type");
55 
56     // We skip the assert and hence the dynamic_cast if RTTI is disabled.
57 #if !defined(__GNUC__) || defined(__GXX_RTTI)
58     // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
59     assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
60 #endif  // !defined(__GNUC__) || defined(__GXX_RTTI)
61     return static_cast<To>(f);
62   }
63 
64   FunctionLibraryRuntime* CreatePrivateFLR(
65       OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
66       std::unique_ptr<FunctionLibraryDefinition>* flib_def,
67       std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr);
68 
69   mutex mu_;
70   ContainerInfo cinfo_;  // Written once under mu_ then constant afterwards.
71   IteratorResource* resource_ GUARDED_BY(mu_) = nullptr;
72   DataTypeVector output_dtypes_;
73   std::vector<PartialTensorShape> output_shapes_;
74   const int graph_def_version_;
75   string name_;
76 };
77 
78 // Like IteratorHandleOp, but creates handles which are never shared, and does
79 // not hold a reference to these handles. The latter is important for eager
80 // execution, since OpKernel instances generally live as long as the program
81 // running them.
82 class AnonymousIteratorHandleOp : public OpKernel {
83  public:
84   explicit AnonymousIteratorHandleOp(OpKernelConstruction* context);
85 
86   void Compute(OpKernelContext* context) override;
87 
88  private:
89   // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp
90   // instances.
91   static mutex static_resource_lookup_mutex_;
92   // current_id_ is just a hint for creating unique names. If it turns out
93   // there's a collision (e.g. because another AnonymousIteratorHandleOp
94   // instance is generating handles) we'll just skip that id.
95   static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_);
96   DataTypeVector output_dtypes_;
97   std::vector<PartialTensorShape> output_shapes_;
98   const int graph_def_version_;
99 };
100 
101 class MakeIteratorOp : public OpKernel {
102  public:
MakeIteratorOp(OpKernelConstruction * ctx)103   explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
104 
105   void Compute(OpKernelContext* ctx) override;
106 };
107 
108 class IteratorGetNextOp : public AsyncOpKernel {
109  public:
IteratorGetNextOp(OpKernelConstruction * ctx)110   explicit IteratorGetNextOp(OpKernelConstruction* ctx)
111       : AsyncOpKernel(ctx),
112         background_worker_(ctx->env(), "tf_data_iterator_get_next") {}
113 
114   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
115 
116  private:
117   BackgroundWorker background_worker_;
118 };
119 
120 class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
121  public:
IteratorGetNextAsOptionalOp(OpKernelConstruction * ctx)122   explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
123       : AsyncOpKernel(ctx),
124         background_worker_(ctx->env(),
125                            "tf_data_iterator_get_next_as_optional") {
126     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
127     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
128   }
129 
130   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
131 
132  private:
133   BackgroundWorker background_worker_;
134   DataTypeVector output_types_;
135   std::vector<PartialTensorShape> output_shapes_;
136 };
137 
138 class IteratorGetNextSyncOp : public OpKernel {
139  public:
IteratorGetNextSyncOp(OpKernelConstruction * ctx)140   explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
141 
142   void Compute(OpKernelContext* ctx) override;
143 };
144 
145 class IteratorToStringHandleOp : public OpKernel {
146  public:
IteratorToStringHandleOp(OpKernelConstruction * ctx)147   explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
148       : OpKernel(ctx) {}
149 
150   void Compute(OpKernelContext* ctx) override;
151 };
152 
153 class IteratorFromStringHandleOp : public OpKernel {
154  public:
155   explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx);
156 
157   void Compute(OpKernelContext* ctx) override;
158 
159  private:
160   DataTypeVector output_dtypes_;
161   std::vector<PartialTensorShape> output_shapes_;
162 };
163 
164 }  // namespace data
165 }  // namespace tensorflow
166 
167 #endif  // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
168