1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
18 
19 #include <vector>
20 
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/platform/platform.h"
24 // clang-format on
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/types/optional.h"
28 #include "absl/types/variant.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/attr_value_util.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/selective_registration.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/lib/gtl/flatmap.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/lib/random/random.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/protobuf.h"
44 #include "tensorflow/core/protobuf/config.pb.h"
45 #if !defined(IS_MOBILE_PLATFORM)
46 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
47 #endif  // IS_MOBILE_PLATFORM
48 
49 namespace tensorflow {
50 
51 class CancellationManager;
52 class CollectiveExecutor;
53 class DeviceSet;
54 class Graph;
55 class GraphDef;
56 class OpKernel;
57 class ProcessFunctionLibraryRuntime;
58 class ResourceMgr;
59 class Rendezvous;
60 class ScopedStepContainer;
61 class StepStatsCollectorInterface;
62 class Node;
63 
64 // FunctionDefHelper::Create is a convenient helper to construct a
65 // FunctionDef proto.
66 // E.g.,
67 //   FunctionDef my_func = FunctionDefHelper::Create(
68 //     "my_func_name",
69 //     {"x:T", "y:T" /* one string per argument */},
70 //     {"z:T" /* one string per return value */},
71 //     {"T: {float, double}" /* one string per attribute  */},
72 //     {
73 //        {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
74 //        /* one entry per function node */
75 //     },
76 //     /* Mapping between function returns and function node outputs. */
77 //     {{"z", "o:z"}});
78 //
79 // For the old Function::Node approach, use FunctionDefHelper::Define()
80 // E.g.,
81 //   FunctionDef my_func = FunctionDefHelper::Define(
82 //     "my_func_name",
83 //     {"x:T", "y:T" /* one string per argument */},
84 //     {"z:T" /* one string per return value */},
85 //     {"T: {float, double}" /* one string per attribute  */},
86 //     {
87 //        {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
88 //        /* one entry per function node */
89 //     });
90 class FunctionDefHelper {
91  public:
92   // AttrValueWrapper has copy constructors for the type T so that
93   // it's easy to construct a simple AttrValue proto.
94   //
95   // If T is a string type (const char*, string, or StringPiece), and
96   // it starts with "$", we construct a AttrValue of "placeholder".
97   //
98   // E.g.,
99   //   std::<string, AttrValueWrapper> x = {"T", "$T"}
100   // is a named attr value placeholder.
101   struct AttrValueWrapper {
102     AttrValue proto;
103 
AttrValueWrapperAttrValueWrapper104     AttrValueWrapper() {}
105 
106     template <typename T>
AttrValueWrapperAttrValueWrapper107     AttrValueWrapper(T val) {  // NOLINT(runtime/explicit)
108       SetAttrValue(val, &proto);
109     }
110 
111    private:
112     void InitFromString(StringPiece val);
113   };
114 
115   // Constructs an AttrValue.func given the "name" and "attrs".
116   static AttrValueWrapper FunctionRef(
117       const std::string& name,
118       gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
FunctionRef(const std::string & name)119   static AttrValueWrapper FunctionRef(const std::string& name) {
120     return FunctionRef(name, {});
121   }
122 
123   // Node is used to construct FunctionDef.Node using initialization
124   // lists. E.g.,
125   //  Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}};  // z = x * y
126   struct Node {
127     // When constructing a NodeDef, the first entry in ret is used as
128     // the node name, the remaining values are ignored.
129     std::vector<string> ret;
130     std::string op;
131     std::vector<string> arg;
132     std::vector<std::pair<string, AttrValueWrapper>> attr;
133     std::vector<string> dep;
134     std::string device;
135 
136     NodeDef ToNodeDef() const;
137   };
138 
139   // Creates a FunctionDef from the given parameters. Node inputs must use
140   // function encoding (node_name:output_name[:output_index]).
141   // - `ret_def` holds a mapping from the function output names from `out_def`
142   //   to the node outputs from `node_def`.
143   // - `control_ret_def` holds a mapping from the function control
144   //   output names to the nodes from `node_def`.
145   static FunctionDef Create(
146       const std::string& function_name, gtl::ArraySlice<string> in_def,
147       gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
148       gtl::ArraySlice<Node> node_def,
149       gtl::ArraySlice<std::pair<string, string>> ret_def,
150       gtl::ArraySlice<std::pair<string, string>> control_ret_def);
151 
152   // Creates a FunctionDef from the given parameters. Node inputs must use
153   // function encoding (node_name:output_name[:output_index]).
154   // - `ret_def` holds a mapping from the function output names from `out_def`
155   //   to the node outputs from `node_def`.
156   static FunctionDef Create(const std::string& function_name,
157                             gtl::ArraySlice<string> in_def,
158                             gtl::ArraySlice<string> out_def,
159                             gtl::ArraySlice<string> attr_def,
160                             gtl::ArraySlice<Node> node_def,
161                             gtl::ArraySlice<std::pair<string, string>> ret_def);
162 
163   // TODO(josh11b): Get rid of these and transition to the one above.
164   static FunctionDef Define(const std::string& function_name,
165                             gtl::ArraySlice<string> arg_def,
166                             gtl::ArraySlice<string> ret_def,
167                             gtl::ArraySlice<string> attr_def,
168                             gtl::ArraySlice<Node> node_def);
169 
170   // Defines an anonymous function. I.e., its name is not relevant.
171   static FunctionDef Define(gtl::ArraySlice<string> arg_def,
172                             gtl::ArraySlice<string> ret_def,
173                             gtl::ArraySlice<string> attr_def,
174                             gtl::ArraySlice<Node> node_def);
175 
176   // Helpers to construct a constant scalar.
177   template <typename T>
Const(const std::string & name,const T & val)178   static Node Const(const std::string& name, const T& val) {
179     Node n = {{name}, "Const"};
180     const DataType dtype = DataTypeToEnum<T>::value;
181     n.attr.push_back({"dtype", dtype});
182     Tensor t(dtype, TensorShape({}));
183     t.scalar<T>()() = val;
184     n.attr.push_back({"value", t});
185     return n;
186   }
187 
188   template <typename T>
Const(const std::string & name,gtl::ArraySlice<T> vals)189   static Node Const(const std::string& name, gtl::ArraySlice<T> vals) {
190     Node n = {{name}, "Const"};
191     const DataType dtype = DataTypeToEnum<T>::value;
192     n.attr.push_back({"dtype", dtype});
193     int64 num = vals.size();
194     Tensor t(dtype, TensorShape({num}));
195     for (size_t i = 0; i < vals.size(); ++i) {
196       t.flat<T>()(i) = vals[i];
197     }
198     n.attr.push_back({"value", t});
199     return n;
200   }
201 };
202 
203 template <>
AttrValueWrapper(const char * val)204 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
205   InitFromString(val);
206 }
207 
208 template <>
AttrValueWrapper(const std::string & val)209 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
210     const std::string& val) {
211   InitFromString(val);
212 }
213 
214 template <>
AttrValueWrapper(StringPiece val)215 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
216   InitFromString(val);
217 }
218 
219 // Instantiate a function.
220 //
221 // "fdef" encodes a TF function with some attrs in fdef.signature.attr
222 // containing placeholders.  InstantiateFunction binds these
223 // placeholders and produces an instantiated function encoded in
224 // "result.gdef". The value to substitute a placeholder is given by
225 // "attr_values", which is a map from a placeholder name to an attr
226 // value.
227 //
228 // InstantiateFunction calls "get_function" to find signatures of other
229 // functions and primitive ops.
230 
231 // GetFunctionSignature(func name, opdef) returns OK if the func name is found
232 // and opdef is filled with a pointer to the corresponding signature
233 // (a OpDef proto). Otherwise, returns an error.
234 typedef std::function<Status(const string&, const OpDef**)>
235     GetFunctionSignature;
236 
237 struct InstantiationResult {
238   DataTypeVector arg_types;
239   DataTypeVector ret_types;
240   std::vector<NodeDef> nodes;
241 };
242 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
243                            GetFunctionSignature get_function,
244                            InstantiationResult* result);
245 
246 // Returns a debug string for a function definition.
247 //
248 // The returned text is multiple-line. It is intended to be
249 // human-readable rather than being friendly to parsers. It is _NOT_
250 // intended to be the canonical string representation of "func_def".
251 // Particularly, it may not include all information presented in
252 // "func_def" (e.g., comments, description of the function arguments,
253 // etc.)
254 std::string DebugString(const FunctionDef& func_def);
255 std::string DebugString(const GraphDef& instantiated_func_def);
256 std::string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
257 
258 // Returns a debug string for a top level graph (the main program and
259 // its supporting functions defined in its library).
260 std::string DebugStringWhole(const GraphDef& gdef);
261 
262 // Returns true if f1 == f2. Compares all fields, including descriptions. Order
263 // of NodeDefs doesn't matter.
264 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
265 
266 // Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
267 // In other words, if two fdefs compare equal, their hash values will be the
268 // same.
269 uint64 FunctionDefHash(const FunctionDef& fdef);
270 
271 class CallFrameInterface {
272  public:
~CallFrameInterface()273   virtual ~CallFrameInterface() {}
274 
275   virtual size_t num_args() const = 0;
276   virtual size_t num_retvals() const = 0;
277 
278   virtual Status GetArg(int index, const Tensor** val) = 0;
279 
280   // Optimized implementation of `GetArg()` that allows the caller to take
281   // ownership of the tensor. This method may only be called once per
282   // value of `index` and `CallFrameInterface` instance.
283   //
284   // REQUIRES: `this->CanConsumeArg(index) == true`.
ConsumeArg(int index,Tensor * val)285   virtual void ConsumeArg(int index, Tensor* val) {
286     LOG(ERROR) << "This `CallFrameInterface` implementation does not support "
287                   "consuming arguments.";
288   }
CanConsumeArg(int index)289   virtual bool CanConsumeArg(int index) const { return false; }
290 
291   virtual Status SetRetval(int index, const Tensor& val) = 0;
292 };
293 
294 // Represents a function call frame. I.e., the data structure used to
295 // pass arguments to a function and retrieve its results.
296 //
297 // Runtime must arrange accesses to one FunctionCallFrame s.t.
298 //   1. SetArgs() happens before any GetArg();
299 //   2. GetRetvals happens after all SetRetval();
300 class FunctionCallFrame : public CallFrameInterface {
301  public:
302   FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
303   ~FunctionCallFrame() override;
304 
305   // Caller methods.
306   Status SetArgs(gtl::ArraySlice<Tensor> args);
307   Status GetRetvals(std::vector<Tensor>* rets) const;
308 
309   // Moves the return values from the frame to rets. If allow_dead_tensors is
310   // false it will fail if any of the retvals do not have a value.
311   Status ConsumeRetvals(std::vector<Tensor>* rets, bool allow_dead_tensors);
312 
num_args()313   size_t num_args() const override { return arg_types_.size(); }
num_retvals()314   size_t num_retvals() const override { return ret_types_.size(); }
315 
316   // Callee methods.
317   Status GetArg(int index, const Tensor** val) override;
318   Status SetRetval(int index, const Tensor& val) override;
319 
320  private:
321   DataTypeVector arg_types_;
322   DataTypeVector ret_types_;
323   gtl::InlinedVector<Tensor, 4> args_;
324   struct Retval {
325     bool has_val = false;
326     Tensor val;
327   };
328   gtl::InlinedVector<Retval, 4> rets_;
329 
330   TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
331 };
332 
333 // Language agnostic stack traces.
334 class AbstractStackTrace {
335  public:
336   struct TracePrintingOptions {
337     // Show inline the contents of each stack line.
338     bool show_line_contents = false;
339 
340     // Drop the common largest prefix of all filenames in stack frames.
341     bool filter_common_prefix = false;
342 
343     // Do not show internal frames.
344     bool drop_internal_frames = false;
345   };
346 
~AbstractStackTrace()347   virtual ~AbstractStackTrace() {}
348 
349   // The returned span is alive as long as the AbstractStackTrace is alive.
350   virtual absl::Span<StackFrame const> ToFrames() const = 0;
351 
352   // Returns the last stack frame from user code, attempting to ignore the
353   // framework code. Returns an empty frame if no such stack frame was found.
354   virtual StackFrame LastUserFrame() const = 0;
355   virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
356 };
357 
358 using StackTracesMap =
359     std::unordered_map<std::string,
360                        std::shared_ptr<tensorflow::AbstractStackTrace>>;
361 
362 // Helper to maintain a map between function names in a given
363 // FunctionDefLibrary and function definitions.
364 //
365 // This class is thread-safe.
366 class FunctionLibraryDefinition : public OpRegistryInterface {
367  public:
368   // Ops created for function arguments bear the name given by `kArgOp`; those
369   // created for return values bear the name given by `kRetOp`.
370   static constexpr const char* const kArgOp = "_Arg";
371   static constexpr const char* const kDeviceArgOp = "_DeviceArg";
372   static constexpr const char* const kRetOp = "_Retval";
373   static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
374   static constexpr const char* const kIntsOnDeviceAttr =
375       "experimental_ints_on_device";
376   static constexpr const char* const kSharedRendezvousAttr =
377       "shared_rendezvous";
378 
379   static constexpr const char* const kGradientOp = "SymbolicGradient";
380   static constexpr const char* const kFuncAttr = "f";
381 
382   // Note: This constructor grabs `lib_def`'s lock in shared mode.
383   FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
384   FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
385                             const FunctionDefLibrary& lib_def);
386   ~FunctionLibraryDefinition() override;
387 
388   FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
389       delete;
390 
391   // Returns True if the library contains `func`, False otherwise.
392   bool Contains(const std::string& func) const;
393 
394   // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
395   // returns its definition proto.
396   //
397   // NB: This function returns a borrowed pointer, which can be invalidated by a
398   // subsequent call to `ReplaceFunction()` with the given name.
399   const FunctionDef* Find(const std::string& func) const TF_LOCKS_EXCLUDED(mu_);
400 
401   // Adds function definition 'fdef' to this function library.
402   // Returns status 'ok' on success, or error otherwise. This is a no-op if
403   // 'fdef' already exists in this function library.
404   // If 'fdef' is successfully added to the library, it will be accessible
405   // from 'LookUp' and included in the proto returned by 'ToProto'.
406   // This operation is atomic.
407   //
408   // Associates `graph` with a function `func_name`. Lifetime assumption:
409   // `graph` has to outlive all instantiated graphs.
410   Status AddFunctionDef(const FunctionDef& fdef,
411                         const StackTracesMap& stack_traces = {})
412       TF_LOCKS_EXCLUDED(mu_);
413 
414   // Adds gradient definition 'grad' to this function library.
415   // This is a no-op if 'grad' already exists in this function library.
416   // If 'grad' is successfully added, it will be accessible via 'FindGradient'
417   // and included in the proto returned by 'ToProto'.
418   // This operation is atomic.
419   Status AddGradientDef(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
420 
421   // Replaces the function corresponding to `func` with `fdef`. Returns
422   // a non-OK status if "func" was not found in the library, OK otherwise.
423   // Please be careful when replacing function: make sure all previous pointers
424   // returned by `Find()` are no longer in use.
425   Status ReplaceFunction(const std::string& func, const FunctionDef& fdef)
426       TF_LOCKS_EXCLUDED(mu_);
427 
428   // Replaces the gradient corresponding to `grad.function_name()`. Returns
429   // a non-OK status if "grad.function_name()" was not found in the library, OK
430   // otherwise.
431   Status ReplaceGradient(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
432 
433   // Removes the function corresponding to 'func'. Returns a non-OK status if
434   // 'func' was not found in the library, OK otherwise.
435   // Please be careful when removing function: make sure there are no other
436   // nodes using the function, and all previous pointers returned by `Find()`
437   // are no longer in use.
438   Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_);
439 
440   // Removes all the functions and gradient functions.
441   void Clear() TF_LOCKS_EXCLUDED(mu_);
442 
443   // Adds the functions and gradients in 'other' to this function library.
444   // Duplicate functions and gradients are ignored.
445   // This operation is atomic.
446   Status AddLibrary(const FunctionLibraryDefinition& other)
447       TF_LOCKS_EXCLUDED(mu_);
448 
449   // Adds the functions and gradients in 'lib_def' to this function library.
450   // Duplicate functions and gradients are ignored.
451   // This operation is atomic.
452   Status AddLibrary(const FunctionDefLibrary& lib_def) TF_LOCKS_EXCLUDED(mu_);
453 
454   // If the gradient function for 'func' is specified explicitly in
455   // the library, returns the gradient function name.  Otherwise,
456   // returns an empty string.
457   std::string FindGradient(const std::string& func) const
458       TF_LOCKS_EXCLUDED(mu_);
459 
460   // OpRegistryInterface method. Useful for constructing a Graph.
461   //
462   // If "op" is defined in the library, returns its signature.
463   // Otherwise, assume "op" is a primitive op and returns its op
464   // signature and shape inference function.
465   //
466   // NB: This function outputs a borrowed pointer, which can be invalidated by a
467   // subsequent call to `ReplaceFunction()` with the given name.
468   Status LookUp(const std::string& op_type_name,
469                 const OpRegistrationData** op_reg_data) const override
470       TF_LOCKS_EXCLUDED(mu_);
471 
472   // Generates new function name with the specified prefix that is unique
473   // across this library.
474   std::string UniqueFunctionName(StringPiece prefix) const
475       TF_LOCKS_EXCLUDED(mu_);
476 
477   // Given a node def 'ndef', inspects attributes of the callee
478   // function to derive the attribute 'value' for 'attr'. Returns OK
479   // iff the attribute is given by the function's definition.
480   // TODO(irving): Remove; keep only the const Node& version.
481   template <typename T>
482   Status GetAttr(const NodeDef& ndef, const std::string& attr, T* value) const;
483 
484   // Given a node, inspects attributes of the callee function to derive the
485   // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
486   // function's definition.
487   template <typename T>
488   Status GetAttr(const Node& node, const std::string& attr, T* value) const;
489 
490   // Returns a proto representation of the state of this function library.
491   FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_);
492 
num_functions()493   size_t num_functions() const {
494     tf_shared_lock l(mu_);
495     return function_defs_.size();
496   }
497 
498   // Returns all the function names in the FunctionLibraryDefinition.
499   std::vector<string> ListFunctionNames() const TF_LOCKS_EXCLUDED(mu_);
500 
default_registry()501   const OpRegistryInterface* default_registry() const {
502     return default_registry_;
503   }
504 
505   // Returns a copy of `*this` with only the subset of functions that are
506   // reachable from the nodes of `graph` or `func`.
507   FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const;
508   FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const;
509 
510   // Copies the function named `func` from `other` to this
511   // FunctionLibraryDefinition.
512   // REQUIRES: `this->default_registry() == other.default_registry()`.
513   // Returns OK on success, or error otherwise. This is a no-op if a function
514   // name `func` already exists in this function library, and has the same
515   // implementation as in `other`. If the implementations conflict, an invalid
516   // argument error is returned.
517   Status CopyFunctionDefFrom(const std::string& func,
518                              const FunctionLibraryDefinition& other)
519       TF_LOCKS_EXCLUDED(mu_);
520 
521   // Returns graph with debug stack traces for the given function, or `nullptr`
522   // if none found.
GetStackTraces(const std::string & func_name)523   const StackTracesMap& GetStackTraces(const std::string& func_name) const {
524     tf_shared_lock l(mu_);
525     std::shared_ptr<FunctionDefAndOpRegistration> entry = FindHelper(func_name);
526     if (entry) {
527       return entry->stack_traces;
528     }
529     static const auto* empty_map = new StackTracesMap;
530     return *empty_map;
531   }
532 
533  private:
534   // Shape inference for functions is handled separately by ShapeRefiner.
535 
536   struct FunctionDefAndOpRegistration {
537     explicit FunctionDefAndOpRegistration(
538         const FunctionDef& fdef_in, const StackTracesMap& stack_traces = {});
539 
540     const FunctionDef fdef;
541     const OpRegistrationData op_registration_data;
542     const StackTracesMap stack_traces;
543   };
544 
545   std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
546       const string& func) const TF_SHARED_LOCKS_REQUIRED(mu_);
547   std::string FindGradientHelper(const std::string& func) const
548       TF_SHARED_LOCKS_REQUIRED(mu_);
549 
550   Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,
551                    bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
552 
553   // Same as AddFunctionDef/AddGradientDef except these methods set
554   // `added` to true if the `fdef`/`grad` were actually added to this.
555   Status AddFunctionDefHelper(const FunctionDef& fdef,
556                               const StackTracesMap& stack_traces, bool* added)
557       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
558   Status AddGradientDefHelper(const GradientDef& grad, bool* added)
559       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
560 
561   // Helper function for GetAttr. Returns the FunctionDef* to get the
562   // attr from.
563   const FunctionDef* GetAttrImpl(const NodeDef& ndef) const
564       TF_LOCKS_EXCLUDED(mu_);
565 
566   // Remove all functions in `funcs` and all gradients of functions in
567   // `funcs_with_grads` from this library.
568   void Remove(const std::vector<string>& funcs,
569               const std::vector<string>& funcs_with_grads)
570       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
571 
572   // Remove `func` from the library. Returns non-OK Status unless `func` is in
573   // the library. This should only be called when there is a guarantee that the
574   // function being removed hasn't been retrieved with `Find`.
575   Status RemoveFunctionHelper(const std::string& func)
576       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
577 
578   // Remove gradient of function `func` from the library. Returns non-OK Status
579   // unless `func` has a gradient.
580   Status RemoveGradient(const std::string& func)
581       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
582 
583   mutable mutex mu_;
584   const OpRegistryInterface* const default_registry_;
585   gtl::FlatMap<string, std::shared_ptr<FunctionDefAndOpRegistration>>
586       function_defs_ TF_GUARDED_BY(mu_);
587   gtl::FlatMap<string, string> func_grad_ TF_GUARDED_BY(mu_);
588 };
589 
590 // Forward declare. Defined in common_runtime/function.h
591 struct FunctionBody;
592 
593 // Forward declare. Defined in common_runtime/device.h
594 class Device;
595 // Forward declare. Defined in common_runtime/device_mgr.h
596 class DeviceMgr;
597 
598 // Index of an _Arg node.
599 struct FunctionArgIndex {
FunctionArgIndexFunctionArgIndex600   explicit FunctionArgIndex(const int index) : index(index) {}
FunctionArgIndexFunctionArgIndex601   FunctionArgIndex(const int index, const int sub_index)
602       : index(index), sub_index(sub_index) {}
603 
604   // The value of the attribute "Index" of the _Arg node.
605   int index;
606   // Set only when the _Arg node represents multiple arguments (e.g. an _Arg
607   // node is replicated to multiple devices/subgraphs). Use sub-index to
608   // distinguish arguments with the same index.
609   int sub_index = -1;
610 };
611 
612 class FunctionLibraryRuntime {
613  public:
~FunctionLibraryRuntime()614   virtual ~FunctionLibraryRuntime() {}
615 
616   // Instantiate a function with the given "attrs".
617   //
618   // Returns OK and fills in "handle" if the instantiation succeeds.
619   // Otherwise returns an error and "handle" is undefined.
620   struct InstantiateOptions {
621     // The canonical device name of the device on which the function
622     // should be instantiated. If empty, the function will be
623     // instantiated on the local device.
624     std::string target;
625 
626     // Should the function be instantiated as a multi-device function?
627     bool is_multi_device_function = false;
628 
629     // If true, graph passes will be skipped when instantiating the function
630     // since they have already run on the main function side.
631     bool is_component_function = false;
632 
633     // For multi-device functions, a vector of canonical device names for
634     // function's inputs. The device of resource inputs must be the device
635     // backing the resource, not the CPU device backing the resource handle.
636     // Must have the same length as number of inputs to the function.
637     std::vector<string> input_devices;
638 
639     // For multi-device functions, a vector of canonical device names for
640     // function's outputs.
641     //
642     // (a) If specified (must have the same length as number of outputs):
643     //
644     // Specified devices will be assigned to Retval nodes inserted into the
645     // function body graph in place of function outputs. It is allowed to
646     // specify output device as empty string, in this case Retval device
647     // assignment will be inferred later when function graph will be placed
648     // before partitioning (this is required for resource outputs). Placer will
649     // respect colocation constraints.
650     //
651     // (b) If not specified:
652     //
653     // Function runtime will infer Retval device by following input edges, until
654     // it will reach a node with a device specification. This device
655     // specification must identify a unique device, i.e. a general specification
656     // like "job:foo" matching multiple devices will result in an error.
657     //
658     // IMPORTANT: Resource outputs
659     //
660     // Multi device functions might return resources on a devices different from
661     // the function call device. If output device is not specified for the
662     // resource output, and node producing that resource is a function call,
663     // runtime will leave device specification empty and will rely on Placer to
664     // infer correct device.
665     std::vector<string> output_devices;
666 
667     // If set, it indicates the original output indices of a component function.
668     absl::optional<std::vector<int>> ret_indices = absl::nullopt;
669 
670     // Maps from a CompositeDevice name to a list of underlying physical
671     // devices.
672     absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
673 
674     // This interface is EXPERIMENTAL and subject to change.
675     //
676     // For multi-device functions, a mapping from _Arg node index to type and
677     // shape for input resources.
678     // REQUIRES: if input_resource_dtypes_and_shapes.count(i) > 0 then i-th
679     // argument type must be DT_RESOURCE.
680     std::unordered_map<int, DtypeAndPartialTensorShape>
681         input_resource_dtypes_and_shapes;
682 
683     // This interface is EXPERIMENTAL and subject to change.
684     //
685     // If non-null, the runtime will use `lib_def` to resolve function(s) named
686     // in `function_name` and `attrs`. Otherwise, the runtime will use its
687     // internal library.
688     //
689     // NOTE(mrry): If provided, all functions defined in `lib_def` must be
690     // self-contained, and cannot refer to functions defined in other libraries.
691     const FunctionLibraryDefinition* lib_def = nullptr;
692 
693     // This interface is EXPERIMENTAL and subject to change.
694     //
695     // If non-empty, the runtime will use `state_handle` to identify
696     // cached state related the instantiated function. Two functions
697     // of the same name and attrs, instantiated with the same
698     // `state_handle` will have the same handle and share the same
699     // state (in stateful kernels); and two functions with different
700     // values for `state_handle` will have independent state.
701     std::string state_handle;
702 
703     // This interface is EXPERIMENTAL and subject to change.
704     //
705     // Instantiates the function using an executor of the given type. If empty,
706     // the default TensorFlow executor will be used.
707     std::string executor_type;
708 
709     // If true, the runtime will attempt to create kernels for the function at
710     // instantiation time, rather than on the first run. This can be used to
711     // surface errors earlier.
712     bool create_kernels_eagerly = false;
713 
714     // This interface is EXPERIMENTAL and subject to change.
715     //
716     // Instantiates the function with the provided config_proto.
717     ConfigProto config_proto;
718 
719     // If provided, this optimization function will be invoked before
720     // the placer for multi-device functions.
721     std::function<Status(std::vector<string> /*ret_node_names*/,
722                          std::vector<string> /*keep_node_names*/,
723                          FunctionLibraryDefinition*, const DeviceSet&,
724                          Device* /*cpu_device*/, std::unique_ptr<Graph>*)>
725         optimize_graph_fn;
726 
727     // If set, partitioned functions will be added to `graph_collector`.
728     // `graph_collector` must be alive during the call to Instantiate.
729     GraphCollector* graph_collector = nullptr;
730 
731     // Indicates whether the multi-device function backend should default the
732     // placement of ops without request device to `target`.
733     bool default_device_to_target = true;
734 
735     // If true, the optimized Graph will be stored so that
736     // `FunctionLibraryRuntime::DebugString(handle)` contains the optimized
737     // Graph. Otherwise, the unoptimized function Graph will be returned.
738     bool include_optimized_graph_in_debug_string = false;
739   };
740   typedef uint64 Handle;
741   virtual Status Instantiate(const std::string& function_name, AttrSlice attrs,
742                              const InstantiateOptions& options,
743                              Handle* handle) = 0;
Instantiate(const std::string & function_name,AttrSlice attrs,Handle * handle)744   Status Instantiate(const std::string& function_name, AttrSlice attrs,
745                      Handle* handle) {
746     auto opts = absl::make_unique<InstantiateOptions>();
747     return Instantiate(function_name, attrs, *opts, handle);
748   }
749 
750   // Releases state associated with the handle.
751   virtual Status ReleaseHandle(Handle handle) = 0;
752 
753   // Returns the function body for the instantiated function given its
754   // handle 'h'. Returns nullptr if "h" is not found.
755   //
756   // *this keeps the ownership of the returned object, which remains alive
757   // as long as *this.
758   virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
759 
760   // Returns the return types for the function identified by handle `h`.
761   virtual Status GetRetTypes(Handle h, DataTypeVector* ret_types) = 0;
762 
763   // Asynchronously invokes the instantiated function identified by
764   // "handle".
765   //
766   // If function execution succeeds, "done" is called with OK and
767   // "*rets" is filled with the function's return values. Otherwise,
768   // "done" is called with an error status.
769   //
770   // Does not take ownership of "rets".
771   // In the cross-process scenario, runner isn't used for making the Async
772   // RPC calls.
773   struct Options {
OptionsOptions774     Options() {}
OptionsOptions775     explicit Options(const int64 step_id) : step_id(step_id) {}
776     // Choose a step ID that is guaranteed not to clash with any
777     // Session-generated step ID. DirectSession only generates
778     // non-negative step IDs (contiguous, starting from 0), and
779     // MasterSession generates 56-bit random step IDs whose MSB is
780     // always 0, so a negative random step ID should suffice.
781     const int64 step_id = -std::abs(static_cast<int64>(random::New64()));
782 
783     // op_id of the function running in eager mode. Set when we want to copy
784     // remote outputs lazily. All components of a remote multi-device function
785     // should use the same op_id, in order to correctly map remote output
786     // tensors to the remote TensorHandles in the default device.
787     absl::optional<int64> op_id = absl::nullopt;
788 
789     RendezvousInterface* rendezvous = nullptr;
790     CancellationManager* cancellation_manager = nullptr;
791     CollectiveExecutor* collective_executor = nullptr;
792     ScopedStepContainer* step_container = nullptr;
793     StepStatsCollectorInterface* stats_collector = nullptr;
794 
795     std::function<void(std::function<void()>)>* runner = nullptr;
796 
797     // Parameters for remote function execution.
798     bool remote_execution = false;
799     std::string source_device = "";  // Fully specified device name.
800 
801     // Allocator attributes specifying where the args are / rets should be put.
802     // These should either be {} or match the length of args / retvals. If {},
803     // the default allocator attributes will be assumed for all args / retvals.
804     std::vector<AllocatorAttributes> args_alloc_attrs;
805     std::vector<AllocatorAttributes> rets_alloc_attrs;
806 
807     // If true, we create a new IntraProcessRendezvous, else use the existing
808     // one.
809     bool create_rendezvous = false;
810 
811     // If True, allow returning dead tensors.
812     bool allow_dead_tensors = false;
813 
814     // If True, hint that all kernels should be treated as "inexpensive", and
815     // hence executed on the scheduling thread.
816     bool run_all_kernels_inline = false;
817 
818     // Returns a human readable representation of this.
819     std::string DebugString() const;
820   };
821   typedef std::function<void(const Status&)> DoneCallback;
822   virtual void Run(const Options& opts, Handle handle,
823                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
824                    DoneCallback done) = 0;
825   virtual void Run(const Options& opts, Handle handle,
826                    CallFrameInterface* call_frame, DoneCallback done) = 0;
827 
828   virtual Status RunSync(Options opts, Handle handle,
829                          gtl::ArraySlice<Tensor> args,
830                          std::vector<Tensor>* rets) = 0;
831   virtual Status RunSync(Options opts, Handle handle,
832                          CallFrameInterface* call_frame) = 0;
833 
834   // Creates a "kernel" for the given NodeProperties "props".
835   //
836   // If succeeds, returns OK and the caller takes the ownership of the
837   // returned "*kernel". Otherwise, returns an error.
838   virtual Status CreateKernel(
839       const std::shared_ptr<const NodeProperties>& props,
840       OpKernel** kernel) = 0;
841 
842   // Returns true iff the function named `function_name` is stateful.
843   //
844   // NOTE(mrry): This method assumes that the runtime is associated with a
845   // default function library, and looks up `function_name` in that library.
846   // It does not support overriding the function library.
847   virtual bool IsStateful(const std::string& function_name) const = 0;
848 
849   // Returns the device on which the function executes.
850   virtual Device* device() = 0;
851   virtual const Device* device() const = 0;
852 
853   // Returns the default runner in which the ops should be launched. If the
854   // device on which the function executes has a private thread pool, return
855   // runner on the device local thread pool.
856   virtual std::function<void(std::function<void()>)>* runner() = 0;
857 
858   // Get the DeviceMgr from which the device was obtained.
859   virtual const DeviceMgr* device_mgr() const = 0;
860 
861   // Returns the function library definition that backs this runtime.
862   //
863   // NOTE(mrry): The returned library definition is the default function library
864   // for this runtime. The caller may override the function library used by the
865   // runtime to instantiate functions, which will not be reflected in the return
866   // value of this function.
867   virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
868       const = 0;
869 
870   // Returns the environment on which the function executes.
871   virtual Env* env() = 0;
872 
873   // Returns the ConfigProto passed to the session used to create the function.
874   virtual const ConfigProto* const config_proto() = 0;
875 
876   // Returns a debug string showing the definition of the function of
877   // 'handle'.
878   virtual std::string DebugString(Handle handle) = 0;
879 
880   // Returns the graph version number.
881   virtual int graph_def_version() const = 0;
882 
883   typedef uint64 LocalHandle;
884 
885   // Creates a copy of ProcessFunctionLibraryRuntime (transferring ownership to
886   // the caller), FunctionLibraryRuntime (owned by the returned
887   // ProcessFunctionLibraryRuntime), FunctionLibraryDefinition (transferring
888   // ownership to the caller). Note that both the ProcessFunctionLibraryRuntime
889   // and FunctionLibraryRuntime borrow a pointer to the
890   // FunctionLibraryDefinition and so the FunctionLibraryDefinition should
891   // outlive both.
892   //
893   // The `skip_flib_def` argument controls whether the method should clone the
894   // FunctionLibraryDefinition (default behavior) or return an empty function
895   // library. The latter is used by tf.data, which manages
896   // FunctionLibraryDefinitions for its functions independently (and passes
897   // these into the FunctionLibraryRuntime through an overlay), to avoid linear
898   // runtime w.r.t. to number of functions in the current function library.
899   virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
900                        std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
901                        FunctionLibraryRuntime** out_flr,
902                        bool skip_flib_def = false) = 0;
903 
904   // Returns the name of the executor class (in the sense of
905   // `ExecutorFactory::GetFactory()`) that will be used based on the given
906   // dynamic `options` and static `attrs`. If none is specified, this method
907   // will return an empty string, which leaves the decision up to the runtime.
908   static std::string ExecutorType(const InstantiateOptions& options,
909                                   AttrSlice attrs);
910 };
911 
912 // Returns the device of the `arg_index`-th function input. Update
913 // `composite_devices` if the input device is a composite device.
914 std::string GetFunctionResourceInputDevice(
915     const Tensor& input, const int arg_index, const FunctionDef& function_def,
916     absl::flat_hash_map<string, std::vector<string>>* composite_devices);
917 
918 // Returns a canonicalized string for the instantiation of the
919 // function of the given "name", attributes "attrs", and "options".
920 //
921 // The returned string is guaranteed to be stable within one address
922 // space. But it may be change as the implementation
923 // evolves. Therefore, it should not be persisted or compared across
924 // address spaces.
925 std::string Canonicalize(
926     const std::string& funcname, AttrSlice attrs,
927     const FunctionLibraryRuntime::InstantiateOptions& options);
928 std::string Canonicalize(const std::string& funcname, AttrSlice attrs);
929 
930 const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
931 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
932 
933 class CustomKernelCreator {
934  public:
~CustomKernelCreator()935   virtual ~CustomKernelCreator() {}
936 
937   // Given a NodeDef 'node_def' and the function library runtime 'flr',
938   // validate if the class supports creating such a kernel.
939   virtual bool CanCreateKernel(
940       const FunctionLibraryRuntime& flr,
941       const std::shared_ptr<const NodeProperties>& props) const = 0;
942 
943   // Given a supported NodeDef, returns a kernel that computes the node.
944   virtual Status CreateKernel(
945       FunctionLibraryRuntime* flr,
946       const std::shared_ptr<const NodeProperties>& props,
947       std::unique_ptr<OpKernel>* kernel) const = 0;
948 };
949 
950 typedef
951 #if !defined(IS_MOBILE_PLATFORM)
952     absl::variant<Tensor, eager::RemoteTensorHandle*>
953         FunctionArg;
954 #else
955     absl::variant<Tensor>
956         FunctionArg;
957 #endif
958 
959 // Either a local tensor or the shape of a remote tensor.
960 typedef absl::variant<Tensor, TensorShape> FunctionRet;
961 
962 // Used to instantiate and run functions in a distributed system.
963 class DistributedFunctionLibraryRuntime {
964  public:
~DistributedFunctionLibraryRuntime()965   virtual ~DistributedFunctionLibraryRuntime() {}
966 
967   // Instantiate a function on a remote target specified in `options.target`, by
968   // sending the name and definition of the function to the remote worker. The
969   // local `handle` is filled for the instantiated function data and can be used
970   // for subsequent run function calls on the remote target.
971   virtual void Instantiate(
972       const std::string& function_name,
973       const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
974       const FunctionLibraryRuntime::InstantiateOptions& options,
975       FunctionLibraryRuntime::LocalHandle* handle,
976       FunctionLibraryRuntime::DoneCallback done) = 0;
977 
978   // Run an instantiated remote function (specified by `handle`) with a list of
979   // input Tensors in `args` and get its output Tensors in `rets`. The input
980   // tensor data will be sent with the function execution request, and must be
981   // available on the current caller side.
982   // opts.runner isn't used for execution.
983   virtual void Run(const FunctionLibraryRuntime::Options& opts,
984                    FunctionLibraryRuntime::LocalHandle handle,
985                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
986                    FunctionLibraryRuntime::DoneCallback done) = 0;
987 
988   // Run an instantiated remote function (specified by `handle`) with a list of
989   // input Tensors or RemoteTensorHandles as `args` and get its output Tensors
990   // or TensorShapes in `rets`. When using RemoteTensorHandles as function
991   // inputs or TensorShapes as outputs, the corresponding tensor data will be
992   // resolved on the remote worker, so it is not required to be locally
993   // available on the caller side. Using RemoteTensorHandle inputs is not
994   // supported in TensorFlow v1 runtime.
995   virtual void Run(const FunctionLibraryRuntime::Options& opts,
996                    FunctionLibraryRuntime::LocalHandle handle,
997                    gtl::ArraySlice<FunctionArg> args,
998                    std::vector<FunctionRet>* rets,
999                    FunctionLibraryRuntime::DoneCallback done) = 0;
1000 
1001   // Clean up a previously instantiated function on remote worker.
1002   virtual void CleanUp(uint64 step_id,
1003                        FunctionLibraryRuntime::LocalHandle handle,
1004                        FunctionLibraryRuntime::DoneCallback done) = 0;
1005 
1006   // DeviceMgr with *all* available devices (i.e., local and remote).
1007   virtual DeviceMgr* remote_device_mgr() const = 0;
1008 };
1009 
1010 // Extracts the actual type from "attr_values" based on its definition
1011 // "arg_def".
1012 //
1013 // If "arg_def" is a N*T type, *is_type_list is set to false, and
1014 // *dtypes is set to be a vector of size N and each element is T.
1015 //
1016 // If "arg_def" is a list(type), *is_type_list is set to true, and
1017 // *dtypes is set to be a vector of types specified in attrs for
1018 // arg_def.
1019 //
1020 // Otherwise (arg_def is a simple type T), *is_type_list is set to
1021 // false, and *dtypes is set to a single element vector, whose only
1022 // element is T.
1023 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
1024                   bool* is_type_list, DataTypeVector* dtypes);
1025 
1026 // To register a gradient function for a builtin op, one should use
1027 //   REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
1028 //
1029 // Typically, the c++ grad factory is a plan function that can be
1030 // converted into ::tensorflow::gradient::Creator, which is
1031 //   std::function<Status(const AttrSlice&, FunctionDef*)>.
1032 //
1033 // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
1034 // definition of a brain function which compute the gradient for the
1035 // <op_name> when the <op_name> is instantiated with the given attrs.
1036 //
1037 // E.g.,
1038 //
1039 // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
1040 //   bool transpose_a;
1041 //   TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
1042 //   bool transpose_b;
1043 //   TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
1044 //   DataType dtype;
1045 //   TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
1046 //   if (!transpose_a && !transpose_b) {
1047 //     *g = FunctionDefHelper::Define(
1048 //       "MatMulGrad",
1049 //       {"x:T ", "y:T", "dz:T"},    // Inputs to this function
1050 //       {"dx:T", "dy:T"},           // Outputs from this function
1051 //       {"T: {float, double}"},     // Attributes needed by this function
1052 //       {
1053 //         {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
1054 //         {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
1055 //         {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
1056 //         {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
1057 //       });
1058 //   } else {
1059 //     ... ...
1060 //   }
1061 //   return Status::OK();
1062 // }
1063 //
1064 // NOTE: $T is substituted with the type variable "T" when the
1065 // gradient function MatMul is instantiated.
1066 //
1067 // TODO(zhifengc): Better documentation somewhere.
1068 
1069 // Macros to define a gradient function factory for a primitive
1070 // operation.
1071 #define REGISTER_OP_GRADIENT(name, fn) \
1072   REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
1073 
1074 #define REGISTER_OP_NO_GRADIENT(name) \
1075   REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
1076 
1077 #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
1078   REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
1079 
1080 #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)      \
1081   static bool unused_grad_##ctr TF_ATTRIBUTE_UNUSED = \
1082       SHOULD_REGISTER_OP_GRADIENT &&                  \
1083       ::tensorflow::gradient::RegisterOp(name, fn)
1084 
1085 namespace gradient {
1086 // Register a gradient creator for the "op".
1087 typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
1088 bool RegisterOp(const std::string& op, Creator func);
1089 
1090 // Returns OK the gradient creator for the "op" is found (may be
1091 // nullptr if REGISTER_OP_NO_GRADIENT is used.
1092 Status GetOpGradientCreator(const std::string& op, Creator* creator);
1093 };  // namespace gradient
1094 
1095 // Declare explicit instantiations of GetAttr
1096 #define GET_ATTR(T)                                          \
1097   extern template Status FunctionLibraryDefinition::GetAttr( \
1098       const Node&, const string&, T*) const;                 \
1099   extern template Status FunctionLibraryDefinition::GetAttr( \
1100       const NodeDef&, const string&, T*) const;
1101 GET_ATTR(string)
1102 GET_ATTR(bool)
1103 #undef GET_ATTR
1104 
1105 }  // end namespace tensorflow
1106 
1107 #endif  // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
1108