1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Configuration for distributed TPU jobs
17
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
19
20 #include <unordered_map>
21
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/optimization_registry.h"
27 #include "tensorflow/core/framework/node_def_builder.h"
28 #include "tensorflow/core/framework/partial_tensor_shape.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/public/session_options.h"
34 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
35 #include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
36 #include "tensorflow/core/tpu/tpu_init_mode.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 #include "tensorflow/core/util/dump_graph.h"
39
40 namespace tensorflow {
41 namespace {
42
43 constexpr char kIdentityOp[] = "Identity";
44 constexpr char kConfigureOp[] = "ConfigureDistributedTPU";
45 constexpr char kInternalConfigureOp[] = "_ConfigureDistributedTPU";
46 constexpr char kWaitOp[] = "_WaitForDistributedTPU";
47 constexpr char kHostConfigureOp[] = "_InitializeHostForDistributedTPU";
48 constexpr char kGlobalTPUArrayOp[] = "_SetGlobalTPUArray";
49 constexpr char kShutdownOp[] = "ShutdownDistributedTPU";
50 constexpr char kInternalShutdownOp[] = "_ShutdownDistributedTPU";
51 constexpr char kHostDisconnectOp[] = "_DisconnectHostFromDistributedTPUSystem";
52 constexpr char kEmbeddingConfigurationAttr[] = "embedding_config";
53 constexpr int kDefaultStartupTimeout = 20;
54
AddConfigurationNode(const string & configuration_device_name,int number_of_hosts,Graph * graph,bool enable_whole_mesh_compilations,Node ** configuration_node)55 Status AddConfigurationNode(const string& configuration_device_name,
56 int number_of_hosts, Graph* graph,
57 bool enable_whole_mesh_compilations,
58 Node** configuration_node) {
59 NodeDef config_def;
60 config_def.set_name(graph->NewName("configure_distributed_tpu"));
61 config_def.set_op(kInternalConfigureOp);
62 config_def.set_device(configuration_device_name);
63 AddNodeAttr("N", number_of_hosts, &config_def);
64 AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
65 &config_def);
66 // TODO(shikharagarwal): Fill with appropriate original node debug info.
67
68 Status status;
69 *configuration_node = graph->AddNode(config_def, &status);
70 if (!status.ok()) {
71 return status;
72 }
73 (*configuration_node)->set_assigned_device_name(configuration_device_name);
74 return Status::OK();
75 }
76
AddHostConfigNode(const string & host_device_name,Node * configuration_node,Graph * graph,bool enable_whole_mesh_compilations,Node ** host_configuration_node)77 Status AddHostConfigNode(const string& host_device_name,
78 Node* configuration_node, Graph* graph,
79 bool enable_whole_mesh_compilations,
80 Node** host_configuration_node) {
81 NodeDef host_config_def;
82 host_config_def.set_name(graph->NewName("configure_tpu_host"));
83 host_config_def.set_op(kHostConfigureOp);
84 host_config_def.set_device(host_device_name);
85 AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
86 &host_config_def);
87 MergeDebugInfo(NodeDebugInfo(configuration_node->def()), &host_config_def);
88
89 Status status;
90 *host_configuration_node = graph->AddNode(host_config_def, &status);
91 if (!status.ok()) {
92 return status;
93 }
94 (*host_configuration_node)->set_assigned_device_name(host_device_name);
95 graph->AddEdge(configuration_node, 0, *host_configuration_node, 0);
96 return Status::OK();
97 }
98
AddWaitNode(const string & configuration_device_name,const std::vector<Node * > & host_configuration_nodes,Graph * graph,Node ** wait_node)99 Status AddWaitNode(const string& configuration_device_name,
100 const std::vector<Node*>& host_configuration_nodes,
101 Graph* graph, Node** wait_node) {
102 NodeDef wait_def;
103 wait_def.set_name(graph->NewName("wait_for_distributed_tpu_system"));
104 wait_def.set_op(kWaitOp);
105 wait_def.set_device(configuration_device_name);
106 AddNodeAttr("N", static_cast<int32>(host_configuration_nodes.size()),
107 &wait_def);
108 AddNodeAttr("startup_timeout_sec", kDefaultStartupTimeout, &wait_def);
109 if (!host_configuration_nodes.empty()) {
110 MergeDebugInfo(NodeDebugInfo(host_configuration_nodes[0]->def()),
111 &wait_def);
112 }
113
114 Status status;
115 *wait_node = graph->AddNode(wait_def, &status);
116 if (!status.ok()) {
117 return status;
118 }
119 (*wait_node)->set_assigned_device_name(configuration_device_name);
120 // Get the inputs from the host configuration nodes.
121 for (int i = 0; i < host_configuration_nodes.size(); ++i) {
122 graph->AddEdge(host_configuration_nodes[i], 0, *wait_node, i);
123 }
124 return Status::OK();
125 }
126
AddGlobalTPUArrayNode(const string & host_device_name,Node * wait_node,Graph * graph,Node ** global_tpu_array_node)127 Status AddGlobalTPUArrayNode(const string& host_device_name, Node* wait_node,
128 Graph* graph, Node** global_tpu_array_node) {
129 NodeDef global_tpu_array_def;
130 global_tpu_array_def.set_name(graph->NewName("set_global_tpu_array"));
131 global_tpu_array_def.set_op(kGlobalTPUArrayOp);
132 global_tpu_array_def.set_device(host_device_name);
133 MergeDebugInfo(NodeDebugInfo(wait_node->def()), &global_tpu_array_def);
134
135 Status status;
136 *global_tpu_array_node = graph->AddNode(global_tpu_array_def, &status);
137 if (!status.ok()) {
138 return status;
139 }
140 (*global_tpu_array_node)->set_assigned_device_name(host_device_name);
141 graph->AddEdge(wait_node, 0, *global_tpu_array_node, 0);
142 return Status::OK();
143 }
144
AddSynchronizationNode(const NodeDef & sync_node_def,const string & device_name,const std::vector<Node * > & global_array_id_nodes,Node * wait_node,const std::vector<DistributedTPURewriteHelpers::OutputDependency> & output_dependencies,Graph * graph)145 Status AddSynchronizationNode(
146 const NodeDef& sync_node_def, const string& device_name,
147 const std::vector<Node*>& global_array_id_nodes, Node* wait_node,
148 const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
149 output_dependencies,
150 Graph* graph) {
151 NodeDef sync_def;
152 sync_def.set_name(sync_node_def.name());
153 sync_def.set_op(kIdentityOp);
154 sync_def.set_device(device_name);
155 AddNodeAttr("T", DT_STRING, &sync_def);
156 MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def);
157
158 Status status;
159 Node* sync_node = graph->AddNode(sync_def, &status);
160 if (!status.ok()) {
161 return status;
162 }
163 sync_node->set_assigned_device_name(device_name);
164 // Add control edges from the global array id nodes.
165 for (auto node : global_array_id_nodes) {
166 graph->AddControlEdge(node, sync_node);
167 }
168 // Forward the data from the wait node.
169 graph->AddEdge(wait_node, 0, sync_node, 0);
170 // Replace the output edges.
171 for (const DistributedTPURewriteHelpers::OutputDependency& dep :
172 output_dependencies) {
173 if (dep.dst_input == Graph::kControlSlot) {
174 graph->AddControlEdge(sync_node, dep.dst);
175 } else {
176 graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input);
177 }
178 }
179 return Status::OK();
180 }
181
182
AddShutdownNode(const NodeDef & shutdown_node_def,const string & shutdown_device_name,const std::vector<DistributedTPURewriteHelpers::OutputDependency> & output_dependencies,Graph * graph,Node ** shutdown_node)183 Status AddShutdownNode(
184 const NodeDef& shutdown_node_def, const string& shutdown_device_name,
185 const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
186 output_dependencies,
187 Graph* graph, Node** shutdown_node) {
188 NodeDef shutdown_def;
189 shutdown_def.set_name(shutdown_node_def.name());
190 shutdown_def.set_op(kInternalShutdownOp);
191 shutdown_def.set_device(shutdown_device_name);
192 MergeDebugInfo(NodeDebugInfo(shutdown_node_def), &shutdown_def);
193
194 Status status;
195 *shutdown_node = graph->AddNode(shutdown_def, &status);
196 if (!status.ok()) {
197 return status;
198 }
199 (*shutdown_node)->set_assigned_device_name(shutdown_device_name);
200 // Replace the output control edges.
201 for (const DistributedTPURewriteHelpers::OutputDependency& dep :
202 output_dependencies) {
203 if (dep.dst_input != Graph::kControlSlot) {
204 return errors::Internal("Shutdown node had non-control edge output");
205 }
206 graph->AddControlEdge(*shutdown_node, dep.dst);
207 }
208 return Status::OK();
209 }
210
AddHostDisconnectNode(const string & host_device_name,const std::vector<Node * > & input_dependencies,Node * post_disconnect_node,int output_index,Graph * graph)211 Status AddHostDisconnectNode(const string& host_device_name,
212 const std::vector<Node*>& input_dependencies,
213 Node* post_disconnect_node, int output_index,
214 Graph* graph) {
215 NodeDef host_disconnect_def;
216 host_disconnect_def.set_name(graph->NewName("disconnect_tpu_host"));
217 host_disconnect_def.set_op(kHostDisconnectOp);
218 host_disconnect_def.set_device(host_device_name);
219 MergeDebugInfo(NodeDebugInfo(post_disconnect_node->def()),
220 &host_disconnect_def);
221
222 Status status;
223 Node* host_disconnect_node = graph->AddNode(host_disconnect_def, &status);
224 if (!status.ok()) {
225 return status;
226 }
227 host_disconnect_node->set_assigned_device_name(host_device_name);
228 // Replace the input control edges.
229 for (Node* src_node : input_dependencies) {
230 graph->AddControlEdge(src_node, host_disconnect_node);
231 }
232 if (output_index == -1) {
233 graph->AddControlEdge(host_disconnect_node, post_disconnect_node);
234 } else {
235 graph->AddEdge(host_disconnect_node, 0, post_disconnect_node, output_index);
236 }
237 return Status::OK();
238 }
239
240 } // namespace
241
Run(const GraphOptimizationPassOptions & options)242 Status DistributedTPUConfigurationRewritePass::Run(
243 const GraphOptimizationPassOptions& options) {
244 VLOG(1) << "DistributedTPUConfigurationRewritePass::Run";
245
246 Graph* graph = options.graph->get();
247
248 if (VLOG_IS_ON(1)) {
249 DumpGraphToFile("distributed_tpu_configuration_before", *graph,
250 options.flib_def);
251 }
252
253 // This pass can only run in the session master, which should fill
254 // in the device_set field to the options.
255 TF_RET_CHECK(options.device_set != nullptr);
256
257 TF_RETURN_IF_ERROR(
258 DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
259 kConfigureOp, graph, *options.device_set,
260 [](const NodeDef& configuration_node_def,
261 const string& configuration_device_name,
262 const std::vector<Device*>& host_devices,
263 const std::vector<Node*>& input_dependencies,
264 const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
265 output_dependencies,
266 Graph* graph) -> Status {
267 const std::string& embedding_attr_string = GetNodeAttrString(
268 AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr);
269
270 if (!embedding_attr_string.empty()) {
271 return errors::InvalidArgument("embedding_config must be empty.");
272 }
273
274 bool is_global_init = false;
275 bool enable_whole_mesh_compilations = false;
276 TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
277 "is_global_init", &is_global_init));
278 TryGetNodeAttr(configuration_node_def,
279 "enable_whole_mesh_compilations",
280 &enable_whole_mesh_compilations);
281 TF_RETURN_IF_ERROR(SetTPUInitMode(
282 is_global_init ? TPUInitMode::kGlobal : TPUInitMode::kRegular));
283
284 bool compilation_failure_closes_chips;
285 TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
286 "compilation_failure_closes_chips",
287 &compilation_failure_closes_chips));
288 internal::SetTpuCompilationFailureClosesChips(
289 compilation_failure_closes_chips);
290
291 // Add the global TPU system configuration node.
292 Node* configuration_node;
293 TF_RETURN_IF_ERROR(AddConfigurationNode(
294 configuration_device_name, host_devices.size(), graph,
295 enable_whole_mesh_compilations, &configuration_node));
296
297 // Add the host disconnect nodes.
298 for (int i = 0; i < host_devices.size(); ++i) {
299 const auto host_device = host_devices[i];
300 TF_RETURN_IF_ERROR(
301 AddHostDisconnectNode(host_device->name(), input_dependencies,
302 configuration_node, i, graph));
303 }
304
305 // Add the host configuration nodes.
306 std::vector<Node*> host_configuration_nodes;
307 for (const auto host_device : host_devices) {
308 Node* host_configuration_node;
309 TF_RETURN_IF_ERROR(AddHostConfigNode(
310 host_device->name(), configuration_node, graph,
311 enable_whole_mesh_compilations, &host_configuration_node));
312 host_configuration_nodes.push_back(host_configuration_node);
313 }
314
315 // Add the node to wait for the system configuration to
316 // stabilize. Use the name of the original dummy Op in case it was
317 // the target of a Session::Run call.
318 Node* wait_node;
319 TF_RETURN_IF_ERROR(AddWaitNode(configuration_device_name,
320 host_configuration_nodes, graph,
321 &wait_node));
322
323 // Add the nodes to set the global TPU ids at each host.
324 std::vector<Node*> global_array_id_nodes;
325 for (const auto host_device : host_devices) {
326 Node* global_array_id_node;
327 TF_RETURN_IF_ERROR(AddGlobalTPUArrayNode(host_device->name(),
328 wait_node, graph,
329 &global_array_id_node));
330 global_array_id_nodes.push_back(global_array_id_node);
331 }
332
333 if (host_devices.empty()) {
334 return errors::InvalidArgument("TPU job contains no CPU devices");
335 }
336 TF_RET_CHECK(!host_devices.empty());
337
338 TF_RETURN_IF_ERROR(AddSynchronizationNode(
339 configuration_node_def, host_devices.front()->name(),
340 global_array_id_nodes, wait_node, output_dependencies, graph));
341
342 return Status::OK();
343 }));
344
345 if (VLOG_IS_ON(1)) {
346 DumpGraphToFile("distributed_tpu_configuration_after", *graph,
347 options.flib_def);
348 }
349
350 VLOG(1) << "DistributedTPUConfigurationRewritePass::Run() finished";
351 return Status::OK();
352 }
353
Run(const GraphOptimizationPassOptions & options)354 Status DistributedTPUShutdownRewritePass::Run(
355 const GraphOptimizationPassOptions& options) {
356 VLOG(1) << "DistributedTPUShutdownRewritePass::Run";
357
358 Graph* graph = options.graph->get();
359
360 if (VLOG_IS_ON(1)) {
361 DumpGraphToFile("distributed_tpu_shutdown_before", *graph,
362 options.flib_def);
363 }
364
365 // This pass can only run in the session master, which should fill
366 // in the device_set field to the options.
367 TF_RET_CHECK(options.device_set != nullptr);
368
369 TF_RETURN_IF_ERROR(
370 DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
371 kShutdownOp, graph, *options.device_set,
372 [](const NodeDef& shutdown_node_def,
373 const string& shutdown_device_name,
374 const std::vector<Device*>& host_devices,
375 const std::vector<Node*>& input_dependencies,
376 const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
377 output_dependencies,
378 Graph* graph) -> Status {
379 Node* shutdown_node;
380 TF_RETURN_IF_ERROR(
381 AddShutdownNode(shutdown_node_def, shutdown_device_name,
382 output_dependencies, graph, &shutdown_node));
383
384 // Add the host disconnect nodes.
385 for (const auto host_device : host_devices) {
386 TF_RETURN_IF_ERROR(
387 AddHostDisconnectNode(host_device->name(), input_dependencies,
388 shutdown_node, -1, graph));
389 }
390
391 return Status::OK();
392 }));
393
394 if (VLOG_IS_ON(1)) {
395 DumpGraphToFile("distributed_tpu_shutdown_after", *graph, options.flib_def);
396 }
397
398 VLOG(1) << "DistributedTPUShutdownRewritePass::Run() finished";
399 return Status::OK();
400 }
401
402 } // namespace tensorflow
403