1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_
18 
19 #include <functional>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/grappler/op_types.h"
23 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 namespace fusion_utils {
30 
31 // These functions are invoked with first and second function signature,
32 // should set a signature of fused second_function.
33 using SetFunctionSignatureFn = std::function<void(
34     const OpDef& first_function_signature,
35     const OpDef& second_function_signature, OpDef* fused_function_signature)>;
36 
37 using StringCollection = gtl::InlinedVector<string, 2>;
38 
39 // These functions are invoked with nodes from second function that were
40 // previously taking arguments as input. The `arg_num` tells which
41 // function argument node was using as an input, e.g:
42 // node(arg_1, other_node, arg_4)
43 // would be called on the first and third input with arg_num equal 1 and 4.
44 // It should set up inputs based on first function inputs or outputs or
45 // second function inputs.
46 using SetInputFn =
47     std::function<string(const StringCollection& first_function_inputs,
48                          const StringCollection& second_function_inputs,
49                          const StringCollection& parent_outputs, int arg_num)>;
50 
51 // This function is invoked with first and second function ret. It is used to
52 // set up returns of fused function.
53 using SetOutputFn =
54     std::function<void(const protobuf::Map<string, string>& parent_ret,
55                        const protobuf::Map<string, string>& second_function_ret,
56                        protobuf::Map<string, string>* fused_ret)>;
57 
58 using SetNodesFn = std::function<void(
59     const FunctionDef& first_function, const FunctionDef& second_function,
60     FunctionDef* fused_function, FunctionDefLibrary* library)>;
61 
62 void MergeNodes(const FunctionDef& first_function,
63                 const FunctionDef& second_function, FunctionDef* fused_function,
64                 FunctionDefLibrary* library);
65 
66 // Returns true if functions can be composed.
67 bool CanCompose(const OpDef& first_signature, const OpDef& second_signature);
68 
69 void ComposeSignature(const OpDef& first_signature,
70                       const OpDef& second_signature, OpDef* fused_signature);
71 
72 string ComposeInput(const StringCollection& first_inputs,
73                     const StringCollection& second_inputs,
74                     const StringCollection& first_outputs, int arg_num);
75 
76 // Sets output to the composition of first and second function:
77 // second_function(first_function(args...)).
78 void ComposeOutput(const protobuf::Map<string, string>& first_ret,
79                    const protobuf::Map<string, string>& second_ret,
80                    protobuf::Map<string, string>* fused_ret);
81 
82 // Set input signature to `first_function_signature` and output signature
83 // to `first_function_signature` + `second_function_signature`
84 void CombineSignature(const OpDef& first_signature,
85                       const OpDef& second_signature, OpDef* fused_signature);
86 
87 // Apart from first function returns, return values from second function as
88 // extra returns like:
89 // return *first_function(...), *second_function(...)
90 void CombineOutput(const protobuf::Map<string, string>& first_ret,
91                    const protobuf::Map<string, string>& second_ret,
92                    protobuf::Map<string, string>* fused_ret);
93 
94 // Returns true if both signatures have the same number of input and output
95 // args.
96 bool HasSameSignature(const OpDef& first_signature,
97                       const OpDef& second_signature);
98 
99 // Check if both signatures are same and copy it from `first_signature`.
100 void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
101                    OpDef* fused_signature);
102 
103 // Take the same input as first function.
104 string SameInput(const StringCollection& first_inputs,
105                  const StringCollection& second_inputs,
106                  const StringCollection& first_outputs, int arg_num);
107 
108 // Create a fused function that computes the short-circuit logical AND of the
109 // result of the first function and the result of the second function.
110 void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
111                            const protobuf::Map<string, string>& second_ret,
112                            protobuf::Map<string, string>* fused_ret);
113 
114 void LazyConjunctionNodes(const FunctionDef& first_function,
115                           const FunctionDef& second_function,
116                           FunctionDef* fused_function,
117                           FunctionDefLibrary* library);
118 
119 // Fuse `first_function` with `second_function`, setting `fused_name_prefix` as
120 // a name prefix.  The nodes from `first_function` are copied unmodified.  All
121 // of the setup functions are called with a copy of second function having names
122 // that are not conflicting with first function.  This means that copied nodes
123 // from  second function can end up having different names.  For explanation of
124 // set up functions see the documentation of the functions types.
125 FunctionDef* FuseFunctions(
126     const FunctionDef& first_function, const FunctionDef& second_function,
127     StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
128     const SetInputFn& set_input, const SetOutputFn& set_output,
129     const SetNodesFn& set_nodes, FunctionDefLibrary* library);
130 
131 }  // namespace fusion_utils
132 }  // namespace grappler
133 }  // namespace tensorflow
134 
135 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_
136