1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
16
17 #include "tensorflow/core/common_runtime/scoped_allocator.h"
18 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/grappler/costs/graph_properties.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/utils/frame.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28
29 // Like TF_RETURN_IF_ERROR, but also logs a WARNING.
30 #define LOG_WARNING_AND_RETURN_IF_ERROR(...) \
31 do { \
32 const ::tensorflow::Status _status = (__VA_ARGS__); \
33 if (TF_PREDICT_FALSE(!_status.ok())) { \
34 LOG(WARNING) << "error: " << _status; \
35 return _status; \
36 } \
37 } while (0)
38
39 namespace tensorflow {
40 namespace grappler {
41
42 namespace {
43
44 const char kScopedAllocatorAttrName[] = "_scoped_allocator";
45
46 // Node names often have some kind of name_scope prefix, with slashes,
47 // and a _nn numeric suffix. Returns true if the main part of the node_name
48 // matches op_name, i.e. it looks from the name like this node is
49 // of that op type.
HasOpName(const string & node_name,const string & op_name)50 bool HasOpName(const string& node_name, const string& op_name) {
51 size_t begin = node_name.rfind('/');
52 if (begin == string::npos) {
53 begin = 0;
54 } else {
55 ++begin;
56 }
57 size_t end = node_name.rfind('_');
58 if (end != string::npos) {
59 size_t p = end + 1;
60 while (p < node_name.size()) {
61 if (!isdigit(node_name[p])) {
62 end = node_name.size();
63 break;
64 }
65 ++p;
66 }
67 } else {
68 end = node_name.size();
69 }
70 return node_name.substr(begin, end - begin) == op_name;
71 }
72
GetOutputDataType(const std::vector<OpInfo::TensorProperties> & output_props,int output_index,DataType * dtype)73 Status GetOutputDataType(
74 const std::vector<OpInfo::TensorProperties>& output_props, int output_index,
75 DataType* dtype) {
76 int output_props_size = output_props.size();
77 if (output_index >= output_props_size) {
78 return errors::Internal("Invalid output index ", output_index,
79 " size of output_props ", output_props.size());
80 }
81 *dtype = output_props[output_index].dtype();
82 return Status::OK();
83 }
84
85 // After shape inference has been done each op should be annotated
86 // with its output shape(s). This function iterates over a collection
87 // of ops that are a potential application of a ScopedAllocator. It
88 // verifies whether they all have the same output type and if so
89 // gathers a vector of their output shapes. It returns an error if
90 // any of the ops doesn't have type or shape data, or if it has more
91 // than one output, of if the output type of all ops is not the same.
92 // If it returns OK then *type and *shapes should be correctly populated.
CheckTypesAndGetShapes(const GraphProperties & graph_properties,const std::vector<NodeDef * > & ops,DataType * type,std::vector<TensorShape> * shapes)93 Status CheckTypesAndGetShapes(const GraphProperties& graph_properties,
94 const std::vector<NodeDef*>& ops, DataType* type,
95 std::vector<TensorShape>* shapes) {
96 VLOG(1) << "CheckTypesAndGetShapes";
97 *type = DT_INVALID;
98 for (NodeDef* n : ops) {
99 AttrSlice n_attrs = AttrSlice(*n);
100 DataType dtype;
101 // Check that op has an explicit data type attr "T".
102 LOG_WARNING_AND_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
103 VLOG(2) << "op " << n->name() << " has type " << dtype << " shapes.size() "
104 << shapes->size();
105 if (!graph_properties.HasOutputProperties(n->name())) {
106 LOG(ERROR) << "Node " << n->DebugString() << " lacks output shape.";
107 return errors::Aborted("Node ", n->name(), " lacks output shape.");
108 }
109 const std::vector<OpInfo::TensorProperties>& prop_list =
110 graph_properties.GetOutputProperties(n->name());
111 if (prop_list.size() != 1) {
112 return errors::Aborted("Node ", n->name(),
113 " does not have exactly one output as expected "
114 "by ScopedAllocatorOptimizer");
115 }
116 const OpInfo::TensorProperties& props = prop_list[0];
117 if (shapes->empty()) {
118 *type = props.dtype();
119 } else if (*type != props.dtype()) {
120 return errors::Aborted("Group ops don't all have same type");
121 }
122 if (*type != dtype) {
123 return errors::Internal(
124 "Type mismatch: type in op attr = ", DataTypeString(dtype),
125 ", type in output props = ", DataTypeString(*type));
126 }
127 if (!TensorShape::IsValid(props.shape()) || props.shape().unknown_rank()) {
128 // TensorShape::IsValid may return true if unknown_rank is True, i.e.
129 // number of dimensions is unknown. But for ScopedAllocatorOptimizer we
130 // need to know the shape fully.
131 return errors::Aborted("Complete shape not known for ", n->name());
132 }
133 VLOG(2) << "Adding shape " << props.shape().DebugString();
134 shapes->push_back(TensorShape(props.shape()));
135 }
136 return Status::OK();
137 }
138
139 // Describes an existing input edge in the graph.
140 struct InputDesc {
141 NodeDef* from_node_def;
142 int output_slot;
143 NodeDef* to_node_def;
InputDesctensorflow::grappler::__anone24104b20111::InputDesc144 InputDesc(NodeDef* f, int os, NodeDef* t)
145 : from_node_def(f), output_slot(os), to_node_def(t) {}
146 };
147
148 // Remove the NodeDef nd from node_map and graph. It must be the case
149 // that nd no longer has any input or output edges, though that is not
150 // checked.
RemoveNode(NodeDef * nd,GraphDef * graph,NodeMap * node_map)151 void RemoveNode(NodeDef* nd, GraphDef* graph, NodeMap* node_map) {
152 node_map->RemoveNode(nd->name());
153 // TODO(tucker): The efficiency of this routine is poor.
154 // Change to accumulate and do a bulk removal, maybe refactoring
155 // some code from dependency_optimizer.
156 protobuf::RepeatedPtrField<NodeDef>* nodes = graph->mutable_node();
157 for (int i = 0; i < nodes->size(); ++i) {
158 if (nd->name() == (*nodes)[i].name()) {
159 nodes->SwapElements(i, nodes->size() - 1);
160 nodes->RemoveLast();
161 return;
162 }
163 }
164 LOG(FATAL) << "Failed to find node " << nd->name() << " in graph";
165 }
166
167 // Removes a named edge from between two nodes.
RemoveEdge(const string & input_edge_name,const string & from_node_name,NodeDef * to_node,NodeMap * node_map)168 Status RemoveEdge(const string& input_edge_name, const string& from_node_name,
169 NodeDef* to_node, NodeMap* node_map) {
170 protobuf::RepeatedPtrField<string>* inputs = to_node->mutable_input();
171 int edge_index = -1;
172 for (edge_index = 0; edge_index < inputs->size(); ++edge_index) {
173 VLOG(2) << " consider edge " << (*inputs)[edge_index];
174 if ((*inputs)[edge_index] == input_edge_name) {
175 break;
176 }
177 }
178 if (edge_index >= inputs->size()) {
179 return errors::Internal("Could not find input name ", input_edge_name,
180 " at node ", to_node->name());
181 }
182 if (node_map) {
183 node_map->RemoveOutput(from_node_name, to_node->name());
184 }
185 inputs->DeleteSubrange(edge_index, 1);
186 return Status::OK();
187 }
188
189 // In certain cases, we would like to insert an identity op between `input` and
190 // `op` to ensure correctness. We currently do this in 2 cases: when `input` is
191 // Exit node, or when `input` is already marked for allocation with another
192 // scoped allocator op.
193 //
194 // If `input` is an Exit node, we add an identity to avoid the case when Exit
195 // has inputs from different frames.
196 //
197 // If `input` is in `sa_opti->repeated_outputs()`, this means that it will be
198 // potentially used by multiple scope ids. Since there can be only one scope id
199 // per output, we insert an identity between the input and op. This will ensure
200 // that the identity becomes the new input to op, and this identity can be
201 // marked with a new scope id different from `input`.
202 //
203 // If the graph is rewritten, this function will perform the following change:
204 //
205 // input input
206 // | |
207 // op Identity
208 // |
209 // op
210 //
211 // This function returns the input to op in `new_input`, and the output index
212 // from input to op in `new_output_index`.
213 // `edge_name` gives the name of the edge from `input` to `op`, and
214 // `output_index` is the output index of this edge on `input`.
MaybeRewriteInput(ScopedAllocatorOptimizer * sa_opti,int64 invocation_count,GraphDef * graph,NodeMap * node_map,const DataType & dtype,NodeDef * input,const string & edge_name,int output_index,NodeDef * op,NodeDef ** new_input,int * new_output_index,bool * rewrite)215 Status MaybeRewriteInput(ScopedAllocatorOptimizer* sa_opti,
216 int64 invocation_count, GraphDef* graph,
217 NodeMap* node_map, const DataType& dtype,
218 NodeDef* input, const string& edge_name,
219 int output_index, NodeDef* op, NodeDef** new_input,
220 int* new_output_index, bool* rewrite) {
221 *rewrite = IsConstant(*input) || IsExit(*input) ||
222 (sa_opti->repeated_outputs().find(edge_name) !=
223 sa_opti->repeated_outputs().end());
224 if (!(*rewrite)) {
225 *new_input = input;
226 *new_output_index = output_index;
227 return Status::OK();
228 }
229
230 // Create new Identity op.
231 int unique_id;
232 LOG_WARNING_AND_RETURN_IF_ERROR(sa_opti->NewIdentityId(&unique_id));
233 string identity_name = strings::StrCat("scoped_allocator_identity_",
234 unique_id, "_", invocation_count);
235 NodeDefBuilder identity_builder(identity_name, "Identity");
236 identity_builder.Device(op->device());
237 identity_builder.Attr("T", dtype);
238 // Connect output at `output_index` from `input` to `identity`.
239 identity_builder.Input(
240 NodeDefBuilder::NodeOut(input->name(), output_index, dtype));
241 NodeDef* identity = graph->add_node();
242 LOG_WARNING_AND_RETURN_IF_ERROR(identity_builder.Finalize(identity));
243 node_map->AddNode(identity_name, identity);
244 node_map->AddOutput(input->name(), identity_name);
245 node_map->UpdateInput(op->name(), input->name(), identity_name);
246 *op->mutable_input(0) = identity_name;
247 *new_input = identity;
248 *new_output_index = 0;
249 VLOG(1) << "Rewrite input " << edge_name << " op " << op->name()
250 << " old output index " << output_index << " with identity "
251 << identity_name << " new output index 0";
252 return Status::OK();
253 }
254
255 // Populates *inputs with all of the non-control inputs of ops.
256 // Returns error if it fails to find exactly one input for each op,
257 // or if some input is not of type dtype.
GetInputs(ScopedAllocatorOptimizer * sa_opti,int64 invocation_count,GraphDef * graph,const GraphProperties & graph_properties,NodeMap * node_map,const std::vector<NodeDef * > & ops,DataType dtype,std::vector<InputDesc> * inputs)258 Status GetInputs(ScopedAllocatorOptimizer* sa_opti, int64 invocation_count,
259 GraphDef* graph, const GraphProperties& graph_properties,
260 NodeMap* node_map, const std::vector<NodeDef*>& ops,
261 DataType dtype, std::vector<InputDesc>* inputs) {
262 VLOG(1) << "Getinputs";
263 for (NodeDef* n : ops) {
264 NodeDef* inode = nullptr;
265 int output_index = 0;
266 DataType inode_dtype = DT_INVALID;
267 VLOG(2) << "for node " << n->name();
268 for (const auto& input_name : n->input()) {
269 if (!IsControlInput(input_name)) {
270 if (inode) {
271 return errors::Internal("Found more than one input for node ",
272 n->name());
273 }
274 ParseNodeName(input_name, &output_index);
275 inode = node_map->GetNode(input_name);
276 if (inode == nullptr) {
277 return errors::Internal("Did not find node ", input_name);
278 }
279 VLOG(2) << "inode " << inode->DebugString() << " output_index "
280 << output_index;
281 bool rewrite;
282 LOG_WARNING_AND_RETURN_IF_ERROR(MaybeRewriteInput(
283 sa_opti, invocation_count, graph, node_map, dtype, inode,
284 input_name, output_index, n, &inode, &output_index, &rewrite));
285 // If `inode` was rewritten, don't try to get output properties from the
286 // input node below.
287 if (rewrite) {
288 inode_dtype = dtype;
289 }
290 VLOG(2) << "inode after rewrite " << inode->DebugString()
291 << " output_index " << output_index;
292 }
293 }
294 if (inode_dtype == DT_INVALID) {
295 if (!graph_properties.HasOutputProperties(inode->name())) {
296 return errors::Internal("Input node ", inode->name(),
297 " does not have output properties");
298 }
299 const auto& inode_output_props =
300 graph_properties.GetOutputProperties(inode->name());
301 LOG_WARNING_AND_RETURN_IF_ERROR(
302 GetOutputDataType(inode_output_props, output_index, &inode_dtype));
303 }
304 if (inode_dtype != dtype) {
305 return errors::Aborted("ScopedAllocatorOptimizer expected input type ",
306 dtype, " but found ", inode_dtype);
307 }
308 inputs->emplace_back(inode, output_index, n);
309 }
310 return Status::OK();
311 }
312
313 // Return non-control inputs of `op` in `inputs`.
GetDataInputs(GraphDef * graph,NodeMap * node_map,NodeDef * op,std::vector<InputDesc> * inputs)314 Status GetDataInputs(GraphDef* graph, NodeMap* node_map, NodeDef* op,
315 std::vector<InputDesc>* inputs) {
316 VLOG(2) << "GetDataInputs for node " << op->name();
317 NodeDef* inode = nullptr;
318 int output_index = 0;
319 for (const auto& input_name : op->input()) {
320 if (IsControlInput(input_name)) {
321 continue;
322 }
323 ParseNodeName(input_name, &output_index);
324 inode = nullptr;
325 inode = node_map->GetNode(input_name);
326 if (inode == nullptr) {
327 return errors::Internal("Did not find node ", input_name);
328 }
329 VLOG(2) << "inode " << inode->DebugString() << " output_index "
330 << output_index;
331 inputs->emplace_back(inode, output_index, op);
332 }
333 return Status::OK();
334 }
335
DumpGraphToVLOG(const GraphDef & graph,int log_level)336 void DumpGraphToVLOG(const GraphDef& graph, int log_level) {
337 if (VLOG_IS_ON(log_level)) {
338 // VLOG may truncate lines so we print line by line.
339 for (const auto& line : str_util::Split(graph.DebugString(), "\n\r")) {
340 VLOG(log_level) << line;
341 }
342 }
343 }
344
345 } // namespace
346
ExtendNodeAttr(StringPiece name,const std::vector<int32> & values,NodeDef * node_def)347 void ScopedAllocatorOptimizer::ExtendNodeAttr(StringPiece name,
348 const std::vector<int32>& values,
349 NodeDef* node_def) {
350 if (HasNodeAttr(*node_def, name)) {
351 VLOG(2) << "extending";
352 AttrValue* existing = &(*node_def->mutable_attr())[string(name)];
353 for (int32 i : values) {
354 existing->mutable_list()->add_i(i);
355 }
356 } else {
357 VLOG(2) << "setting new attr value";
358 AddNodeAttr(name, values, node_def);
359 }
360 }
361
362 class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
363 public:
~UnaryElementwiseRewriter()364 ~UnaryElementwiseRewriter() override {}
365
366 // Return non-OK if any input is an op that does not use the
367 // AllocatorAttributes set by executor to allocate its output.
CheckUsesAllocatorAttributes(const std::vector<InputDesc> & inputs)368 Status CheckUsesAllocatorAttributes(const std::vector<InputDesc>& inputs) {
369 for (const InputDesc& nd : inputs) {
370 if (IsConstant(*nd.from_node_def)) {
371 return errors::Aborted(
372 "Abandoning ScopedAllocatorOptimizer because input ",
373 nd.from_node_def->name(),
374 " is a Const op which does not use AllocatorAttributes");
375 }
376 }
377 return Status::OK();
378 }
379
380 // Return non-OK if any input is already committed to a ScopedAllocator.
381 //
382 // We insert an identity to ensure that inputs are not committed to different
383 // scope ids in `MaybeRewriteInput`, so this function is basically a sanity
384 // check.
CheckExistingScopedAllocator(const std::vector<InputDesc> & inputs)385 Status CheckExistingScopedAllocator(const std::vector<InputDesc>& inputs) {
386 for (const InputDesc& nd : inputs) {
387 VLOG(2) << "get attrs for " << nd.from_node_def->name();
388 AttrSlice n_attrs = AttrSlice(*nd.from_node_def);
389 std::vector<int32> scope_ids;
390 Status ss = GetNodeAttr(n_attrs, kScopedAllocatorAttrName, &scope_ids);
391 // Check that both output name and output slot match. It is okay to have
392 // different outputs of the input committed to different scope ids.
393 if (ss.ok() && scope_ids[0] == nd.output_slot) {
394 LOG(INFO) << "Abandoning ScopedAllocatorOptimizer because input "
395 << nd.from_node_def->name() << " output " << scope_ids[0]
396 << " is already assigned to scope_id " << scope_ids[1];
397 return errors::Aborted(
398 "Abandoning ScopedAllocatorOptimizer because input ",
399 nd.from_node_def->name(), " output ", scope_ids[0], " is already ",
400 "assigned to scope_id ", scope_ids[1]);
401 }
402 }
403 return Status::OK();
404 }
405
406 // Return non-OK if any input is a member of op_set.
CheckInternalDataDependency(const std::set<string> & op_set,const std::vector<InputDesc> & inputs)407 Status CheckInternalDataDependency(const std::set<string>& op_set,
408 const std::vector<InputDesc>& inputs) {
409 for (const InputDesc& nd : inputs) {
410 if (op_set.find(nd.from_node_def->name()) != op_set.end()) {
411 if (nd.output_slot != tensorflow::Graph::kControlSlot) {
412 return errors::Aborted("Data edge exists between ",
413 nd.from_node_def->name(),
414 " and another "
415 "node in the set");
416 }
417 }
418 }
419 return Status::OK();
420 }
421
422 // Remove all control edges between members of ops.
ClearInternalControlInputs(const std::set<string> & op_set,const std::vector<NodeDef * > & ops,NodeMap * node_map)423 void ClearInternalControlInputs(const std::set<string>& op_set,
424 const std::vector<NodeDef*>& ops,
425 NodeMap* node_map) {
426 for (NodeDef* n : ops) {
427 for (const auto& input_name : n->input()) {
428 if (IsControlInput(input_name)) {
429 int position = 0;
430 string input_node_name = ParseNodeName(input_name, &position);
431 CHECK_EQ(position, -1);
432 if (op_set.find(input_node_name) != op_set.end()) {
433 // This is an internal control edge. Remove it.
434 VLOG(1) << "Remove control output from " << input_node_name
435 << " via edge " << input_name << " to " << n->name();
436 TF_CHECK_OK(RemoveEdge(input_name, input_node_name, n, node_map));
437 }
438 }
439 }
440 }
441 }
442
443 // Examine the input set of an op set, gathering their shapes and types
444 // and checking whether there are any considerations that prevent use
445 // of a single ScopedAllocator for all of those inputs.
AnalyzeInputs(ScopedAllocatorOptimizer * sa_opti,int64 invocation_count,GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,string * device_name,DataType * dtype,std::vector<TensorShape> * input_shapes,std::vector<InputDesc> * inputs,TensorShape * sa_shape)446 Status AnalyzeInputs(ScopedAllocatorOptimizer* sa_opti,
447 int64 invocation_count, GraphDef* graph,
448 NodeMap* node_map, const std::vector<NodeDef*>& ops,
449 const std::set<string>& op_instance_names,
450 string* device_name, DataType* dtype,
451 std::vector<TensorShape>* input_shapes,
452 std::vector<InputDesc>* inputs, TensorShape* sa_shape) {
453 CHECK(graph_properties_);
454 LOG_WARNING_AND_RETURN_IF_ERROR(
455 CheckTypesAndGetShapes(*graph_properties_, ops, dtype, input_shapes));
456 LOG_WARNING_AND_RETURN_IF_ERROR(
457 GetInputs(sa_opti, invocation_count, graph, *graph_properties_,
458 sa_opti->node_map(), ops, *dtype, inputs));
459 LOG_WARNING_AND_RETURN_IF_ERROR(CheckUsesAllocatorAttributes(*inputs));
460 LOG_WARNING_AND_RETURN_IF_ERROR(CheckExistingScopedAllocator(*inputs));
461 LOG_WARNING_AND_RETURN_IF_ERROR(
462 CheckInternalDataDependency(op_instance_names, *inputs));
463 ClearInternalControlInputs(op_instance_names, ops, node_map);
464 *device_name = ops[0]->device();
465 CHECK(!device_name->empty());
466 CHECK(!input_shapes->empty());
467 CHECK_EQ(0, Allocator::kAllocatorAlignment % DataTypeSize(*dtype))
468 << "ScopedAllocatorOptimizer only applies to types that evenly "
469 << "divide kAllocatorAlignment";
470 std::vector<ScopedAllocator::Field> sa_fields;
471 // Calculate the field embedding boundaries and thereby the
472 // required size of the backing tensor.
473 int64 num_bytes = ScopedAllocatorMgr::PopulateFields(
474 0 /*scope_id*/, *input_shapes, *dtype, &sa_fields);
475 int64 num_elts = num_bytes / DataTypeSize(*dtype);
476 VLOG(2) << "num_bytes " << num_bytes << " num_elts=" << num_elts;
477 *sa_shape = TensorShape({num_elts});
478 return Status::OK();
479 }
480
481 // Returns the set of all nodes that are transitively reachable via data or
482 // control edges starting at `source_nodes`. Stop at the boundary of a frame.
TransitiveFanoutWithinFrame(GraphDef * graph,NodeMap * node_map,const std::vector<const NodeDef * > & source_nodes,absl::flat_hash_set<const NodeDef * > * fanout)483 Status TransitiveFanoutWithinFrame(
484 GraphDef* graph, NodeMap* node_map,
485 const std::vector<const NodeDef*>& source_nodes,
486 absl::flat_hash_set<const NodeDef*>* fanout) {
487 std::deque<const NodeDef*> queue(source_nodes.begin(), source_nodes.end());
488 absl::flat_hash_set<const NodeDef*> visited;
489 while (!queue.empty()) {
490 const NodeDef* node = queue.front();
491 queue.pop_front();
492 if (!visited.insert(node).second) {
493 continue;
494 }
495 fanout->insert(node);
496 for (const NodeDef* output : node_map->GetOutputs(node->name())) {
497 if (!ModifiesFrameInfo(*output)) {
498 queue.push_back(output);
499 }
500 VLOG(2) << "TransitiveFanout parent: " << node->name()
501 << " child: " << output->name() << " of type " << output->op();
502 }
503 }
504
505 return Status::OK();
506 }
507
508 // Build the ScopedAllocator node that will be assigned to allocate
509 // the output tensors of the input node set.
ConstructScopedAllocatorNode(ScopedAllocatorOptimizer * sa_opti,GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const string & device_name,DataType dtype,int sa_id,const string & sa_name,const std::vector<TensorShape> & input_shapes,const std::vector<InputDesc> & inputs,const TensorShape & sa_shape)510 Status ConstructScopedAllocatorNode(
511 ScopedAllocatorOptimizer* sa_opti, GraphDef* graph, NodeMap* node_map,
512 const std::vector<NodeDef*>& ops, const string& device_name,
513 DataType dtype, int sa_id, const string& sa_name,
514 const std::vector<TensorShape>& input_shapes,
515 const std::vector<InputDesc>& inputs, const TensorShape& sa_shape) {
516 VLOG(2) << "ConstructScopedAllocatorNode " << sa_name;
517 NodeDefBuilder sa_builder(sa_name, "_ScopedAllocator");
518 sa_builder.Device(device_name);
519 sa_builder.Attr("sa_name", sa_name);
520 sa_builder.Attr("T", dtype);
521 sa_builder.Attr("id", sa_id);
522 sa_builder.Attr("shapes", input_shapes);
523 sa_builder.Attr("shape", sa_shape);
524 sa_builder.Attr("expected_call_count", static_cast<int64>(ops.size()));
525 NodeDef* sa_node = graph->add_node();
526 LOG_WARNING_AND_RETURN_IF_ERROR(sa_builder.Finalize(sa_node));
527 node_map->AddNode(sa_name, sa_node);
528
529 std::vector<const NodeDef*> fanout_sources;
530 fanout_sources.reserve(inputs.size());
531 for (const auto& input : inputs) {
532 fanout_sources.push_back(input.from_node_def);
533 }
534 absl::flat_hash_set<const NodeDef*> fanout;
535 TF_RETURN_IF_ERROR(
536 TransitiveFanoutWithinFrame(graph, node_map, fanout_sources, &fanout));
537
538 // Add control edges from the ScopedAllocatorOp to all of the
539 // input nodes and mark them for allocation from backing tensor.
540 for (int i = 0, end = inputs.size(); i < end; ++i) {
541 auto& nd = inputs[i];
542 if (IsArg(*nd.from_node_def)) {
543 return errors::Aborted(
544 "ScopedAllocatorOptimizer does not work well when the op inputs "
545 "are _Arg ops; skipping this optimizer for this function");
546 }
547 VLOG(2) << "To input " << i << ": " << nd.from_node_def->name()
548 << " add control input "
549 << "^" << sa_name;
550 nd.from_node_def->add_input(strings::StrCat("^", sa_name));
551 // This attribute says: allocate output_slot from
552 // ScopedAllocator instance sa_id + 1 + i.
553 ScopedAllocatorOptimizer::ExtendNodeAttr(kScopedAllocatorAttrName,
554 {nd.output_slot, sa_id + 1 + i},
555 nd.from_node_def);
556 node_map->AddOutput(sa_name, nd.from_node_def->name());
557 }
558
559 // We add control edges in order to (1) delay execution of the
560 // ScopedAllocatorOp until just before first use in order to conserve memory
561 // (2) ensure correctness in the presence of control flow related ops.
562 bool added_delay_edge = false;
563 for (auto& nd : inputs) {
564 std::vector<InputDesc> inputs_to_first;
565 LOG_WARNING_AND_RETURN_IF_ERROR(GetDataInputs(
566 graph, sa_opti->node_map(), nd.from_node_def, &inputs_to_first));
567 for (int i = 0, end = inputs_to_first.size(); i < end; ++i) {
568 if (fanout.find(inputs_to_first[i].from_node_def) != fanout.end()) {
569 VLOG(2) << "Found node " << inputs_to_first[i].from_node_def->name()
570 << " in the fanout of " << sa_name;
571 continue;
572 }
573 sa_node->add_input(
574 strings::StrCat("^", inputs_to_first[i].from_node_def->name()));
575 node_map->AddOutput(inputs_to_first[i].from_node_def->name(), sa_name);
576 added_delay_edge = true;
577 VLOG(2) << "Adding control dependency from "
578 << inputs_to_first[i].from_node_def->name() << " to "
579 << sa_node->name();
580 break;
581 }
582 if (added_delay_edge) {
583 break;
584 }
585 }
586
587 if (!added_delay_edge) {
588 LOG(WARNING) << "Found no node from which a control edge can be added to "
589 "scoped allocator node. If you run into issues with "
590 "graphs that contain control flow, turn off the "
591 "ScopedAllocatorOptimizer and file a bug.";
592 }
593
594 return Status::OK();
595 }
596
BuildSAConcatNode(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,const string & device_name,DataType dtype,int sa_id,const string & sa_name,const string & sac_name,const TensorShape & sa_shape,std::vector<NodeDefBuilder::NodeOut> * sac_inputs)597 Status BuildSAConcatNode(GraphDef* graph, NodeMap* node_map,
598 const std::vector<NodeDef*>& ops,
599 const std::set<string>& op_instance_names,
600 const string& device_name, DataType dtype, int sa_id,
601 const string& sa_name, const string& sac_name,
602 const TensorShape& sa_shape,
603 std::vector<NodeDefBuilder::NodeOut>* sac_inputs) {
604 VLOG(2) << "BuildSAConcatNode " << sac_name;
605 // control input: edge name -> source node name
606 absl::flat_hash_map<string, string> sac_ctl_inputs;
607 for (int i = 0, end = ops.size(); i < end; ++i) {
608 NodeDef* old_op = ops[i];
609 for (const string& old_op_input : old_op->input()) {
610 int position = 0;
611 string input_name = ParseNodeName(old_op_input, &position);
612 if (position == -1) {
613 // A control input: drop if from another member of the op set.
614 if (op_instance_names.find(old_op_input) == op_instance_names.end()) {
615 sac_ctl_inputs.emplace(old_op_input, input_name);
616 }
617 } else {
618 // TODO(tucker): remove redundant check.
619 // A data input: illegal if from another member of the op set.
620 if (op_instance_names.find(old_op_input) != op_instance_names.end()) {
621 LOG(ERROR) << "Data edge between " << old_op_input << " and "
622 << old_op->name() << " cannot build ScopedAllocator.";
623 return errors::Aborted("Data edge between ", old_op_input, " and ",
624 old_op->name(),
625 " cannot build ScopedAllocator.");
626 }
627 sac_inputs->push_back(
628 NodeDefBuilder::NodeOut(old_op_input, 0, dtype));
629 }
630 VLOG(3) << "from op " << i << ": " << old_op->name()
631 << " sac_inputs append " << old_op_input;
632 }
633 }
634 NodeDefBuilder sac_builder(sac_name, "_ScopedAllocatorConcat");
635 VLOG(2) << "New sac_name " << sac_name << " shape "
636 << sa_shape.DebugString();
637 sac_builder.Device(device_name);
638 sac_builder.Attr("sa_name", sa_name);
639 sac_builder.Attr("id", sa_id);
640 sac_builder.Attr("T", dtype);
641 sac_builder.Attr("shape", sa_shape);
642 sac_builder.Attr("N", static_cast<int>(sac_inputs->size()));
643 sac_builder.Input(NodeDefBuilder::NodeOut(sa_name, 0, dtype));
644 sac_builder.Input(*sac_inputs);
645 NodeDef* sac_node = graph->add_node();
646 LOG_WARNING_AND_RETURN_IF_ERROR(sac_builder.Finalize(sac_node));
647 node_map->AddNode(sac_name, sac_node);
648 node_map->AddOutput(sa_name, sac_name);
649
650 // Attach the old control inputs to the new sac node.
651 for (const auto& ctl_input : sac_ctl_inputs) {
652 const auto& ctl_edge = ctl_input.first;
653 const auto& input_name = ctl_input.second;
654 sac_node->add_input(ctl_edge);
655 node_map->AddOutput(input_name, sac_node->name());
656 }
657 return Status::OK();
658 }
659
BuildReplacementOp(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const string & device_name,DataType dtype,const string & op_name,const string & sac_name,const string & sa_op_name)660 Status BuildReplacementOp(GraphDef* graph, NodeMap* node_map,
661 const std::vector<NodeDef*>& ops,
662 const string& device_name, DataType dtype,
663 const string& op_name, const string& sac_name,
664 const string& sa_op_name) {
665 VLOG(2) << "BuildReplacementOp " << sa_op_name;
666 NodeDefBuilder op_builder(sa_op_name, op_name);
667 op_builder.Device(device_name);
668
669 // Transfer the Node Attr from the first replaced Node to the new
670 // Node. TODO(tucker): In principle we should verify that
671 // the Attr are consistent and compatible across all op instances.
672 // Unfortunately that will probably require op-specific tests, so
673 // punt on that for the time being.
674 AttrSlice first_slice(*ops[0]);
675 for (auto& it : first_slice) {
676 op_builder.Attr(it.first, it.second);
677 }
678 op_builder.Attr("_forward_input", {0, 0});
679 op_builder.Input(sac_name, 0, dtype);
680 NodeDef* sa_op_node = graph->add_node();
681 LOG_WARNING_AND_RETURN_IF_ERROR(op_builder.Finalize(sa_op_node));
682 node_map->AddNode(sa_op_name, sa_op_node);
683 node_map->AddOutput(sac_name, sa_op_name);
684 return Status::OK();
685 }
686
BuildSplitNode(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::vector<TensorShape> & input_shapes,const std::vector<NodeDefBuilder::NodeOut> & sac_inputs,const string & device_name,DataType dtype,const string & op_name,int sa_id,const string & sas_name,const string & sa_name,const string & sa_op_name)687 Status BuildSplitNode(GraphDef* graph, NodeMap* node_map,
688 const std::vector<NodeDef*>& ops,
689 const std::vector<TensorShape>& input_shapes,
690 const std::vector<NodeDefBuilder::NodeOut>& sac_inputs,
691 const string& device_name, DataType dtype,
692 const string& op_name, int sa_id,
693 const string& sas_name, const string& sa_name,
694 const string& sa_op_name) {
695 VLOG(2) << "new ScopedAllocatorSplit " << sas_name;
696 NodeDefBuilder sas_builder(sas_name, "_ScopedAllocatorSplit");
697 sas_builder.Device(device_name);
698 sas_builder.Attr("sa_name", sa_name);
699 sas_builder.Attr("id", sa_id);
700 sas_builder.Attr("T", dtype);
701 sas_builder.Attr("shapes", input_shapes);
702 std::vector<NodeDefBuilder::NodeOut> sas_inputs = sac_inputs;
703 sas_builder.Attr("N", static_cast<int>(sas_inputs.size()));
704 sas_builder.Input(NodeDefBuilder::NodeOut(sa_op_name, 0, dtype));
705 sas_builder.Input(sas_inputs);
706 NodeDef* sas_node = graph->add_node();
707 LOG_WARNING_AND_RETURN_IF_ERROR(sas_builder.Finalize(sas_node));
708 node_map->AddNode(sas_name, sas_node);
709 node_map->AddOutput(sa_op_name, sas_name);
710 for (const auto& input : sas_inputs) {
711 node_map->AddOutput(input.node, sas_name);
712 }
713 return Status::OK();
714 }
715
716 // After the new ScopedAllocator and its corresponding Concat and
717 // Split nodes have been built, and a new single Op instance
718 // constructed, rewire the graph: Remove input edges to the old Op
719 // nodes and replace the old Op node outputs with the corresponding
720 // ScopedAllocatorSplit node outputs. After this the old Op nodes
721 // should no longer have any input or output edges and they can be
722 // removed from the graph.
RewireSubgraph(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,const string & op_name,const string & sas_name)723 Status RewireSubgraph(GraphDef* graph, NodeMap* node_map,
724 const std::vector<NodeDef*>& ops,
725 const std::set<string>& op_instance_names,
726 const string& op_name, const string& sas_name) {
727 VLOG(2) << "RewireSubgraph";
728 for (int op_idx = 0, idx_limit = ops.size(); op_idx < idx_limit; ++op_idx) {
729 NodeDef* old_op = ops[op_idx];
730 // Copy the output node set since we'll be modifying the version
731 // maintained by NodeMap in the loop.
732 auto output_nodes = node_map->GetOutputs(old_op->name());
733 VLOG(3) << "old_op " << old_op->name() << " had " << output_nodes.size()
734 << " outputs. Moving them to the ScopedAllocatorSplit node.";
735 if (VLOG_IS_ON(2)) {
736 for (NodeDef* n : output_nodes) {
737 VLOG(3) << " output: " << n->name();
738 }
739 }
740 for (NodeDef* n : output_nodes) {
741 VLOG(3) << "really checking old output " << n->name()
742 << " for corresponding input.";
743 if (op_instance_names.find(n->name()) != op_instance_names.end()) {
744 // If this output node is a member of the ops set, it must have
745 // been an internal control edge so drop it.
746 VLOG(3) << "Dropping control output from " << old_op->name() << " to "
747 << n->name();
748 // However, we may already have dropped it at the clear() below,
749 // so if we fail to find it, that's okay.
750 Status ignore = RemoveEdge(strings::StrCat("^", old_op->name()),
751 old_op->name(), n, node_map);
752 continue;
753 }
754 bool found = false;
755 VLOG(3) << "about to iterate over " << n->input_size() << " inputs";
756 for (int i = 0; i < n->input_size(); ++i) {
757 VLOG(3) << "input " << n->input(i);
758 int position = 0;
759 string input_node = ParseNodeName(n->input(i), &position);
760 if (input_node == old_op->name()) {
761 found = true;
762 VLOG(3) << "match pos=" << position;
763 if (position == -1) {
764 // It was a control edge
765 *n->mutable_input(i) = strings::StrCat("^", sas_name);
766 } else {
767 CHECK_EQ(0, position)
768 << "name " << n->input(i) << " pos " << position;
769 *n->mutable_input(i) = strings::StrCat(sas_name, ":", op_idx);
770 }
771 node_map->UpdateInput(n->name(), old_op->name(), sas_name);
772 VLOG(3) << "breaking on success";
773 break;
774 } else {
775 VLOG(3) << "other input " << n->input(i);
776 }
777 }
778 // In general it's required that we found the output node's old
779 // input and replaced it, but one exception is if the output node
780 // is of the same type being coalesced and the edge is a control
781 // input. In that case it probably got eliminated in an earlier
782 // pass.
783 VLOG(3) << "before HasOp";
784 if (!HasOpName(n->name(), op_name)) {
785 CHECK(found) << "old_op " << old_op->name() << " node "
786 << " could not find input edge on " << n->DebugString()
787 << " to replace."
788 << " " << op_name << " not in " << n->name();
789 }
790 VLOG(3) << "bottom of for output_nodes";
791 }
792 VLOG(3) << "Clearing all inputs of " << old_op->name();
793 node_map->RemoveInputs(old_op->name());
794 old_op->clear_input();
795 node_map->RemoveOutputs(old_op->name());
796 VLOG(3) << "after clear: " << old_op->DebugString();
797 // old_op should be dead, with no further inputs or outputs.
798 // It needs to be removed altogether before the graph is generated,
799 // but we need to leave it around until this Optimizer is done,
800 // because there may be some
801 // Remove.
802 RemoveNode(old_op, graph, node_map);
803 }
804 return Status::OK();
805 }
806
807 // Given a collection of instances of op_name, presumed to be
808 // logically parallel and operating on tensors of the same type,
809 // replace them by a single instance. First find the upstream Ops
810 // generating their inputs. Create a new ScopedAllocatorOp that
811 // outputs a single backing_tensor pre-arranged for sub-allocation
812 // of all of those input tensors. Then insert a new
813 // ScopedAllocatorConcatOp below the upstream Ops to make explicit
814 // the materialization of a concatenation of their outputs. Put the
815 // new op_name instance below the new concat op and follow with a
816 // ScopedAllocatorSplitOp that restores the correct shape outputs
817 // for the consumers of the old op_name instances.
818 //
819 // There must be no non-control edges between Nodes in 'ops'.
820 // Control edges among these nodes will be dropped.
Rewrite(ScopedAllocatorOptimizer * sa_opti,int64 invocation_count,GraphDef * graph,const string & op_name,const std::vector<NodeDef * > & ops,bool * applied)821 Status Rewrite(ScopedAllocatorOptimizer* sa_opti, int64 invocation_count,
822 GraphDef* graph, const string& op_name,
823 const std::vector<NodeDef*>& ops, bool* applied) override {
824 if (VLOG_IS_ON(1)) {
825 VLOG(1) << "Rewrite";
826 string op_names;
827 for (auto& nd : ops) {
828 strings::StrAppend(&op_names, nd->name(), ", ");
829 }
830 VLOG(1) << "UnaryElementwiseRewriter::Rewrite " << op_name
831 << " to: " << op_names;
832 }
833 NodeMap* node_map = sa_opti->node_map();
834
835 // Make a set of the node names for faster membership testing.
836 std::set<string> op_instance_names;
837 for (auto& nd : ops) {
838 op_instance_names.insert(nd->name());
839 VLOG(2) << "op_instance_name " << nd->name();
840 }
841 DataType dtype;
842 std::vector<TensorShape> input_shapes;
843 std::vector<InputDesc> inputs;
844 TensorShape sa_shape;
845 string device_name;
846
847 TF_RETURN_IF_ERROR(AnalyzeInputs(
848 sa_opti, invocation_count, graph, node_map, ops, op_instance_names,
849 &device_name, &dtype, &input_shapes, &inputs, &sa_shape));
850
851 int sa_id = sa_opti->NewScopedAllocatorId(input_shapes.size());
852 string sa_name =
853 strings::StrCat("scoped_allocator_", sa_id, "_", invocation_count);
854 TF_RETURN_IF_ERROR(ConstructScopedAllocatorNode(
855 sa_opti, graph, node_map, ops, device_name, dtype, sa_id, sa_name,
856 input_shapes, inputs, sa_shape));
857
858 // Build a ScopedAllocatorConcat below all of the input nodes.
859 std::vector<NodeDefBuilder::NodeOut> sac_inputs;
860 string sac_name = strings::StrCat("scoped_allocator_concat_", sa_id, "_",
861 invocation_count);
862 TF_RETURN_IF_ERROR(BuildSAConcatNode(
863 graph, node_map, ops, op_instance_names, device_name, dtype, sa_id,
864 sa_name, sac_name, sa_shape, &sac_inputs));
865
866 // Construct a new instance of the parallel op and insert it
867 // immediately below the new ScopedAllocatorConcat.
868 string sa_op_name = strings::StrCat(sa_name, "_", op_name);
869 TF_RETURN_IF_ERROR(BuildReplacementOp(graph, node_map, ops, device_name,
870 dtype, op_name, sac_name,
871 sa_op_name));
872
873 // Build a ScopedAllocatorSplit split below the new Op.
874 string sas_name = strings::StrCat("scoped_allocator_split_", sa_id, "_",
875 invocation_count);
876 TF_RETURN_IF_ERROR(BuildSplitNode(graph, node_map, ops, input_shapes,
877 sac_inputs, device_name, dtype, op_name,
878 sa_id, sas_name, sa_name, sa_op_name));
879
880 // Rewire the graph.
881 TF_RETURN_IF_ERROR(RewireSubgraph(graph, node_map, ops, op_instance_names,
882 op_name, sas_name));
883
884 *applied = true;
885 return Status::OK();
886 }
887 };
888
ScopedAllocatorOptimizer(RewriterConfig::Toggle opt_level,const ScopedAllocatorOptions & opts)889 ScopedAllocatorOptimizer::ScopedAllocatorOptimizer(
890 RewriterConfig::Toggle opt_level, const ScopedAllocatorOptions& opts)
891 : opt_level_(opt_level) {
892 VLOG(1) << "ScopedAllocatorOptimizer::ScopedAllocatorOptimizer";
893 Rewriter* r = new UnaryElementwiseRewriter();
894 to_delete_.push_back(r);
895 if (opts.enable_op_size() == 0) {
896 // Opts handled by default:
897 for (const auto& op_name : {"CollectiveReduce"}) {
898 op_name_set_.insert(op_name);
899 rewriters_[op_name] = r;
900 }
901 } else {
902 for (const auto& op_name : opts.enable_op()) {
903 op_name_set_.insert(op_name);
904 rewriters_[op_name] = r;
905 }
906 }
907 }
908
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)909 Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/,
910 const GrapplerItem& item,
911 GraphDef* optimized_graph) {
912 VLOG(3) << "Input graph:";
913 DumpGraphToVLOG(item.graph, /*log_level=*/3);
914
915 // Nodes that cannot be removed from the graph without damaging correctness,
916 // typically fetch nodes.
917 nodes_to_preserve_ = item.NodesToPreserve();
918
919 GraphProperties graph_properties(item);
920 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
921 LOG_WARNING_AND_RETURN_IF_ERROR(graph_properties.InferStatically(
922 assume_valid_feeds, /*aggressive_shape_inference=*/false,
923 /*include_tensor_values=*/false));
924 *optimized_graph = item.graph;
925 node_map_ = absl::make_unique<NodeMap>(optimized_graph);
926
927 LOG_WARNING_AND_RETURN_IF_ERROR(ScopedAllocatorOptimizer::ProcessGraphDef(
928 optimized_graph, graph_properties));
929
930 VLOG(1) << "ScopedAllocatorOptimizer::Optimize() done";
931 VLOG(3) << "Optimized graph:";
932 DumpGraphToVLOG(*optimized_graph, /*log_level=*/3);
933 return Status::OK();
934 }
935
GetRewriter(const string & op_name)936 ScopedAllocatorOptimizer::Rewriter* ScopedAllocatorOptimizer::GetRewriter(
937 const string& op_name) {
938 auto it = rewriters_.find(op_name);
939 if (it != rewriters_.end()) {
940 return it->second;
941 }
942 return nullptr;
943 }
944
NewScopedAllocatorId(int num_fields)945 int ScopedAllocatorOptimizer::NewScopedAllocatorId(int num_fields) {
946 CHECK_GT(num_fields, 0);
947 int id = next_sa_id_;
948 next_sa_id_ += (num_fields + 1);
949 CHECK_GT(next_sa_id_, 0);
950 return id;
951 }
952
NewIdentityId(int * id)953 Status ScopedAllocatorOptimizer::NewIdentityId(int* id) {
954 *id = next_identity_id_++;
955 if (next_identity_id_ < 0) {
956 return errors::Aborted("NewIdentityId overflow");
957 }
958 return Status::OK();
959 }
960
~ScopedAllocatorOptimizer()961 ScopedAllocatorOptimizer::~ScopedAllocatorOptimizer() {
962 for (auto ptr : to_delete_) {
963 delete ptr;
964 }
965 }
966
FindOpOccurrences(GraphDef * graph,const OpNameSet & op_names,GraphOpOccurrences * occs)967 void ScopedAllocatorOptimizer::FindOpOccurrences(GraphDef* graph,
968 const OpNameSet& op_names,
969 GraphOpOccurrences* occs) {
970 VLOG(1) << "FindOpOccurrences ";
971 for (const auto& it : op_names) {
972 VLOG(1) << "search target " << it;
973 }
974 for (int ni = 0; ni < graph->node_size(); ++ni) {
975 NodeDef* node = graph->mutable_node(ni);
976 const string& op_name = node->op();
977 if (op_names.find(op_name) != op_names.end()) {
978 VLOG(1) << "found " << op_name << " on dev " << node->device();
979 (*occs)[node->device()][op_name].push_back(node);
980 }
981 }
982 }
983
984 namespace {
985 struct OpNameOrder {
operator ()tensorflow::grappler::__anone24104b20211::OpNameOrder986 bool operator()(const NodeDef* a, const NodeDef* b) {
987 return a->name() <= b->name();
988 }
989 };
990
991 class Tree {
992 public:
Tree(const string & edge,int depth)993 Tree(const string& edge, int depth) : edge_(edge), depth_(depth) {}
~Tree()994 ~Tree() {
995 for (const auto& it : subtrees_) delete it.second;
996 }
997
GetSubTree(const string & edge)998 Tree* GetSubTree(const string& edge) {
999 auto it = subtrees_.find(edge);
1000 if (it != subtrees_.end()) {
1001 return it->second;
1002 }
1003 Tree* t = new Tree(edge, depth_ + 1);
1004 subtrees_[edge] = t;
1005 return t;
1006 }
1007
InsertNode(NodeDef * n)1008 void InsertNode(NodeDef* n) { nodes_.push_back(n); }
1009
1010 string edge_;
1011 int depth_;
1012 std::vector<NodeDef*> nodes_;
1013 absl::flat_hash_map<string, Tree*> subtrees_;
1014 };
1015
1016 // Applies a function to every Tree in DFS order. Terminates early
1017 // on any non-OK Status.
ApplyToAll(Tree * tree,const std::function<Status (Tree *)> & func)1018 Status ApplyToAll(Tree* tree, const std::function<Status(Tree*)>& func) {
1019 Status s;
1020 for (const auto& it : tree->subtrees_) {
1021 s = ApplyToAll(it.second, func);
1022 if (!s.ok()) return s;
1023 }
1024 s = func(tree);
1025 return s;
1026 }
1027
ComputeScopeTree(const string & op_name,const std::vector<NodeDef * > & node_vec)1028 Tree* ComputeScopeTree(const string& op_name,
1029 const std::vector<NodeDef*>& node_vec) {
1030 Tree* root = new Tree("", 0);
1031 for (NodeDef* n : node_vec) {
1032 std::vector<string> pieces = str_util::Split(n->name(), "/");
1033 // last piece is node name proper.
1034 int depth = pieces.size() - 1;
1035 Tree* subtree = root;
1036 for (int i = 0; i < depth; ++i) {
1037 subtree = subtree->GetSubTree(pieces[i]);
1038 }
1039 subtree->InsertNode(n);
1040 }
1041 return root;
1042 }
1043
PartitionByLoopStructure(const FrameView & frame_view,std::vector<NodeDef * > nodes,std::vector<std::vector<NodeDef * >> * loop_groups)1044 void PartitionByLoopStructure(const FrameView& frame_view,
1045 std::vector<NodeDef*> nodes,
1046 std::vector<std::vector<NodeDef*>>* loop_groups) {
1047 // It is assumed that two nodes with identical loop containment have
1048 // identical integer vectors. Represent those by 64 bit hashes.
1049 absl::flat_hash_map<uint64, std::vector<NodeDef*>> loop_sets;
1050 for (NodeDef* nd : nodes) {
1051 uint64 hash = 0;
1052 const std::vector<int>& loop_ids = frame_view.Frames(*nd);
1053 for (int id : loop_ids) {
1054 hash = Hash64Combine(hash, static_cast<uint64>(id));
1055 }
1056 loop_sets[hash].push_back(nd);
1057 }
1058 for (auto it : loop_sets) {
1059 loop_groups->push_back(std::move(it.second));
1060 }
1061 }
1062
1063 // Identify outputs that are inputs to multiple sets of nodes.
IdentifyRepeatedInputs(const std::vector<NodeDef * > & nodes,absl::flat_hash_set<string> * seen_outputs,absl::flat_hash_set<string> * repeated_outputs)1064 void IdentifyRepeatedInputs(const std::vector<NodeDef*>& nodes,
1065 absl::flat_hash_set<string>* seen_outputs,
1066 absl::flat_hash_set<string>* repeated_outputs) {
1067 for (NodeDef* node : nodes) {
1068 for (const auto& input_name : node->input()) {
1069 if (!seen_outputs->insert(input_name).second) {
1070 repeated_outputs->insert(input_name);
1071 }
1072 }
1073 }
1074 }
1075
1076 } // namespace
1077
ProcessGraphDef(GraphDef * graph,const GraphProperties & graph_properties)1078 Status ScopedAllocatorOptimizer::ProcessGraphDef(
1079 GraphDef* graph, const GraphProperties& graph_properties) {
1080 // Nodes created by this optimizer have the IsStateful() property
1081 // which means their names must be globally unique within a process,
1082 // so we include an optimizer invocation count in every generated
1083 // name.
1084 static std::atomic<int64> invocation_counter(1);
1085 const int64 invocation_count =
1086 invocation_counter.fetch_add(1, std::memory_order_seq_cst);
1087 VLOG(1) << "ProcessGraphDef " << invocation_count;
1088 Status status;
1089 GraphOpOccurrences occ;
1090 FindOpOccurrences(graph, op_name_set_, &occ);
1091 if (!occ.empty()) {
1092 FrameView frame_view;
1093 // TODO(ezhulenev): Pass a GraphView when this optimizer will be migrated
1094 // from NodeMap.
1095 LOG_WARNING_AND_RETURN_IF_ERROR(frame_view.InferFromGraph(*graph));
1096
1097 for (auto& dt : occ) {
1098 VLOG(2) << "Processing device " << dt.first;
1099 const DevOpOccurrences& dev_occ = dt.second;
1100 for (auto& it : dev_occ) {
1101 string op_name = it.first;
1102 VLOG(1) << "Processing " << op_name << " set size " << it.second.size();
1103 Rewriter* rewriter = GetRewriter(op_name);
1104 if (!rewriter) {
1105 LOG(ERROR) << "Failed to find Rewriter in ScopedAllocatorOptimizer "
1106 << "for op_name " << op_name;
1107 continue;
1108 }
1109 rewriter->SetGraphProperties(graph_properties);
1110 std::unique_ptr<Tree> root(ComputeScopeTree(it.first, it.second));
1111 // Record outputs that are inputs to multiple Tree nodes.
1112 absl::flat_hash_set<string> seen_outputs;
1113 status = ApplyToAll(root.get(), [this, &seen_outputs](Tree* t) {
1114 IdentifyRepeatedInputs(t->nodes_, &seen_outputs, &repeated_outputs_);
1115 return Status::OK();
1116 });
1117 if (!status.ok()) {
1118 break;
1119 }
1120 // Nodes with a common depth and root path are now grouped
1121 // in the same Tree struct. Split those groups into subgroups that
1122 // share identical loop nesting.
1123 status = ApplyToAll(root.get(), [this, rewriter, graph, &frame_view,
1124 &op_name, invocation_count](Tree* t) {
1125 VLOG(2) << "applied to tree node " << t->edge_ << " at depth "
1126 << t->depth_ << " of size " << t->nodes_.size();
1127 if (t->nodes_.size() > 1) {
1128 std::vector<std::vector<NodeDef*>> loop_groups;
1129 PartitionByLoopStructure(frame_view, t->nodes_, &loop_groups);
1130 for (auto& lg : loop_groups) {
1131 if (lg.size() > 1) {
1132 bool applied = false;
1133 Status s = OrderNodeSet(&lg);
1134 TF_RETURN_IF_ERROR(s);
1135 VLOG(1) << "Applying Rewriter for " << op_name;
1136 s = rewriter->Rewrite(this, invocation_count, graph, op_name,
1137 lg, &applied);
1138 LOG_WARNING_AND_RETURN_IF_ERROR(s);
1139 }
1140 }
1141 }
1142 return Status::OK();
1143 });
1144 if (!status.ok()) {
1145 break;
1146 }
1147 }
1148 if (!status.ok()) {
1149 break;
1150 }
1151 }
1152 }
1153 VLOG(1) << "ScopedAllocatorOptimizer returning " << status;
1154 if (!status.ok()) {
1155 LOG(ERROR) << "ScopedAllocatorOptimizer: " << status;
1156 }
1157 return status;
1158 }
1159
1160 namespace {
1161 struct InstanceKeyLess {
operator ()tensorflow::grappler::__anone24104b20511::InstanceKeyLess1162 bool operator()(const NodeDef* a, const NodeDef* b) const {
1163 AttrSlice a_attrs = AttrSlice(*a);
1164 AttrSlice b_attrs = AttrSlice(*b);
1165 int32 a_key = -1;
1166 int32 b_key = -1;
1167 Status s = GetNodeAttr(a_attrs, "instance_key", &a_key);
1168 CHECK(s.ok());
1169 s = GetNodeAttr(b_attrs, "instance_key", &b_key);
1170 CHECK(s.ok());
1171 return a_key < b_key;
1172 }
1173 };
1174
1175 struct NameLess {
operator ()tensorflow::grappler::__anone24104b20511::NameLess1176 bool operator()(const NodeDef* a, const NodeDef* b) const {
1177 return a->name() < b->name();
1178 }
1179 };
1180
IsCollectiveNode(const NodeDef & n)1181 bool IsCollectiveNode(const NodeDef& n) {
1182 AttrSlice attrs = AttrSlice(n);
1183 int key = -1;
1184 if (!IsCollective(n)) return false;
1185 Status s = GetNodeAttr(attrs, "instance_key", &key);
1186 if (s.ok() && key >= 0) {
1187 return true;
1188 }
1189 return false;
1190 }
1191 } // namespace
1192
OrderNodeSet(std::vector<NodeDef * > * nodes) const1193 Status ScopedAllocatorOptimizer::OrderNodeSet(
1194 std::vector<NodeDef*>* nodes) const {
1195 // Nodes should be identical type. Default order is by name but for
1196 // collectives we order by increasing instance_key so each group gets
1197 // the same instance_key.
1198 if (nodes->size() <= 1) return Status::OK();
1199 if (IsCollectiveNode(*nodes->at(0))) {
1200 sort(nodes->begin(), nodes->end(), InstanceKeyLess());
1201 } else {
1202 sort(nodes->begin(), nodes->end(), NameLess());
1203 }
1204 return Status::OK();
1205 }
1206
1207 } // namespace grappler
1208 } // namespace tensorflow
1209
1210 #undef LOG_WARNING_AND_RETURN_IF_ERROR
1211