1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
17
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/versions.pb.h"
23 #include "tensorflow/core/grappler/clusters/cluster.h"
24 #include "tensorflow/core/grappler/devices.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30
31 namespace tensorflow {
32 namespace grappler {
33 const char kAutoParallelPrefix[] = "AutoParallel";
34
AddNodeDivConst()35 NodeDef* AutoParallel::AddNodeDivConst() {
36 NodeDef* node = graph_.add_node();
37 node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-Const"));
38 node->set_op("Const");
39
40 AttrValue attr_data_type;
41 attr_data_type.set_type(DT_FLOAT);
42 node->mutable_attr()->insert({"dtype", attr_data_type});
43
44 AttrValue attr_tensor;
45 auto tensor = attr_tensor.mutable_tensor();
46 tensor->add_float_val(static_cast<float>(num_replicas_));
47 tensor->set_dtype(DT_FLOAT);
48 node->mutable_attr()->insert({"value", attr_tensor});
49 return node;
50 }
51
AddNodeDiv(const string & name,const string & input_a,const string & input_b)52 NodeDef* AutoParallel::AddNodeDiv(const string& name, const string& input_a,
53 const string& input_b) {
54 NodeDef* node = graph_.add_node();
55 node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-", name));
56 node->set_op("RealDiv");
57 node->add_input(input_a);
58 node->add_input(input_b);
59 AttrValue attr_type;
60 attr_type.set_type(DT_FLOAT);
61 node->mutable_attr()->insert({"T", attr_type});
62 return node;
63 }
64
AddNodeControl(const string & name,const std::set<string> & deps,GraphDef * graph)65 NodeDef* AutoParallel::AddNodeControl(const string& name,
66 const std::set<string>& deps,
67 GraphDef* graph) {
68 NodeDef* node = graph->add_node();
69 node->set_name(name);
70 node->set_op("NoOp");
71 for (const auto& dep : deps) {
72 node->add_input(strings::StrCat("^", dep));
73 }
74 return node;
75 }
76
Initialize(const GrapplerItem & item)77 Status AutoParallel::Initialize(const GrapplerItem& item) {
78 num_gpus_ = GetNumAvailableGPUs();
79 LOG(INFO) << "Number of GPUs: " << num_gpus_;
80 item_ = &item;
81 graph_ = item.graph;
82 LOG(INFO) << "Original graph size: " << graph_.node_size();
83 if (item.fetch.empty()) {
84 return Status(error::INVALID_ARGUMENT, "No fetch nodes provided.");
85 }
86
87 if (item.MainVariables().empty()) {
88 return Status(error::INVALID_ARGUMENT, "No variables provided.");
89 }
90
91 for (const auto& init : item.init_ops) {
92 VLOG(1) << "Init node: " << init;
93 }
94
95 for (const auto& fetch : item.fetch) {
96 VLOG(1) << "Fetch node: " << fetch;
97 }
98
99 for (const auto& var : item.MainVariables()) {
100 VLOG(2) << "Variable: " << var->name();
101 }
102
103 const std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
104 "ApplyProximalGradientDescent",
105 "ApplyAdadelta",
106 "ApplyAdagrad",
107 "ApplyProximalAdagrad",
108 "ApplyAdagradDA",
109 "ApplyFtrl",
110 "ApplyMomentum",
111 "ApplyAdam",
112 "ApplyRMSProp",
113 "ApplyCenteredRMSProp"};
114 for (int i = 0; i < graph_.node_size(); i++) {
115 all_nodes_.insert(
116 std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
117 if (apply_gradients_ops.find(graph_.node(i).op()) !=
118 apply_gradients_ops.end()) {
119 apply_gradients_nodes_.insert(graph_.node(i).name());
120 VLOG(2) << "Apply gradients node: " << graph_.node(i).name();
121 }
122 }
123
124 auto div_const_node = AddNodeDivConst();
125 all_nodes_.insert(std::make_pair(div_const_node->name(), div_const_node));
126 std::map<string, int> gradient_pos = {{"ApplyGradientDescent", 2},
127 {"ApplyProximalGradientDescent", 4},
128 {"ApplyAdadelta", 6},
129 {"ApplyAdagrad", 3},
130 {"ApplyProximalAdagrad", 5},
131 {"ApplyAdagradDA", 3},
132 {"ApplyFtrl", 3},
133 {"ApplyMomentum", 3},
134 {"ApplyAdam", 9},
135 {"ApplyRMSProp", 7},
136 {"ApplyCenteredRMSProp", 8}};
137 for (const auto& apply_gradient_node_name : apply_gradients_nodes_) {
138 auto apply_gradients_op = all_nodes_[apply_gradient_node_name]->op();
139 auto apply_gradients_node = all_nodes_[apply_gradient_node_name];
140
141 auto div_node = AddNodeDiv(
142 apply_gradient_node_name,
143 apply_gradients_node->input(gradient_pos[apply_gradients_op]),
144 div_const_node->name());
145 all_nodes_.insert(std::make_pair(div_node->name(), div_node));
146 *apply_gradients_node->mutable_input(gradient_pos[apply_gradients_op]) =
147 div_node->name();
148 }
149 LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size();
150
151 std::vector<const NodeDef*> train_nodes;
152 TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph_, item.fetch, &train_nodes));
153 LOG(INFO) << "Number of training nodes: " << train_nodes.size();
154
155 const NodeDef* dequeue_node;
156 for (const auto& train_node : train_nodes) {
157 if (IsDequeueOp(*train_node)) {
158 dequeue_node = train_node;
159 break;
160 }
161 }
162
163 std::vector<const NodeDef*> input_nodes;
164 if (dequeue_node) {
165 LOG(INFO) << "Dequeue node: " << dequeue_node->name();
166 TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph_, {dequeue_node->name()},
167 {}, &input_nodes));
168 }
169 LOG(INFO) << "Number of input nodes: " << input_nodes.size();
170
171 std::set<string> dont_replicate_nodes;
172 for (const auto& variable : item.MainVariables()) {
173 dont_replicate_nodes.insert(variable->name());
174 }
175
176 for (const auto& init : item.init_ops) {
177 dont_replicate_nodes.insert(NodeName(init));
178 }
179
180 // Don't replicate all input nodes, except the dequeue node.
181 for (const auto& input_node : input_nodes) {
182 if (input_node->name() != dequeue_node->name()) {
183 dont_replicate_nodes.insert(input_node->name());
184 }
185 }
186
187 for (const auto& node : train_nodes) {
188 if (dont_replicate_nodes.find(node->name()) == dont_replicate_nodes.end()) {
189 replica_nodes_.insert(node->name());
190 }
191 }
192 LOG(INFO) << "Number of replica nodes: " << replica_nodes_.size();
193
194 for (const auto& node : all_nodes_) {
195 if (replica_nodes_.find(node.first) == replica_nodes_.end()) {
196 shared_nodes_.insert(node.first);
197 }
198 }
199 LOG(INFO) << "Number of shared nodes: " << shared_nodes_.size();
200 return Status::OK();
201 }
202
NotSharedNode(const string & name)203 bool AutoParallel::NotSharedNode(const string& name) {
204 return shared_nodes_.find(name) == shared_nodes_.end();
205 }
206
AddSharedNodes(GraphDef * graph)207 void AutoParallel::AddSharedNodes(GraphDef* graph) {
208 string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", 0);
209 for (const auto& node : shared_nodes_) {
210 auto new_node = graph->add_node();
211 *new_node = *all_nodes_[node];
212 for (int i = 0; i < new_node->input_size(); i++) {
213 if (NotSharedNode(NodeName(new_node->input(i)))) {
214 string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
215 *new_node->mutable_input(i) = new_name;
216 }
217 }
218 }
219 }
220
AddOneReplica(GraphDef * graph,int number)221 void AutoParallel::AddOneReplica(GraphDef* graph, int number) {
222 string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", number);
223 for (const auto& node : replica_nodes_) {
224 auto new_node = graph->add_node();
225 *new_node = *all_nodes_[node];
226 if (NotSharedNode(new_node->name())) {
227 new_node->set_name(AddPrefixToNodeName(new_node->name(), prefix));
228 if (num_gpus_ > 0) {
229 new_node->set_device(strings::StrCat("/gpu:", number % num_gpus_));
230 }
231 for (int i = 0; i < new_node->input_size(); i++) {
232 if (NotSharedNode(NodeName(new_node->input(i)))) {
233 string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
234 *new_node->mutable_input(i) = new_name;
235 }
236 }
237 }
238 }
239 }
240
BuildGraph(GraphDef * graph)241 void AutoParallel::BuildGraph(GraphDef* graph) {
242 AddSharedNodes(graph);
243 for (int i = 0; i < num_replicas_; i++) {
244 AddOneReplica(graph, i);
245 }
246 std::set<string> fetches;
247 for (size_t i = 0; i < item_->fetch.size(); i++) {
248 for (int j = 0; j < num_replicas_; j++) {
249 string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j);
250 string fetch = AddPrefixToNodeName(item_->fetch[i], prefix);
251 fetches.insert(fetch);
252 }
253 }
254 string name_control =
255 strings::StrCat(kAutoParallelPrefix, "-Control-", "Fetch");
256 auto control = AddNodeControl(name_control, fetches, graph);
257
258 for (const auto& fetch : item_->fetch) {
259 AddNodeControl(fetch, {control->name()}, graph);
260 }
261 *graph->mutable_library() = item_->graph.library();
262 *graph->mutable_versions() = item_->graph.versions();
263 LOG(INFO) << "Parallelized graph size: " << graph->node_size();
264 }
265
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)266 Status AutoParallel::Optimize(Cluster* cluster, const GrapplerItem& item,
267 GraphDef* output) {
268 TF_RETURN_IF_ERROR(Initialize(item));
269 BuildGraph(output);
270 return Status::OK();
271 }
272
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)273 void AutoParallel::Feedback(Cluster* cluster, const GrapplerItem& item,
274 const GraphDef& optimize_output, double result) {
275 // TODO(yaozhang): Add feedback.
276 }
277
278 } // end namespace grappler
279 } // end namespace tensorflow
280