1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Configuration for distributed TPU jobs
17 
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
19 
20 #include <unordered_map>
21 
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/optimization_registry.h"
27 #include "tensorflow/core/framework/node_def_builder.h"
28 #include "tensorflow/core/framework/partial_tensor_shape.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/public/session_options.h"
34 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
35 #include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
36 #include "tensorflow/core/tpu/tpu_init_mode.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 #include "tensorflow/core/util/dump_graph.h"
39 
40 namespace tensorflow {
41 namespace {
42 
43 constexpr char kIdentityOp[] = "Identity";
44 constexpr char kConfigureOp[] = "ConfigureDistributedTPU";
45 constexpr char kInternalConfigureOp[] = "_ConfigureDistributedTPU";
46 constexpr char kWaitOp[] = "_WaitForDistributedTPU";
47 constexpr char kHostConfigureOp[] = "_InitializeHostForDistributedTPU";
48 constexpr char kGlobalTPUArrayOp[] = "_SetGlobalTPUArray";
49 constexpr char kShutdownOp[] = "ShutdownDistributedTPU";
50 constexpr char kInternalShutdownOp[] = "_ShutdownDistributedTPU";
51 constexpr char kHostDisconnectOp[] = "_DisconnectHostFromDistributedTPUSystem";
52 constexpr char kEmbeddingConfigurationAttr[] = "embedding_config";
53 constexpr int kDefaultStartupTimeout = 20;
54 
AddConfigurationNode(const string & configuration_device_name,int number_of_hosts,Graph * graph,bool enable_whole_mesh_compilations,Node ** configuration_node)55 Status AddConfigurationNode(const string& configuration_device_name,
56                             int number_of_hosts, Graph* graph,
57                             bool enable_whole_mesh_compilations,
58                             Node** configuration_node) {
59   NodeDef config_def;
60   config_def.set_name(graph->NewName("configure_distributed_tpu"));
61   config_def.set_op(kInternalConfigureOp);
62   config_def.set_device(configuration_device_name);
63   AddNodeAttr("N", number_of_hosts, &config_def);
64   AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
65               &config_def);
66   // TODO(shikharagarwal): Fill with appropriate original node debug info.
67 
68   Status status;
69   *configuration_node = graph->AddNode(config_def, &status);
70   if (!status.ok()) {
71     return status;
72   }
73   (*configuration_node)->set_assigned_device_name(configuration_device_name);
74   return Status::OK();
75 }
76 
AddHostConfigNode(const string & host_device_name,Node * configuration_node,Graph * graph,bool enable_whole_mesh_compilations,Node ** host_configuration_node)77 Status AddHostConfigNode(const string& host_device_name,
78                          Node* configuration_node, Graph* graph,
79                          bool enable_whole_mesh_compilations,
80                          Node** host_configuration_node) {
81   NodeDef host_config_def;
82   host_config_def.set_name(graph->NewName("configure_tpu_host"));
83   host_config_def.set_op(kHostConfigureOp);
84   host_config_def.set_device(host_device_name);
85   AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
86               &host_config_def);
87   MergeDebugInfo(NodeDebugInfo(configuration_node->def()), &host_config_def);
88 
89   Status status;
90   *host_configuration_node = graph->AddNode(host_config_def, &status);
91   if (!status.ok()) {
92     return status;
93   }
94   (*host_configuration_node)->set_assigned_device_name(host_device_name);
95   graph->AddEdge(configuration_node, 0, *host_configuration_node, 0);
96   return Status::OK();
97 }
98 
AddWaitNode(const string & configuration_device_name,const std::vector<Node * > & host_configuration_nodes,Graph * graph,Node ** wait_node)99 Status AddWaitNode(const string& configuration_device_name,
100                    const std::vector<Node*>& host_configuration_nodes,
101                    Graph* graph, Node** wait_node) {
102   NodeDef wait_def;
103   wait_def.set_name(graph->NewName("wait_for_distributed_tpu_system"));
104   wait_def.set_op(kWaitOp);
105   wait_def.set_device(configuration_device_name);
106   AddNodeAttr("N", static_cast<int32>(host_configuration_nodes.size()),
107               &wait_def);
108   AddNodeAttr("startup_timeout_sec", kDefaultStartupTimeout, &wait_def);
109   if (!host_configuration_nodes.empty()) {
110     MergeDebugInfo(NodeDebugInfo(host_configuration_nodes[0]->def()),
111                    &wait_def);
112   }
113 
114   Status status;
115   *wait_node = graph->AddNode(wait_def, &status);
116   if (!status.ok()) {
117     return status;
118   }
119   (*wait_node)->set_assigned_device_name(configuration_device_name);
120   // Get the inputs from the host configuration nodes.
121   for (int i = 0; i < host_configuration_nodes.size(); ++i) {
122     graph->AddEdge(host_configuration_nodes[i], 0, *wait_node, i);
123   }
124   return Status::OK();
125 }
126 
AddGlobalTPUArrayNode(const string & host_device_name,Node * wait_node,Graph * graph,Node ** global_tpu_array_node)127 Status AddGlobalTPUArrayNode(const string& host_device_name, Node* wait_node,
128                              Graph* graph, Node** global_tpu_array_node) {
129   NodeDef global_tpu_array_def;
130   global_tpu_array_def.set_name(graph->NewName("set_global_tpu_array"));
131   global_tpu_array_def.set_op(kGlobalTPUArrayOp);
132   global_tpu_array_def.set_device(host_device_name);
133   MergeDebugInfo(NodeDebugInfo(wait_node->def()), &global_tpu_array_def);
134 
135   Status status;
136   *global_tpu_array_node = graph->AddNode(global_tpu_array_def, &status);
137   if (!status.ok()) {
138     return status;
139   }
140   (*global_tpu_array_node)->set_assigned_device_name(host_device_name);
141   graph->AddEdge(wait_node, 0, *global_tpu_array_node, 0);
142   return Status::OK();
143 }
144 
AddSynchronizationNode(const NodeDef & sync_node_def,const string & device_name,const std::vector<Node * > & global_array_id_nodes,Node * wait_node,const std::vector<DistributedTPURewriteHelpers::OutputDependency> & output_dependencies,Graph * graph)145 Status AddSynchronizationNode(
146     const NodeDef& sync_node_def, const string& device_name,
147     const std::vector<Node*>& global_array_id_nodes, Node* wait_node,
148     const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
149         output_dependencies,
150     Graph* graph) {
151   NodeDef sync_def;
152   sync_def.set_name(sync_node_def.name());
153   sync_def.set_op(kIdentityOp);
154   sync_def.set_device(device_name);
155   AddNodeAttr("T", DT_STRING, &sync_def);
156   MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def);
157 
158   Status status;
159   Node* sync_node = graph->AddNode(sync_def, &status);
160   if (!status.ok()) {
161     return status;
162   }
163   sync_node->set_assigned_device_name(device_name);
164   // Add control edges from the global array id nodes.
165   for (auto node : global_array_id_nodes) {
166     graph->AddControlEdge(node, sync_node);
167   }
168   // Forward the data from the wait node.
169   graph->AddEdge(wait_node, 0, sync_node, 0);
170   // Replace the output edges.
171   for (const DistributedTPURewriteHelpers::OutputDependency& dep :
172        output_dependencies) {
173     if (dep.dst_input == Graph::kControlSlot) {
174       graph->AddControlEdge(sync_node, dep.dst);
175     } else {
176       graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input);
177     }
178   }
179   return Status::OK();
180 }
181 
182 
AddShutdownNode(const NodeDef & shutdown_node_def,const string & shutdown_device_name,const std::vector<DistributedTPURewriteHelpers::OutputDependency> & output_dependencies,Graph * graph,Node ** shutdown_node)183 Status AddShutdownNode(
184     const NodeDef& shutdown_node_def, const string& shutdown_device_name,
185     const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
186         output_dependencies,
187     Graph* graph, Node** shutdown_node) {
188   NodeDef shutdown_def;
189   shutdown_def.set_name(shutdown_node_def.name());
190   shutdown_def.set_op(kInternalShutdownOp);
191   shutdown_def.set_device(shutdown_device_name);
192   MergeDebugInfo(NodeDebugInfo(shutdown_node_def), &shutdown_def);
193 
194   Status status;
195   *shutdown_node = graph->AddNode(shutdown_def, &status);
196   if (!status.ok()) {
197     return status;
198   }
199   (*shutdown_node)->set_assigned_device_name(shutdown_device_name);
200   // Replace the output control edges.
201   for (const DistributedTPURewriteHelpers::OutputDependency& dep :
202        output_dependencies) {
203     if (dep.dst_input != Graph::kControlSlot) {
204       return errors::Internal("Shutdown node had non-control edge output");
205     }
206     graph->AddControlEdge(*shutdown_node, dep.dst);
207   }
208   return Status::OK();
209 }
210 
AddHostDisconnectNode(const string & host_device_name,const std::vector<Node * > & input_dependencies,Node * post_disconnect_node,int output_index,Graph * graph)211 Status AddHostDisconnectNode(const string& host_device_name,
212                              const std::vector<Node*>& input_dependencies,
213                              Node* post_disconnect_node, int output_index,
214                              Graph* graph) {
215   NodeDef host_disconnect_def;
216   host_disconnect_def.set_name(graph->NewName("disconnect_tpu_host"));
217   host_disconnect_def.set_op(kHostDisconnectOp);
218   host_disconnect_def.set_device(host_device_name);
219   MergeDebugInfo(NodeDebugInfo(post_disconnect_node->def()),
220                  &host_disconnect_def);
221 
222   Status status;
223   Node* host_disconnect_node = graph->AddNode(host_disconnect_def, &status);
224   if (!status.ok()) {
225     return status;
226   }
227   host_disconnect_node->set_assigned_device_name(host_device_name);
228   // Replace the input control edges.
229   for (Node* src_node : input_dependencies) {
230     graph->AddControlEdge(src_node, host_disconnect_node);
231   }
232   if (output_index == -1) {
233     graph->AddControlEdge(host_disconnect_node, post_disconnect_node);
234   } else {
235     graph->AddEdge(host_disconnect_node, 0, post_disconnect_node, output_index);
236   }
237   return Status::OK();
238 }
239 
240 }  // namespace
241 
Run(const GraphOptimizationPassOptions & options)242 Status DistributedTPUConfigurationRewritePass::Run(
243     const GraphOptimizationPassOptions& options) {
244   VLOG(1) << "DistributedTPUConfigurationRewritePass::Run";
245 
246   Graph* graph = options.graph->get();
247 
248   if (VLOG_IS_ON(1)) {
249     DumpGraphToFile("distributed_tpu_configuration_before", *graph,
250                     options.flib_def);
251   }
252 
253   // This pass can only run in the session master, which should fill
254   // in the device_set field to the options.
255   TF_RET_CHECK(options.device_set != nullptr);
256 
257   TF_RETURN_IF_ERROR(
258       DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
259           kConfigureOp, graph, *options.device_set,
260           [](const NodeDef& configuration_node_def,
261              const string& configuration_device_name,
262              const std::vector<Device*>& host_devices,
263              const std::vector<Node*>& input_dependencies,
264              const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
265                  output_dependencies,
266              Graph* graph) -> Status {
267             const std::string& embedding_attr_string = GetNodeAttrString(
268                 AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr);
269 
270             if (!embedding_attr_string.empty()) {
271               return errors::InvalidArgument("embedding_config must be empty.");
272             }
273 
274             bool is_global_init = false;
275             bool enable_whole_mesh_compilations = false;
276             TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
277                                            "is_global_init", &is_global_init));
278             TryGetNodeAttr(configuration_node_def,
279                            "enable_whole_mesh_compilations",
280                            &enable_whole_mesh_compilations);
281             TF_RETURN_IF_ERROR(SetTPUInitMode(
282                 is_global_init ? TPUInitMode::kGlobal : TPUInitMode::kRegular));
283 
284             bool compilation_failure_closes_chips;
285             TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
286                                            "compilation_failure_closes_chips",
287                                            &compilation_failure_closes_chips));
288             internal::SetTpuCompilationFailureClosesChips(
289                 compilation_failure_closes_chips);
290 
291             // Add the global TPU system configuration node.
292             Node* configuration_node;
293             TF_RETURN_IF_ERROR(AddConfigurationNode(
294                 configuration_device_name, host_devices.size(), graph,
295                 enable_whole_mesh_compilations, &configuration_node));
296 
297             // Add the host disconnect nodes.
298             for (int i = 0; i < host_devices.size(); ++i) {
299               const auto host_device = host_devices[i];
300               TF_RETURN_IF_ERROR(
301                   AddHostDisconnectNode(host_device->name(), input_dependencies,
302                                         configuration_node, i, graph));
303             }
304 
305             // Add the host configuration nodes.
306             std::vector<Node*> host_configuration_nodes;
307             for (const auto host_device : host_devices) {
308               Node* host_configuration_node;
309               TF_RETURN_IF_ERROR(AddHostConfigNode(
310                   host_device->name(), configuration_node, graph,
311                   enable_whole_mesh_compilations, &host_configuration_node));
312               host_configuration_nodes.push_back(host_configuration_node);
313             }
314 
315             // Add the node to wait for the system configuration to
316             // stabilize. Use the name of the original dummy Op in case it was
317             // the target of a Session::Run call.
318             Node* wait_node;
319             TF_RETURN_IF_ERROR(AddWaitNode(configuration_device_name,
320                                            host_configuration_nodes, graph,
321                                            &wait_node));
322 
323             // Add the nodes to set the global TPU ids at each host.
324             std::vector<Node*> global_array_id_nodes;
325             for (const auto host_device : host_devices) {
326               Node* global_array_id_node;
327               TF_RETURN_IF_ERROR(AddGlobalTPUArrayNode(host_device->name(),
328                                                        wait_node, graph,
329                                                        &global_array_id_node));
330               global_array_id_nodes.push_back(global_array_id_node);
331             }
332 
333             if (host_devices.empty()) {
334               return errors::InvalidArgument("TPU job contains no CPU devices");
335             }
336             TF_RET_CHECK(!host_devices.empty());
337 
338             TF_RETURN_IF_ERROR(AddSynchronizationNode(
339                 configuration_node_def, host_devices.front()->name(),
340                 global_array_id_nodes, wait_node, output_dependencies, graph));
341 
342             return Status::OK();
343           }));
344 
345   if (VLOG_IS_ON(1)) {
346     DumpGraphToFile("distributed_tpu_configuration_after", *graph,
347                     options.flib_def);
348   }
349 
350   VLOG(1) << "DistributedTPUConfigurationRewritePass::Run() finished";
351   return Status::OK();
352 }
353 
Run(const GraphOptimizationPassOptions & options)354 Status DistributedTPUShutdownRewritePass::Run(
355     const GraphOptimizationPassOptions& options) {
356   VLOG(1) << "DistributedTPUShutdownRewritePass::Run";
357 
358   Graph* graph = options.graph->get();
359 
360   if (VLOG_IS_ON(1)) {
361     DumpGraphToFile("distributed_tpu_shutdown_before", *graph,
362                     options.flib_def);
363   }
364 
365   // This pass can only run in the session master, which should fill
366   // in the device_set field to the options.
367   TF_RET_CHECK(options.device_set != nullptr);
368 
369   TF_RETURN_IF_ERROR(
370       DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
371           kShutdownOp, graph, *options.device_set,
372           [](const NodeDef& shutdown_node_def,
373              const string& shutdown_device_name,
374              const std::vector<Device*>& host_devices,
375              const std::vector<Node*>& input_dependencies,
376              const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
377                  output_dependencies,
378              Graph* graph) -> Status {
379             Node* shutdown_node;
380             TF_RETURN_IF_ERROR(
381                 AddShutdownNode(shutdown_node_def, shutdown_device_name,
382                                 output_dependencies, graph, &shutdown_node));
383 
384             // Add the host disconnect nodes.
385             for (const auto host_device : host_devices) {
386               TF_RETURN_IF_ERROR(
387                   AddHostDisconnectNode(host_device->name(), input_dependencies,
388                                         shutdown_node, -1, graph));
389             }
390 
391             return Status::OK();
392           }));
393 
394   if (VLOG_IS_ON(1)) {
395     DumpGraphToFile("distributed_tpu_shutdown_after", *graph, options.flib_def);
396   }
397 
398   VLOG(1) << "DistributedTPUShutdownRewritePass::Run() finished";
399   return Status::OK();
400 }
401 
402 }  // namespace tensorflow
403