1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
17 
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/versions.pb.h"
23 #include "tensorflow/core/grappler/clusters/cluster.h"
24 #include "tensorflow/core/grappler/devices.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 const char kAutoParallelPrefix[] = "AutoParallel";
34 
AddNodeDivConst()35 NodeDef* AutoParallel::AddNodeDivConst() {
36   NodeDef* node = graph_.add_node();
37   node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-Const"));
38   node->set_op("Const");
39 
40   AttrValue attr_data_type;
41   attr_data_type.set_type(DT_FLOAT);
42   node->mutable_attr()->insert({"dtype", attr_data_type});
43 
44   AttrValue attr_tensor;
45   auto tensor = attr_tensor.mutable_tensor();
46   tensor->add_float_val(static_cast<float>(num_replicas_));
47   tensor->set_dtype(DT_FLOAT);
48   node->mutable_attr()->insert({"value", attr_tensor});
49   return node;
50 }
51 
AddNodeDiv(const string & name,const string & input_a,const string & input_b)52 NodeDef* AutoParallel::AddNodeDiv(const string& name, const string& input_a,
53                                   const string& input_b) {
54   NodeDef* node = graph_.add_node();
55   node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-", name));
56   node->set_op("RealDiv");
57   node->add_input(input_a);
58   node->add_input(input_b);
59   AttrValue attr_type;
60   attr_type.set_type(DT_FLOAT);
61   node->mutable_attr()->insert({"T", attr_type});
62   return node;
63 }
64 
AddNodeControl(const string & name,const std::set<string> & deps,GraphDef * graph)65 NodeDef* AutoParallel::AddNodeControl(const string& name,
66                                       const std::set<string>& deps,
67                                       GraphDef* graph) {
68   NodeDef* node = graph->add_node();
69   node->set_name(name);
70   node->set_op("NoOp");
71   for (const auto& dep : deps) {
72     node->add_input(strings::StrCat("^", dep));
73   }
74   return node;
75 }
76 
Initialize(const GrapplerItem & item)77 Status AutoParallel::Initialize(const GrapplerItem& item) {
78   num_gpus_ = GetNumAvailableGPUs();
79   LOG(INFO) << "Number of GPUs: " << num_gpus_;
80   item_ = &item;
81   graph_ = item.graph;
82   LOG(INFO) << "Original graph size: " << graph_.node_size();
83   if (item.fetch.empty()) {
84     return Status(error::INVALID_ARGUMENT, "No fetch nodes provided.");
85   }
86 
87   if (item.MainVariables().empty()) {
88     return Status(error::INVALID_ARGUMENT, "No variables provided.");
89   }
90 
91   for (const auto& init : item.init_ops) {
92     VLOG(1) << "Init node: " << init;
93   }
94 
95   for (const auto& fetch : item.fetch) {
96     VLOG(1) << "Fetch node: " << fetch;
97   }
98 
99   for (const auto& var : item.MainVariables()) {
100     VLOG(2) << "Variable: " << var->name();
101   }
102 
103   const std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
104                                                 "ApplyProximalGradientDescent",
105                                                 "ApplyAdadelta",
106                                                 "ApplyAdagrad",
107                                                 "ApplyProximalAdagrad",
108                                                 "ApplyAdagradDA",
109                                                 "ApplyFtrl",
110                                                 "ApplyMomentum",
111                                                 "ApplyAdam",
112                                                 "ApplyRMSProp",
113                                                 "ApplyCenteredRMSProp"};
114   for (int i = 0; i < graph_.node_size(); i++) {
115     all_nodes_.insert(
116         std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
117     if (apply_gradients_ops.find(graph_.node(i).op()) !=
118         apply_gradients_ops.end()) {
119       apply_gradients_nodes_.insert(graph_.node(i).name());
120       VLOG(2) << "Apply gradients node: " << graph_.node(i).name();
121     }
122   }
123 
124   auto div_const_node = AddNodeDivConst();
125   all_nodes_.insert(std::make_pair(div_const_node->name(), div_const_node));
126   std::map<string, int> gradient_pos = {{"ApplyGradientDescent", 2},
127                                         {"ApplyProximalGradientDescent", 4},
128                                         {"ApplyAdadelta", 6},
129                                         {"ApplyAdagrad", 3},
130                                         {"ApplyProximalAdagrad", 5},
131                                         {"ApplyAdagradDA", 3},
132                                         {"ApplyFtrl", 3},
133                                         {"ApplyMomentum", 3},
134                                         {"ApplyAdam", 9},
135                                         {"ApplyRMSProp", 7},
136                                         {"ApplyCenteredRMSProp", 8}};
137   for (const auto& apply_gradient_node_name : apply_gradients_nodes_) {
138     auto apply_gradients_op = all_nodes_[apply_gradient_node_name]->op();
139     auto apply_gradients_node = all_nodes_[apply_gradient_node_name];
140 
141     auto div_node = AddNodeDiv(
142         apply_gradient_node_name,
143         apply_gradients_node->input(gradient_pos[apply_gradients_op]),
144         div_const_node->name());
145     all_nodes_.insert(std::make_pair(div_node->name(), div_node));
146     *apply_gradients_node->mutable_input(gradient_pos[apply_gradients_op]) =
147         div_node->name();
148   }
149   LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size();
150 
151   std::vector<const NodeDef*> train_nodes;
152   TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph_, item.fetch, &train_nodes));
153   LOG(INFO) << "Number of training nodes: " << train_nodes.size();
154 
155   const NodeDef* dequeue_node;
156   for (const auto& train_node : train_nodes) {
157     if (IsDequeueOp(*train_node)) {
158       dequeue_node = train_node;
159       break;
160     }
161   }
162 
163   std::vector<const NodeDef*> input_nodes;
164   if (dequeue_node) {
165     LOG(INFO) << "Dequeue node: " << dequeue_node->name();
166     TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph_, {dequeue_node->name()},
167                                               {}, &input_nodes));
168   }
169   LOG(INFO) << "Number of input nodes: " << input_nodes.size();
170 
171   std::set<string> dont_replicate_nodes;
172   for (const auto& variable : item.MainVariables()) {
173     dont_replicate_nodes.insert(variable->name());
174   }
175 
176   for (const auto& init : item.init_ops) {
177     dont_replicate_nodes.insert(NodeName(init));
178   }
179 
180   // Don't replicate all input nodes, except the dequeue node.
181   for (const auto& input_node : input_nodes) {
182     if (input_node->name() != dequeue_node->name()) {
183       dont_replicate_nodes.insert(input_node->name());
184     }
185   }
186 
187   for (const auto& node : train_nodes) {
188     if (dont_replicate_nodes.find(node->name()) == dont_replicate_nodes.end()) {
189       replica_nodes_.insert(node->name());
190     }
191   }
192   LOG(INFO) << "Number of replica nodes: " << replica_nodes_.size();
193 
194   for (const auto& node : all_nodes_) {
195     if (replica_nodes_.find(node.first) == replica_nodes_.end()) {
196       shared_nodes_.insert(node.first);
197     }
198   }
199   LOG(INFO) << "Number of shared nodes: " << shared_nodes_.size();
200   return Status::OK();
201 }
202 
NotSharedNode(const string & name)203 bool AutoParallel::NotSharedNode(const string& name) {
204   return shared_nodes_.find(name) == shared_nodes_.end();
205 }
206 
AddSharedNodes(GraphDef * graph)207 void AutoParallel::AddSharedNodes(GraphDef* graph) {
208   string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", 0);
209   for (const auto& node : shared_nodes_) {
210     auto new_node = graph->add_node();
211     *new_node = *all_nodes_[node];
212     for (int i = 0; i < new_node->input_size(); i++) {
213       if (NotSharedNode(NodeName(new_node->input(i)))) {
214         string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
215         *new_node->mutable_input(i) = new_name;
216       }
217     }
218   }
219 }
220 
AddOneReplica(GraphDef * graph,int number)221 void AutoParallel::AddOneReplica(GraphDef* graph, int number) {
222   string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", number);
223   for (const auto& node : replica_nodes_) {
224     auto new_node = graph->add_node();
225     *new_node = *all_nodes_[node];
226     if (NotSharedNode(new_node->name())) {
227       new_node->set_name(AddPrefixToNodeName(new_node->name(), prefix));
228       if (num_gpus_ > 0) {
229         new_node->set_device(strings::StrCat("/gpu:", number % num_gpus_));
230       }
231       for (int i = 0; i < new_node->input_size(); i++) {
232         if (NotSharedNode(NodeName(new_node->input(i)))) {
233           string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
234           *new_node->mutable_input(i) = new_name;
235         }
236       }
237     }
238   }
239 }
240 
BuildGraph(GraphDef * graph)241 void AutoParallel::BuildGraph(GraphDef* graph) {
242   AddSharedNodes(graph);
243   for (int i = 0; i < num_replicas_; i++) {
244     AddOneReplica(graph, i);
245   }
246   std::set<string> fetches;
247   for (size_t i = 0; i < item_->fetch.size(); i++) {
248     for (int j = 0; j < num_replicas_; j++) {
249       string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j);
250       string fetch = AddPrefixToNodeName(item_->fetch[i], prefix);
251       fetches.insert(fetch);
252     }
253   }
254   string name_control =
255       strings::StrCat(kAutoParallelPrefix, "-Control-", "Fetch");
256   auto control = AddNodeControl(name_control, fetches, graph);
257 
258   for (const auto& fetch : item_->fetch) {
259     AddNodeControl(fetch, {control->name()}, graph);
260   }
261   *graph->mutable_library() = item_->graph.library();
262   *graph->mutable_versions() = item_->graph.versions();
263   LOG(INFO) << "Parallelized graph size: " << graph->node_size();
264 }
265 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)266 Status AutoParallel::Optimize(Cluster* cluster, const GrapplerItem& item,
267                               GraphDef* output) {
268   TF_RETURN_IF_ERROR(Initialize(item));
269   BuildGraph(output);
270   return Status::OK();
271 }
272 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)273 void AutoParallel::Feedback(Cluster* cluster, const GrapplerItem& item,
274                             const GraphDef& optimize_output, double result) {
275   // TODO(yaozhang): Add feedback.
276 }
277 
278 }  // end namespace grappler
279 }  // end namespace tensorflow
280