1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
17 #include "tensorflow/core/framework/graph.pb.h"
18 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
19 
20 #include "tensorflow/cc/framework/scope.h"
21 #include "tensorflow/cc/ops/const_op.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/graph/node_builder.h"
24 #include "tensorflow/core/platform/logging.h"
25 namespace tensorflow {
26 
27 // function alias
28 constexpr auto AddOutputTensorShapeTypeByTensorShapeMap =
29     &RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap;
30 
31 /* static */ std::priority_queue<std::tuple<float, int, string>>
GetTopNFloatResults(const float * const data,const string * const labels,const int element_count)32 GraphTransferUtils::GetTopNFloatResults(const float* const data,
33                                         const string* const labels,
34                                         const int element_count) {
35   CHECK(data != nullptr);
36   CHECK(labels != nullptr);
37   std::priority_queue<std::tuple<float, int, string>> queue;
38   for (int i = 0; i < element_count; ++i) {
39     queue.emplace(data[i], i, labels[i]);
40   }
41   return queue;
42 }
43 
DumpTopNFloatResults(const float * const data,const string * const labels,const int element_count,const int top_n)44 /* static */ void GraphTransferUtils::DumpTopNFloatResults(
45     const float* const data, const string* const labels,
46     const int element_count, const int top_n) {
47   std::priority_queue<std::tuple<float, int, string>> queue =
48       GetTopNFloatResults(data, labels, element_count);
49   LOG(INFO) << "=== Dump ranking ===";
50   for (int i = 0; i < top_n; ++i) {
51     const std::tuple<float, int, string>& entry = queue.top();
52     LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry)
53               << ", " << std::get<0>(entry);
54     queue.pop();
55   }
56 }
57 
58 /* static */ RemoteFusedGraphExecuteInfo
BuildRemoteFusedGraphExecuteInfo(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & outputs,const RemoteFusedGraphExecuteUtils::TensorShapeMap & tensor_shape_map)59 GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
60     const GraphDef& graph_def,
61     const std::vector<std::pair<string, Tensor>>& inputs,
62     const std::vector<string>& outputs,
63     const RemoteFusedGraphExecuteUtils::TensorShapeMap& tensor_shape_map) {
64   RemoteFusedGraphExecuteInfo execute_info;
65   execute_info.set_executor_name("build_hexagon_remote_fused_graph_executor");
66 
67   // copy graph
68   *execute_info.mutable_remote_graph() = graph_def;
69 
70   for (const std::pair<string, Tensor>& input : inputs) {
71     execute_info.add_graph_input_node_name(input.first);
72     RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
73         *execute_info.add_default_graph_input_tensor_shape();
74     tensor_shape_type.set_dtype(input.second.dtype());
75     TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
76     for (const int64 dim : input.second.shape().dim_sizes()) {
77       tensor_shape_proto.add_dim()->set_size(dim);
78     }
79   }
80 
81   for (const string& output_name : outputs) {
82     const std::pair<DataType, TensorShape>* tensor_shape_type =
83         RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
84                                                          output_name);
85     CHECK_NOTNULL(tensor_shape_type);
86     execute_info.add_graph_output_node_name(output_name);
87     RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type_proto =
88         *execute_info.add_default_graph_output_tensor_shape();
89     tensor_shape_type_proto.set_dtype(tensor_shape_type->first);
90     TensorShapeProto& tensor_shape_proto =
91         *tensor_shape_type_proto.mutable_shape();
92     for (const int64 dim : tensor_shape_type->second.dim_sizes()) {
93       tensor_shape_proto.add_dim()->set_size(dim);
94     }
95   }
96 
97   return execute_info;
98 }
99 
BuildFusedGraphDef(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const string & remote_graph_execute_name,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & outputs,GraphDef * original_def)100 /* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef(
101     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
102     const string& remote_graph_execute_name,
103     const std::vector<std::pair<string, Tensor>>& inputs,
104     const std::vector<string>& outputs, GraphDef* original_def) {
105   RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
106   Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
107       *original_def, inputs, true /* initialize_by_zero */, &tensor_shape_map);
108   for (NodeDef& node_def : *original_def->mutable_node()) {
109     TF_CHECK_OK(
110         AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
111   }
112   CHECK(status.ok());
113 
114   Scope root = Scope::NewRootScope();
115   std::vector<Output> output_list;
116   DataTypeVector input_types;
117   for (const std::pair<string, Tensor>& input_node_info : inputs) {
118     const Scope& scope = root.WithOpName(input_node_info.first);
119     Node* ret;
120     const auto unique_name = scope.GetUniqueNameForOp("Placeholder");
121     auto builder = NodeBuilder(unique_name, "Placeholder")
122                        .Attr("dtype", input_node_info.second.dtype())
123                        .Attr("shape", input_node_info.second.shape());
124     scope.UpdateBuilder(&builder);
125     scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
126     TF_CHECK_OK(scope.status());
127     output_list.emplace_back(Output(ret, 0));
128     input_types.push_back(input_node_info.second.dtype());
129   }
130 
131   const RemoteFusedGraphExecuteInfo execute_info =
132       BuildRemoteFusedGraphExecuteInfo(*original_def, inputs, outputs,
133                                        tensor_shape_map);
134 
135   DataTypeVector output_types;
136   // Sanity-check to confirm all output data types are same.
137   for (const string& output_node_name : outputs) {
138     const std::pair<DataType, TensorShape>* tst =
139         RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
140                                                          output_node_name);
141     CHECK_NE(tst, nullptr);
142     output_types.push_back(tst->first);
143   }
144 
145   const Scope& scope = root.WithOpName(remote_graph_execute_name);
146   CHECK(scope.ok());
147   auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
148   Node* node;
149   const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
150 
151   auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
152                      .Input(node_out_list)
153                      .Attr("Tinputs", input_types)
154                      .Attr("Toutputs", output_types)
155                      .Attr("serialized_remote_fused_graph_execute_info",
156                            StringPiece(execute_info.SerializeAsString()));
157   CHECK(scope.ok());
158   scope.UpdateBuilder(&builder);
159   scope.UpdateStatus(builder.Finalize(scope.graph(), &node));
160   CHECK(scope.ok()) << scope.status();
161 
162   GraphDef fusedGraphDef;
163   TF_CHECK_OK(root.ToGraphDef(&fusedGraphDef));
164   return fusedGraphDef;
165 }
166 
167 }  // namespace tensorflow
168