1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
16 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
17 
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/core/framework/function.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/lib/core/status.h"
28 
29 namespace tensorflow {
30 namespace grappler {
31 class FunctionApiInfo {
32  public:
33   FunctionApiInfo();
34   virtual ~FunctionApiInfo();
35 
36   enum FunctionType {
37     INFERENCE,  // Default type.
38     FORWARD,
39     BACKWARD,
40   };
41 
42   Status Init(const FunctionDef& function_def);
43 
44   const string& interface_name() const;
45   const string& preferred_device() const;
46   const FunctionType function_type() const;
47   const string& pairing_function_name() const;
48   const DataTypeVector& input_arg_dtypes() const;
49   const DataTypeVector& output_arg_dtypes() const;
50 
51  private:
52   string interface_name_;
53   string preferred_device_;
54   FunctionType function_type_;
55   // The pairing function is used to pair between forward and backward function,
56   // which will be useful during function swapping. Inference function won't
57   // have pairing function.
58   string pairing_function_name_;
59   // The following two attributes are useful for forward and backward functions.
60   DataTypeVector input_arg_dtypes_;
61   DataTypeVector output_arg_dtypes_;
62 
63   TF_DISALLOW_COPY_AND_ASSIGN(FunctionApiInfo);
64 };
65 
66 // A collection of information for function and the interface it implements.
67 // A interface is a well defined math operation, eg I1 = 2 * x + y. Multiple
68 // functions could implement the same interface with different behavior based on
69 // different hardware condition and limits,
70 // eg F1 = math_ops.add(math_ops.add(x, x), y), or
71 //    F2 = math_ops.add(math_ops.matmul(x, 2), y).
72 class FunctionLibraryApiInfo {
73  public:
74   FunctionLibraryApiInfo();
75   virtual ~FunctionLibraryApiInfo();
76   // Populate the internal field for the functions within the function_library.
77   Status Init(const FunctionDefLibrary& function_library);
78 
79   Status GetEquivalentImplementations(
80       const string& function_name, std::vector<string>* other_functions) const;
81 
82   const FunctionApiInfo* GetApiInfo(const string& function_name) const;
empty()83   bool empty() const { return func_info_.empty(); }
size()84   std::size_t size() const { return func_info_.size(); }
85 
86  private:
87   // Map between function name to function details.
88   std::unordered_map<string, std::unique_ptr<FunctionApiInfo>> func_info_;
89 
90   // Map between interface name to function names.
91   // Forward/backward function pair usually have different signatures between
92   // each other since forward function could produce extra internal state as
93   // output, and backward will take those extra state as inputs.
94   absl::flat_hash_map<string, std::vector<string>> intf_to_inference_funcs_;
95   absl::flat_hash_map<string, std::vector<string>> intf_to_forward_funcs_;
96   absl::flat_hash_map<string, std::vector<string>> intf_to_backward_funcs_;
97 
98   TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryApiInfo);
99 };
100 
101 }  // end namespace grappler
102 }  // end namespace tensorflow
103 
104 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
105