1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
16 #define TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/dataset.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/gtl/array_slice.h"
26 #include "tensorflow/core/lib/random/random.h"
27 #include "tensorflow/core/platform/macros.h"
28 
29 namespace tensorflow {
30 
31 class Device;
32 class OpKernelContext;
33 class ResourceMgr;
34 
35 namespace data {
36 
37 class CapturedFunction;
38 
39 // An InstantiatedCapturedFunction encapsulates all the runtime support needed
40 // to execute a tensorflow function.
41 //
42 // While CapturedFunction (below) encapsulates the more permanent attributes
43 // of the function i.e. name, captured arguments etc.,
44 // InstantiatedCapturedFunction encapsulates the more runtime aspects i.e.
45 // FunctionLibraryRuntime, function handle etc.
46 //
47 // The `Iterator-`related classes use `InstantiatedCapturedFunction` to execute
48 // functions outside a the normal `OpKernel::Compute()` context.
49 class InstantiatedCapturedFunction {
50  public:
51   ~InstantiatedCapturedFunction();
52 
53   // Runs the "Instantiated Captured function". This method takes ownership of
54   // the tensors in `args`, in order to be able to deallocate them as early as
55   // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
56   // ownership of the `args`.
57   Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
58              std::vector<Tensor>* rets) const;
59 
60   // Synchronously runs the captured function on the given `args`, and stores
61   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
62   // possible.
63   Status RunWithBorrowedArgs(IteratorContext* ctx,
64                              const std::vector<Tensor>& args,
65                              std::vector<Tensor>* rets) const;
66 
67   // Synchronously runs the captured function on the given `args`, and stores
68   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
69   // possible. This can be useful for calling a captured
70   // function in cases where an `IteratorContext*` is not available
71   // (such as a destructor).
72   Status RunInstantiated(const std::vector<Tensor>& args,
73                          std::vector<Tensor>* rets);
74 
75   // Asynchronously runs the captured function on the given `args`, stores
76   // the results in `*rets`, and calls the given `done` callback when the
77   // function returns. This method takes ownership of the tensors in `args`,
78   // in order to be able to deallocate them as early as possible.
79   void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
80                 std::vector<Tensor>* rets,
81                 FunctionLibraryRuntime::DoneCallback done,
82                 const string& prefix) const;
83 
84   // Returns a step ID for use when running an `InstantiatedCapturedFunction`.
generate_step_id()85   static int64 generate_step_id() {
86     // Choose a step ID that is guaranteed not to clash with any
87     // Session-generated step ID. DirectSession only generates
88     // non-negative step IDs (contiguous, starting from 0), and
89     // MasterSession generates 56-bit random step IDs whose MSB is
90     // always 0, so a negative random step ID should suffice.
91     return -std::abs(static_cast<int64>(random::New64()));
92   }
93 
94  private:
95   InstantiatedCapturedFunction(
96       FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
97       DataTypeVector ret_types,
98       std::function<void(std::function<void()>)> runner,
99       CapturedFunction* captured_func);
100 
101   friend class CapturedFunction;
102 
103   FunctionLibraryRuntime* const lib_;
104   const FunctionLibraryRuntime::Handle f_handle_;
105   const DataTypeVector ret_types_;
106   std::function<void(std::function<void()>)> captured_runner_;
107   CapturedFunction* const captured_func_;
108 
109   TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
110 };
111 
112 // A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured"
113 // arguments that it closed over in the user program.
114 class CapturedFunction {
115  public:
116   // Creates a new instance using a list of named attributes, fetching captured
117   // inputs from a context argument.
118   static Status Create(const NameAttrList& func, OpKernelContext* ctx,
119                        const string& argument_name,
120                        std::unique_ptr<CapturedFunction>* out_function);
121 
122   // Creates a new instance using a list of named attributes, fetching captured
123   // inputs from a context argument.
124   //
125   // If `use_inter_op_parallelism` is false, the runtime may use an executor
126   // that is optimized for small functions.
127   static Status Create(const NameAttrList& func, OpKernelContext* ctx,
128                        const string& argument_name,
129                        bool use_inter_op_parallelism,
130                        std::unique_ptr<CapturedFunction>* out_function);
131 
132   // Creates a new instance using a list of named attributes, using provided
133   // captured inputs.
134   //
135   // If `use_inter_op_parallelism` is false, the runtime may use an executor
136   // that is optimized for small functions.
137   static Status Create(const NameAttrList& func, OpKernelContext* ctx,
138                        std::vector<Tensor>&& captured_inputs,
139                        bool use_inter_op_parallelism,
140                        std::unique_ptr<CapturedFunction>* out_function);
141 
142   // Instantiates this function for use in the given context, providing an
143   // InstantiatedCapturedFunction that can be used to execute functions.
144   Status Instantiate(IteratorContext* ctx,
145                      std::unique_ptr<InstantiatedCapturedFunction>*
146                          instantiated_captured_function);
147 
148   // Returns the named list of function arguments.
func()149   const NameAttrList& func() { return func_; }
150 
151   // Returns that additional captured inputs that will be passed to the function
captured_inputs()152   const std::vector<Tensor>& captured_inputs() { return captured_inputs_; }
153 
154  private:
155   CapturedFunction(const NameAttrList& func,
156                    std::vector<Tensor> captured_inputs,
157                    bool use_inter_op_parallelism);
158 
159   const NameAttrList func_;
160   const std::vector<Tensor> captured_inputs_;
161   const bool use_inter_op_parallelism_;
162 
163   TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
164 };
165 }  // namespace data
166 
167 // TODO(b/114112161): Remove these aliases when all users have moved over to the
168 // `tensorflow::data` namespace.
169 using data::CapturedFunction;
170 
171 }  // namespace tensorflow
172 
173 #endif  // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
174