1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
16 
17 #include "tensorflow/core/common_runtime/scoped_allocator.h"
18 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/grappler/costs/graph_properties.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/utils/frame.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 
29 // Like TF_RETURN_IF_ERROR, but also logs a WARNING.
30 #define LOG_WARNING_AND_RETURN_IF_ERROR(...)            \
31   do {                                                  \
32     const ::tensorflow::Status _status = (__VA_ARGS__); \
33     if (TF_PREDICT_FALSE(!_status.ok())) {              \
34       LOG(WARNING) << "error: " << _status;             \
35       return _status;                                   \
36     }                                                   \
37   } while (0)
38 
39 namespace tensorflow {
40 namespace grappler {
41 
42 namespace {
43 
44 const char kScopedAllocatorAttrName[] = "_scoped_allocator";
45 
46 // Node names often have some kind of name_scope prefix, with slashes,
47 // and a _nn numeric suffix.  Returns true if the main part of the node_name
48 // matches op_name, i.e. it looks from the name like this node is
49 // of that op type.
HasOpName(const string & node_name,const string & op_name)50 bool HasOpName(const string& node_name, const string& op_name) {
51   size_t begin = node_name.rfind('/');
52   if (begin == string::npos) {
53     begin = 0;
54   } else {
55     ++begin;
56   }
57   size_t end = node_name.rfind('_');
58   if (end != string::npos) {
59     size_t p = end + 1;
60     while (p < node_name.size()) {
61       if (!isdigit(node_name[p])) {
62         end = node_name.size();
63         break;
64       }
65       ++p;
66     }
67   } else {
68     end = node_name.size();
69   }
70   return node_name.substr(begin, end - begin) == op_name;
71 }
72 
GetOutputDataType(const std::vector<OpInfo::TensorProperties> & output_props,int output_index,DataType * dtype)73 Status GetOutputDataType(
74     const std::vector<OpInfo::TensorProperties>& output_props, int output_index,
75     DataType* dtype) {
76   int output_props_size = output_props.size();
77   if (output_index >= output_props_size) {
78     return errors::Internal("Invalid output index ", output_index,
79                             " size of output_props ", output_props.size());
80   }
81   *dtype = output_props[output_index].dtype();
82   return Status::OK();
83 }
84 
85 // After shape inference has been done each op should be annotated
86 // with its output shape(s).  This function iterates over a collection
87 // of ops that are a potential application of a ScopedAllocator.  It
88 // verifies whether they all have the same output type and if so
89 // gathers a vector of their output shapes.  It returns an error if
90 // any of the ops doesn't have type or shape data, or if it has more
91 // than one output, of if the output type of all ops is not the same.
92 // If it returns OK then *type and *shapes should be correctly populated.
CheckTypesAndGetShapes(const GraphProperties & graph_properties,const std::vector<NodeDef * > & ops,DataType * type,std::vector<TensorShape> * shapes)93 Status CheckTypesAndGetShapes(const GraphProperties& graph_properties,
94                               const std::vector<NodeDef*>& ops, DataType* type,
95                               std::vector<TensorShape>* shapes) {
96   VLOG(1) << "CheckTypesAndGetShapes";
97   *type = DT_INVALID;
98   for (NodeDef* n : ops) {
99     AttrSlice n_attrs = AttrSlice(*n);
100     DataType dtype;
101     // Check that op has an explicit data type attr "T".
102     LOG_WARNING_AND_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
103     VLOG(2) << "op " << n->name() << " has type " << dtype << " shapes.size() "
104             << shapes->size();
105     if (!graph_properties.HasOutputProperties(n->name())) {
106       LOG(ERROR) << "Node " << n->DebugString() << " lacks output shape.";
107       return errors::Aborted("Node ", n->name(), " lacks output shape.");
108     }
109     const std::vector<OpInfo::TensorProperties>& prop_list =
110         graph_properties.GetOutputProperties(n->name());
111     if (prop_list.size() != 1) {
112       return errors::Aborted("Node ", n->name(),
113                              " does not have exactly one output as expected "
114                              "by ScopedAllocatorOptimizer");
115     }
116     const OpInfo::TensorProperties& props = prop_list[0];
117     if (shapes->empty()) {
118       *type = props.dtype();
119     } else if (*type != props.dtype()) {
120       return errors::Aborted("Group ops don't all have same type");
121     }
122     if (*type != dtype) {
123       return errors::Internal(
124           "Type mismatch: type in op attr = ", DataTypeString(dtype),
125           ", type in output props = ", DataTypeString(*type));
126     }
127     if (!TensorShape::IsValid(props.shape()) || props.shape().unknown_rank()) {
128       // TensorShape::IsValid may return true if unknown_rank is True, i.e.
129       // number of dimensions is unknown.  But for ScopedAllocatorOptimizer we
130       // need to know the shape fully.
131       return errors::Aborted("Complete shape not known for ", n->name());
132     }
133     VLOG(2) << "Adding shape " << props.shape().DebugString();
134     shapes->push_back(TensorShape(props.shape()));
135   }
136   return Status::OK();
137 }
138 
139 // Describes an existing input edge in the graph.
140 struct InputDesc {
141   NodeDef* from_node_def;
142   int output_slot;
143   NodeDef* to_node_def;
InputDesctensorflow::grappler::__anone24104b20111::InputDesc144   InputDesc(NodeDef* f, int os, NodeDef* t)
145       : from_node_def(f), output_slot(os), to_node_def(t) {}
146 };
147 
148 // Remove the NodeDef nd from node_map and graph.  It must be the case
149 // that nd no longer has any input or output edges, though that is not
150 // checked.
RemoveNode(NodeDef * nd,GraphDef * graph,NodeMap * node_map)151 void RemoveNode(NodeDef* nd, GraphDef* graph, NodeMap* node_map) {
152   node_map->RemoveNode(nd->name());
153   // TODO(tucker): The efficiency of this routine is poor.
154   // Change to accumulate and do a bulk removal, maybe refactoring
155   // some code from dependency_optimizer.
156   protobuf::RepeatedPtrField<NodeDef>* nodes = graph->mutable_node();
157   for (int i = 0; i < nodes->size(); ++i) {
158     if (nd->name() == (*nodes)[i].name()) {
159       nodes->SwapElements(i, nodes->size() - 1);
160       nodes->RemoveLast();
161       return;
162     }
163   }
164   LOG(FATAL) << "Failed to find node " << nd->name() << " in graph";
165 }
166 
167 // Removes a named edge from between two nodes.
RemoveEdge(const string & input_edge_name,const string & from_node_name,NodeDef * to_node,NodeMap * node_map)168 Status RemoveEdge(const string& input_edge_name, const string& from_node_name,
169                   NodeDef* to_node, NodeMap* node_map) {
170   protobuf::RepeatedPtrField<string>* inputs = to_node->mutable_input();
171   int edge_index = -1;
172   for (edge_index = 0; edge_index < inputs->size(); ++edge_index) {
173     VLOG(2) << " consider edge " << (*inputs)[edge_index];
174     if ((*inputs)[edge_index] == input_edge_name) {
175       break;
176     }
177   }
178   if (edge_index >= inputs->size()) {
179     return errors::Internal("Could not find input name ", input_edge_name,
180                             " at node ", to_node->name());
181   }
182   if (node_map) {
183     node_map->RemoveOutput(from_node_name, to_node->name());
184   }
185   inputs->DeleteSubrange(edge_index, 1);
186   return Status::OK();
187 }
188 
189 // In certain cases, we would like to insert an identity op between `input` and
190 // `op` to ensure correctness.  We currently do this in 2 cases: when `input` is
191 // Exit node, or when `input` is already marked for allocation with another
192 // scoped allocator op.
193 //
194 // If `input` is an Exit node, we add an identity to avoid the case when Exit
195 // has inputs from different frames.
196 //
197 // If `input` is in `sa_opti->repeated_outputs()`, this means that it will be
198 // potentially used by multiple scope ids.  Since there can be only one scope id
199 // per output, we insert an identity between the input and op.  This will ensure
200 // that the identity becomes the new input to op, and this identity can be
201 // marked with a new scope id different from `input`.
202 //
203 // If the graph is rewritten, this function will perform the following change:
204 //
205 //  input                                  input
206 //   |                                      |
207 //   op                                  Identity
208 //                                          |
209 //                                          op
210 //
211 // This function returns the input to op in `new_input`, and the output index
212 // from input to op in `new_output_index`.
213 // `edge_name` gives the name of the edge from `input` to `op`, and
214 // `output_index` is the output index of this edge on `input`.
MaybeRewriteInput(ScopedAllocatorOptimizer * sa_opti,int64 invocation_count,GraphDef * graph,NodeMap * node_map,const DataType & dtype,NodeDef * input,const string & edge_name,int output_index,NodeDef * op,NodeDef ** new_input,int * new_output_index,bool * rewrite)215 Status MaybeRewriteInput(ScopedAllocatorOptimizer* sa_opti,
216                          int64 invocation_count, GraphDef* graph,
217                          NodeMap* node_map, const DataType& dtype,
218                          NodeDef* input, const string& edge_name,
219                          int output_index, NodeDef* op, NodeDef** new_input,
220                          int* new_output_index, bool* rewrite) {
221   *rewrite = IsConstant(*input) || IsExit(*input) ||
222              (sa_opti->repeated_outputs().find(edge_name) !=
223               sa_opti->repeated_outputs().end());
224   if (!(*rewrite)) {
225     *new_input = input;
226     *new_output_index = output_index;
227     return Status::OK();
228   }
229 
230   // Create new Identity op.
231   int unique_id;
232   LOG_WARNING_AND_RETURN_IF_ERROR(sa_opti->NewIdentityId(&unique_id));
233   string identity_name = strings::StrCat("scoped_allocator_identity_",
234                                          unique_id, "_", invocation_count);
235   NodeDefBuilder identity_builder(identity_name, "Identity");
236   identity_builder.Device(op->device());
237   identity_builder.Attr("T", dtype);
238   // Connect output at `output_index` from `input` to `identity`.
239   identity_builder.Input(
240       NodeDefBuilder::NodeOut(input->name(), output_index, dtype));
241   NodeDef* identity = graph->add_node();
242   LOG_WARNING_AND_RETURN_IF_ERROR(identity_builder.Finalize(identity));
243   node_map->AddNode(identity_name, identity);
244   node_map->AddOutput(input->name(), identity_name);
245   node_map->UpdateInput(op->name(), input->name(), identity_name);
246   *op->mutable_input(0) = identity_name;
247   *new_input = identity;
248   *new_output_index = 0;
249   VLOG(1) << "Rewrite input " << edge_name << " op " << op->name()
250           << " old output index " << output_index << " with identity "
251           << identity_name << " new output index 0";
252   return Status::OK();
253 }
254 
255 // Populates *inputs with all of the non-control inputs of ops.
256 // Returns error if it fails to find exactly one input for each op,
257 // or if some input is not of type dtype.
GetInputs(ScopedAllocatorOptimizer * sa_opti,int64 invocation_count,GraphDef * graph,const GraphProperties & graph_properties,NodeMap * node_map,const std::vector<NodeDef * > & ops,DataType dtype,std::vector<InputDesc> * inputs)258 Status GetInputs(ScopedAllocatorOptimizer* sa_opti, int64 invocation_count,
259                  GraphDef* graph, const GraphProperties& graph_properties,
260                  NodeMap* node_map, const std::vector<NodeDef*>& ops,
261                  DataType dtype, std::vector<InputDesc>* inputs) {
262   VLOG(1) << "Getinputs";
263   for (NodeDef* n : ops) {
264     NodeDef* inode = nullptr;
265     int output_index = 0;
266     DataType inode_dtype = DT_INVALID;
267     VLOG(2) << "for node " << n->name();
268     for (const auto& input_name : n->input()) {
269       if (!IsControlInput(input_name)) {
270         if (inode) {
271           return errors::Internal("Found more than one input for node ",
272                                   n->name());
273         }
274         ParseNodeName(input_name, &output_index);
275         inode = node_map->GetNode(input_name);
276         if (inode == nullptr) {
277           return errors::Internal("Did not find node ", input_name);
278         }
279         VLOG(2) << "inode " << inode->DebugString() << " output_index "
280                 << output_index;
281         bool rewrite;
282         LOG_WARNING_AND_RETURN_IF_ERROR(MaybeRewriteInput(
283             sa_opti, invocation_count, graph, node_map, dtype, inode,
284             input_name, output_index, n, &inode, &output_index, &rewrite));
285         // If `inode` was rewritten, don't try to get output properties from the
286         // input node below.
287         if (rewrite) {
288           inode_dtype = dtype;
289         }
290         VLOG(2) << "inode after rewrite " << inode->DebugString()
291                 << " output_index " << output_index;
292       }
293     }
294     if (inode_dtype == DT_INVALID) {
295       if (!graph_properties.HasOutputProperties(inode->name())) {
296         return errors::Internal("Input node ", inode->name(),
297                                 " does not have output properties");
298       }
299       const auto& inode_output_props =
300           graph_properties.GetOutputProperties(inode->name());
301       LOG_WARNING_AND_RETURN_IF_ERROR(
302           GetOutputDataType(inode_output_props, output_index, &inode_dtype));
303     }
304     if (inode_dtype != dtype) {
305       return errors::Aborted("ScopedAllocatorOptimizer expected input type ",
306                              dtype, " but found ", inode_dtype);
307     }
308     inputs->emplace_back(inode, output_index, n);
309   }
310   return Status::OK();
311 }
312 
313 // Return non-control inputs of `op` in `inputs`.
GetDataInputs(GraphDef * graph,NodeMap * node_map,NodeDef * op,std::vector<InputDesc> * inputs)314 Status GetDataInputs(GraphDef* graph, NodeMap* node_map, NodeDef* op,
315                      std::vector<InputDesc>* inputs) {
316   VLOG(2) << "GetDataInputs for node " << op->name();
317   NodeDef* inode = nullptr;
318   int output_index = 0;
319   for (const auto& input_name : op->input()) {
320     if (IsControlInput(input_name)) {
321       continue;
322     }
323     ParseNodeName(input_name, &output_index);
324     inode = nullptr;
325     inode = node_map->GetNode(input_name);
326     if (inode == nullptr) {
327       return errors::Internal("Did not find node ", input_name);
328     }
329     VLOG(2) << "inode " << inode->DebugString() << " output_index "
330             << output_index;
331     inputs->emplace_back(inode, output_index, op);
332   }
333   return Status::OK();
334 }
335 
DumpGraphToVLOG(const GraphDef & graph,int log_level)336 void DumpGraphToVLOG(const GraphDef& graph, int log_level) {
337   if (VLOG_IS_ON(log_level)) {
338     // VLOG may truncate lines so we print line by line.
339     for (const auto& line : str_util::Split(graph.DebugString(), "\n\r")) {
340       VLOG(log_level) << line;
341     }
342   }
343 }
344 
345 }  // namespace
346 
ExtendNodeAttr(StringPiece name,const std::vector<int32> & values,NodeDef * node_def)347 void ScopedAllocatorOptimizer::ExtendNodeAttr(StringPiece name,
348                                               const std::vector<int32>& values,
349                                               NodeDef* node_def) {
350   if (HasNodeAttr(*node_def, name)) {
351     VLOG(2) << "extending";
352     AttrValue* existing = &(*node_def->mutable_attr())[string(name)];
353     for (int32 i : values) {
354       existing->mutable_list()->add_i(i);
355     }
356   } else {
357     VLOG(2) << "setting new attr value";
358     AddNodeAttr(name, values, node_def);
359   }
360 }
361 
362 class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
363  public:
~UnaryElementwiseRewriter()364   ~UnaryElementwiseRewriter() override {}
365 
366   // Return non-OK if any input is an op that does not use the
367   // AllocatorAttributes set by executor to allocate its output.
CheckUsesAllocatorAttributes(const std::vector<InputDesc> & inputs)368   Status CheckUsesAllocatorAttributes(const std::vector<InputDesc>& inputs) {
369     for (const InputDesc& nd : inputs) {
370       if (IsConstant(*nd.from_node_def)) {
371         return errors::Aborted(
372             "Abandoning ScopedAllocatorOptimizer because input ",
373             nd.from_node_def->name(),
374             " is a Const op which does not use AllocatorAttributes");
375       }
376     }
377     return Status::OK();
378   }
379 
380   // Return non-OK if any input is already committed to a ScopedAllocator.
381   //
382   // We insert an identity to ensure that inputs are not committed to different
383   // scope ids in `MaybeRewriteInput`, so this function is basically a sanity
384   // check.
CheckExistingScopedAllocator(const std::vector<InputDesc> & inputs)385   Status CheckExistingScopedAllocator(const std::vector<InputDesc>& inputs) {
386     for (const InputDesc& nd : inputs) {
387       VLOG(2) << "get attrs for " << nd.from_node_def->name();
388       AttrSlice n_attrs = AttrSlice(*nd.from_node_def);
389       std::vector<int32> scope_ids;
390       Status ss = GetNodeAttr(n_attrs, kScopedAllocatorAttrName, &scope_ids);
391       // Check that both output name and output slot match.  It is okay to have
392       // different outputs of the input committed to different scope ids.
393       if (ss.ok() && scope_ids[0] == nd.output_slot) {
394         LOG(INFO) << "Abandoning ScopedAllocatorOptimizer because input "
395                   << nd.from_node_def->name() << " output " << scope_ids[0]
396                   << " is already assigned to scope_id " << scope_ids[1];
397         return errors::Aborted(
398             "Abandoning ScopedAllocatorOptimizer because input ",
399             nd.from_node_def->name(), " output ", scope_ids[0], " is already ",
400             "assigned to scope_id ", scope_ids[1]);
401       }
402     }
403     return Status::OK();
404   }
405 
406   // Return non-OK if any input is a member of op_set.
CheckInternalDataDependency(const std::set<string> & op_set,const std::vector<InputDesc> & inputs)407   Status CheckInternalDataDependency(const std::set<string>& op_set,
408                                      const std::vector<InputDesc>& inputs) {
409     for (const InputDesc& nd : inputs) {
410       if (op_set.find(nd.from_node_def->name()) != op_set.end()) {
411         if (nd.output_slot != tensorflow::Graph::kControlSlot) {
412           return errors::Aborted("Data edge exists between ",
413                                  nd.from_node_def->name(),
414                                  " and another "
415                                  "node in the set");
416         }
417       }
418     }
419     return Status::OK();
420   }
421 
422   // Remove all control edges between members of ops.
ClearInternalControlInputs(const std::set<string> & op_set,const std::vector<NodeDef * > & ops,NodeMap * node_map)423   void ClearInternalControlInputs(const std::set<string>& op_set,
424                                   const std::vector<NodeDef*>& ops,
425                                   NodeMap* node_map) {
426     for (NodeDef* n : ops) {
427       for (const auto& input_name : n->input()) {
428         if (IsControlInput(input_name)) {
429           int position = 0;
430           string input_node_name = ParseNodeName(input_name, &position);
431           CHECK_EQ(position, -1);
432           if (op_set.find(input_node_name) != op_set.end()) {
433             // This is an internal control edge.  Remove it.
434             VLOG(1) << "Remove control output from " << input_node_name
435                     << " via edge " << input_name << " to " << n->name();
436             TF_CHECK_OK(RemoveEdge(input_name, input_node_name, n, node_map));
437           }
438         }
439       }
440     }
441   }
442 
443   // Examine the input set of an op set, gathering their shapes and types
444   // and checking whether there are any considerations that prevent use
445   // of a single ScopedAllocator for all of those inputs.
AnalyzeInputs(ScopedAllocatorOptimizer * sa_opti,int64 invocation_count,GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,string * device_name,DataType * dtype,std::vector<TensorShape> * input_shapes,std::vector<InputDesc> * inputs,TensorShape * sa_shape)446   Status AnalyzeInputs(ScopedAllocatorOptimizer* sa_opti,
447                        int64 invocation_count, GraphDef* graph,
448                        NodeMap* node_map, const std::vector<NodeDef*>& ops,
449                        const std::set<string>& op_instance_names,
450                        string* device_name, DataType* dtype,
451                        std::vector<TensorShape>* input_shapes,
452                        std::vector<InputDesc>* inputs, TensorShape* sa_shape) {
453     CHECK(graph_properties_);
454     LOG_WARNING_AND_RETURN_IF_ERROR(
455         CheckTypesAndGetShapes(*graph_properties_, ops, dtype, input_shapes));
456     LOG_WARNING_AND_RETURN_IF_ERROR(
457         GetInputs(sa_opti, invocation_count, graph, *graph_properties_,
458                   sa_opti->node_map(), ops, *dtype, inputs));
459     LOG_WARNING_AND_RETURN_IF_ERROR(CheckUsesAllocatorAttributes(*inputs));
460     LOG_WARNING_AND_RETURN_IF_ERROR(CheckExistingScopedAllocator(*inputs));
461     LOG_WARNING_AND_RETURN_IF_ERROR(
462         CheckInternalDataDependency(op_instance_names, *inputs));
463     ClearInternalControlInputs(op_instance_names, ops, node_map);
464     *device_name = ops[0]->device();
465     CHECK(!device_name->empty());
466     CHECK(!input_shapes->empty());
467     CHECK_EQ(0, Allocator::kAllocatorAlignment % DataTypeSize(*dtype))
468         << "ScopedAllocatorOptimizer only applies to types that evenly "
469         << "divide kAllocatorAlignment";
470     std::vector<ScopedAllocator::Field> sa_fields;
471     // Calculate the field embedding boundaries and thereby the
472     // required size of the backing tensor.
473     int64 num_bytes = ScopedAllocatorMgr::PopulateFields(
474         0 /*scope_id*/, *input_shapes, *dtype, &sa_fields);
475     int64 num_elts = num_bytes / DataTypeSize(*dtype);
476     VLOG(2) << "num_bytes " << num_bytes << " num_elts=" << num_elts;
477     *sa_shape = TensorShape({num_elts});
478     return Status::OK();
479   }
480 
481   // Returns the set of all nodes that are transitively reachable via data or
482   // control edges starting at `source_nodes`.  Stop at the boundary of a frame.
TransitiveFanoutWithinFrame(GraphDef * graph,NodeMap * node_map,const std::vector<const NodeDef * > & source_nodes,absl::flat_hash_set<const NodeDef * > * fanout)483   Status TransitiveFanoutWithinFrame(
484       GraphDef* graph, NodeMap* node_map,
485       const std::vector<const NodeDef*>& source_nodes,
486       absl::flat_hash_set<const NodeDef*>* fanout) {
487     std::deque<const NodeDef*> queue(source_nodes.begin(), source_nodes.end());
488     absl::flat_hash_set<const NodeDef*> visited;
489     while (!queue.empty()) {
490       const NodeDef* node = queue.front();
491       queue.pop_front();
492       if (!visited.insert(node).second) {
493         continue;
494       }
495       fanout->insert(node);
496       for (const NodeDef* output : node_map->GetOutputs(node->name())) {
497         if (!ModifiesFrameInfo(*output)) {
498           queue.push_back(output);
499         }
500         VLOG(2) << "TransitiveFanout parent: " << node->name()
501                 << " child: " << output->name() << " of type " << output->op();
502       }
503     }
504 
505     return Status::OK();
506   }
507 
508   // Build the ScopedAllocator node that will be assigned to allocate
509   // the output tensors of the input node set.
ConstructScopedAllocatorNode(ScopedAllocatorOptimizer * sa_opti,GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const string & device_name,DataType dtype,int sa_id,const string & sa_name,const std::vector<TensorShape> & input_shapes,const std::vector<InputDesc> & inputs,const TensorShape & sa_shape)510   Status ConstructScopedAllocatorNode(
511       ScopedAllocatorOptimizer* sa_opti, GraphDef* graph, NodeMap* node_map,
512       const std::vector<NodeDef*>& ops, const string& device_name,
513       DataType dtype, int sa_id, const string& sa_name,
514       const std::vector<TensorShape>& input_shapes,
515       const std::vector<InputDesc>& inputs, const TensorShape& sa_shape) {
516     VLOG(2) << "ConstructScopedAllocatorNode " << sa_name;
517     NodeDefBuilder sa_builder(sa_name, "_ScopedAllocator");
518     sa_builder.Device(device_name);
519     sa_builder.Attr("sa_name", sa_name);
520     sa_builder.Attr("T", dtype);
521     sa_builder.Attr("id", sa_id);
522     sa_builder.Attr("shapes", input_shapes);
523     sa_builder.Attr("shape", sa_shape);
524     sa_builder.Attr("expected_call_count", static_cast<int64>(ops.size()));
525     NodeDef* sa_node = graph->add_node();
526     LOG_WARNING_AND_RETURN_IF_ERROR(sa_builder.Finalize(sa_node));
527     node_map->AddNode(sa_name, sa_node);
528 
529     std::vector<const NodeDef*> fanout_sources;
530     fanout_sources.reserve(inputs.size());
531     for (const auto& input : inputs) {
532       fanout_sources.push_back(input.from_node_def);
533     }
534     absl::flat_hash_set<const NodeDef*> fanout;
535     TF_RETURN_IF_ERROR(
536         TransitiveFanoutWithinFrame(graph, node_map, fanout_sources, &fanout));
537 
538     // Add control edges from the ScopedAllocatorOp to all of the
539     // input nodes and mark them for allocation from backing tensor.
540     for (int i = 0, end = inputs.size(); i < end; ++i) {
541       auto& nd = inputs[i];
542       if (IsArg(*nd.from_node_def)) {
543         return errors::Aborted(
544             "ScopedAllocatorOptimizer does not work well when the op inputs "
545             "are _Arg ops; skipping this optimizer for this function");
546       }
547       VLOG(2) << "To input " << i << ": " << nd.from_node_def->name()
548               << " add control input "
549               << "^" << sa_name;
550       nd.from_node_def->add_input(strings::StrCat("^", sa_name));
551       // This attribute says: allocate output_slot from
552       // ScopedAllocator instance sa_id + 1 + i.
553       ScopedAllocatorOptimizer::ExtendNodeAttr(kScopedAllocatorAttrName,
554                                                {nd.output_slot, sa_id + 1 + i},
555                                                nd.from_node_def);
556       node_map->AddOutput(sa_name, nd.from_node_def->name());
557     }
558 
559     // We add control edges in order to (1) delay execution of the
560     // ScopedAllocatorOp until just before first use in order to conserve memory
561     // (2) ensure correctness in the presence of control flow related ops.
562     bool added_delay_edge = false;
563     for (auto& nd : inputs) {
564       std::vector<InputDesc> inputs_to_first;
565       LOG_WARNING_AND_RETURN_IF_ERROR(GetDataInputs(
566           graph, sa_opti->node_map(), nd.from_node_def, &inputs_to_first));
567       for (int i = 0, end = inputs_to_first.size(); i < end; ++i) {
568         if (fanout.find(inputs_to_first[i].from_node_def) != fanout.end()) {
569           VLOG(2) << "Found node " << inputs_to_first[i].from_node_def->name()
570                   << " in the fanout of " << sa_name;
571           continue;
572         }
573         sa_node->add_input(
574             strings::StrCat("^", inputs_to_first[i].from_node_def->name()));
575         node_map->AddOutput(inputs_to_first[i].from_node_def->name(), sa_name);
576         added_delay_edge = true;
577         VLOG(2) << "Adding control dependency from "
578                 << inputs_to_first[i].from_node_def->name() << " to "
579                 << sa_node->name();
580         break;
581       }
582       if (added_delay_edge) {
583         break;
584       }
585     }
586 
587     if (!added_delay_edge) {
588       LOG(WARNING) << "Found no node from which a control edge can be added to "
589                       "scoped allocator node.  If you run into issues with "
590                       "graphs that contain control flow, turn off the "
591                       "ScopedAllocatorOptimizer and file a bug.";
592     }
593 
594     return Status::OK();
595   }
596 
BuildSAConcatNode(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,const string & device_name,DataType dtype,int sa_id,const string & sa_name,const string & sac_name,const TensorShape & sa_shape,std::vector<NodeDefBuilder::NodeOut> * sac_inputs)597   Status BuildSAConcatNode(GraphDef* graph, NodeMap* node_map,
598                            const std::vector<NodeDef*>& ops,
599                            const std::set<string>& op_instance_names,
600                            const string& device_name, DataType dtype, int sa_id,
601                            const string& sa_name, const string& sac_name,
602                            const TensorShape& sa_shape,
603                            std::vector<NodeDefBuilder::NodeOut>* sac_inputs) {
604     VLOG(2) << "BuildSAConcatNode " << sac_name;
605     // control input: edge name -> source node name
606     absl::flat_hash_map<string, string> sac_ctl_inputs;
607     for (int i = 0, end = ops.size(); i < end; ++i) {
608       NodeDef* old_op = ops[i];
609       for (const string& old_op_input : old_op->input()) {
610         int position = 0;
611         string input_name = ParseNodeName(old_op_input, &position);
612         if (position == -1) {
613           // A control input: drop if from another member of the op set.
614           if (op_instance_names.find(old_op_input) == op_instance_names.end()) {
615             sac_ctl_inputs.emplace(old_op_input, input_name);
616           }
617         } else {
618           // TODO(tucker): remove redundant check.
619           // A data input: illegal if from another member of the op set.
620           if (op_instance_names.find(old_op_input) != op_instance_names.end()) {
621             LOG(ERROR) << "Data edge between " << old_op_input << " and "
622                        << old_op->name() << " cannot build ScopedAllocator.";
623             return errors::Aborted("Data edge between ", old_op_input, " and ",
624                                    old_op->name(),
625                                    " cannot build ScopedAllocator.");
626           }
627           sac_inputs->push_back(
628               NodeDefBuilder::NodeOut(old_op_input, 0, dtype));
629         }
630         VLOG(3) << "from op " << i << ": " << old_op->name()
631                 << " sac_inputs append " << old_op_input;
632       }
633     }
634     NodeDefBuilder sac_builder(sac_name, "_ScopedAllocatorConcat");
635     VLOG(2) << "New sac_name " << sac_name << " shape "
636             << sa_shape.DebugString();
637     sac_builder.Device(device_name);
638     sac_builder.Attr("sa_name", sa_name);
639     sac_builder.Attr("id", sa_id);
640     sac_builder.Attr("T", dtype);
641     sac_builder.Attr("shape", sa_shape);
642     sac_builder.Attr("N", static_cast<int>(sac_inputs->size()));
643     sac_builder.Input(NodeDefBuilder::NodeOut(sa_name, 0, dtype));
644     sac_builder.Input(*sac_inputs);
645     NodeDef* sac_node = graph->add_node();
646     LOG_WARNING_AND_RETURN_IF_ERROR(sac_builder.Finalize(sac_node));
647     node_map->AddNode(sac_name, sac_node);
648     node_map->AddOutput(sa_name, sac_name);
649 
650     // Attach the old control inputs to the new sac node.
651     for (const auto& ctl_input : sac_ctl_inputs) {
652       const auto& ctl_edge = ctl_input.first;
653       const auto& input_name = ctl_input.second;
654       sac_node->add_input(ctl_edge);
655       node_map->AddOutput(input_name, sac_node->name());
656     }
657     return Status::OK();
658   }
659 
BuildReplacementOp(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const string & device_name,DataType dtype,const string & op_name,const string & sac_name,const string & sa_op_name)660   Status BuildReplacementOp(GraphDef* graph, NodeMap* node_map,
661                             const std::vector<NodeDef*>& ops,
662                             const string& device_name, DataType dtype,
663                             const string& op_name, const string& sac_name,
664                             const string& sa_op_name) {
665     VLOG(2) << "BuildReplacementOp " << sa_op_name;
666     NodeDefBuilder op_builder(sa_op_name, op_name);
667     op_builder.Device(device_name);
668 
669     // Transfer the Node Attr from the first replaced Node to the new
670     // Node.  TODO(tucker): In principle we should verify that
671     // the Attr are consistent and compatible across all op instances.
672     // Unfortunately that will probably require op-specific tests, so
673     // punt on that for the time being.
674     AttrSlice first_slice(*ops[0]);
675     for (auto& it : first_slice) {
676       op_builder.Attr(it.first, it.second);
677     }
678     op_builder.Attr("_forward_input", {0, 0});
679     op_builder.Input(sac_name, 0, dtype);
680     NodeDef* sa_op_node = graph->add_node();
681     LOG_WARNING_AND_RETURN_IF_ERROR(op_builder.Finalize(sa_op_node));
682     node_map->AddNode(sa_op_name, sa_op_node);
683     node_map->AddOutput(sac_name, sa_op_name);
684     return Status::OK();
685   }
686 
BuildSplitNode(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::vector<TensorShape> & input_shapes,const std::vector<NodeDefBuilder::NodeOut> & sac_inputs,const string & device_name,DataType dtype,const string & op_name,int sa_id,const string & sas_name,const string & sa_name,const string & sa_op_name)687   Status BuildSplitNode(GraphDef* graph, NodeMap* node_map,
688                         const std::vector<NodeDef*>& ops,
689                         const std::vector<TensorShape>& input_shapes,
690                         const std::vector<NodeDefBuilder::NodeOut>& sac_inputs,
691                         const string& device_name, DataType dtype,
692                         const string& op_name, int sa_id,
693                         const string& sas_name, const string& sa_name,
694                         const string& sa_op_name) {
695     VLOG(2) << "new ScopedAllocatorSplit " << sas_name;
696     NodeDefBuilder sas_builder(sas_name, "_ScopedAllocatorSplit");
697     sas_builder.Device(device_name);
698     sas_builder.Attr("sa_name", sa_name);
699     sas_builder.Attr("id", sa_id);
700     sas_builder.Attr("T", dtype);
701     sas_builder.Attr("shapes", input_shapes);
702     std::vector<NodeDefBuilder::NodeOut> sas_inputs = sac_inputs;
703     sas_builder.Attr("N", static_cast<int>(sas_inputs.size()));
704     sas_builder.Input(NodeDefBuilder::NodeOut(sa_op_name, 0, dtype));
705     sas_builder.Input(sas_inputs);
706     NodeDef* sas_node = graph->add_node();
707     LOG_WARNING_AND_RETURN_IF_ERROR(sas_builder.Finalize(sas_node));
708     node_map->AddNode(sas_name, sas_node);
709     node_map->AddOutput(sa_op_name, sas_name);
710     for (const auto& input : sas_inputs) {
711       node_map->AddOutput(input.node, sas_name);
712     }
713     return Status::OK();
714   }
715 
716   // After the new ScopedAllocator and its corresponding Concat and
717   // Split nodes have been built, and a new single Op instance
718   // constructed, rewire the graph: Remove input edges to the old Op
719   // nodes and replace the old Op node outputs with the corresponding
720   // ScopedAllocatorSplit node outputs.  After this the old Op nodes
721   // should no longer have any input or output edges and they can be
722   // removed from the graph.
RewireSubgraph(GraphDef * graph,NodeMap * node_map,const std::vector<NodeDef * > & ops,const std::set<string> & op_instance_names,const string & op_name,const string & sas_name)723   Status RewireSubgraph(GraphDef* graph, NodeMap* node_map,
724                         const std::vector<NodeDef*>& ops,
725                         const std::set<string>& op_instance_names,
726                         const string& op_name, const string& sas_name) {
727     VLOG(2) << "RewireSubgraph";
728     for (int op_idx = 0, idx_limit = ops.size(); op_idx < idx_limit; ++op_idx) {
729       NodeDef* old_op = ops[op_idx];
730       // Copy the output node set since we'll be modifying the version
731       // maintained by NodeMap in the loop.
732       auto output_nodes = node_map->GetOutputs(old_op->name());
733       VLOG(3) << "old_op " << old_op->name() << " had " << output_nodes.size()
734               << " outputs.  Moving them to the ScopedAllocatorSplit node.";
735       if (VLOG_IS_ON(2)) {
736         for (NodeDef* n : output_nodes) {
737           VLOG(3) << "    output: " << n->name();
738         }
739       }
740       for (NodeDef* n : output_nodes) {
741         VLOG(3) << "really checking old output " << n->name()
742                 << " for corresponding input.";
743         if (op_instance_names.find(n->name()) != op_instance_names.end()) {
744           // If this output node is a member of the ops set, it must have
745           // been an internal control edge so drop it.
746           VLOG(3) << "Dropping control output from " << old_op->name() << " to "
747                   << n->name();
748           // However, we may already have dropped it at the clear() below,
749           // so if we fail to find it, that's okay.
750           Status ignore = RemoveEdge(strings::StrCat("^", old_op->name()),
751                                      old_op->name(), n, node_map);
752           continue;
753         }
754         bool found = false;
755         VLOG(3) << "about to iterate over " << n->input_size() << " inputs";
756         for (int i = 0; i < n->input_size(); ++i) {
757           VLOG(3) << "input " << n->input(i);
758           int position = 0;
759           string input_node = ParseNodeName(n->input(i), &position);
760           if (input_node == old_op->name()) {
761             found = true;
762             VLOG(3) << "match pos=" << position;
763             if (position == -1) {
764               // It was a control edge
765               *n->mutable_input(i) = strings::StrCat("^", sas_name);
766             } else {
767               CHECK_EQ(0, position)
768                   << "name " << n->input(i) << " pos " << position;
769               *n->mutable_input(i) = strings::StrCat(sas_name, ":", op_idx);
770             }
771             node_map->UpdateInput(n->name(), old_op->name(), sas_name);
772             VLOG(3) << "breaking on success";
773             break;
774           } else {
775             VLOG(3) << "other input " << n->input(i);
776           }
777         }
778         // In general it's required that we found the output node's old
779         // input and replaced it, but one exception is if the output node
780         // is of the same type being coalesced and the edge is a control
781         // input.  In that case it probably got eliminated in an earlier
782         // pass.
783         VLOG(3) << "before HasOp";
784         if (!HasOpName(n->name(), op_name)) {
785           CHECK(found) << "old_op " << old_op->name() << " node "
786                        << " could not find input edge on " << n->DebugString()
787                        << " to replace."
788                        << " " << op_name << " not in " << n->name();
789         }
790         VLOG(3) << "bottom of for output_nodes";
791       }
792       VLOG(3) << "Clearing all inputs of " << old_op->name();
793       node_map->RemoveInputs(old_op->name());
794       old_op->clear_input();
795       node_map->RemoveOutputs(old_op->name());
796       VLOG(3) << "after clear: " << old_op->DebugString();
797       // old_op should be dead, with no further inputs or outputs.
798       // It needs to be removed altogether before the graph is generated,
799       // but we need to leave it around until this Optimizer is done,
800       // because there may be some
801       // Remove.
802       RemoveNode(old_op, graph, node_map);
803     }
804     return Status::OK();
805   }
806 
807   // Given a collection of instances of op_name, presumed to be
808   // logically parallel and operating on tensors of the same type,
809   // replace them by a single instance.  First find the upstream Ops
810   // generating their inputs. Create a new ScopedAllocatorOp that
811   // outputs a single backing_tensor pre-arranged for sub-allocation
812   // of all of those input tensors.  Then insert a new
813   // ScopedAllocatorConcatOp below the upstream Ops to make explicit
814   // the materialization of a concatenation of their outputs.  Put the
815   // new op_name instance below the new concat op and follow with a
816   // ScopedAllocatorSplitOp that restores the correct shape outputs
817   // for the consumers of the old op_name instances.
818   //
819   // There must be no non-control edges between Nodes in 'ops'.
820   // Control edges among these nodes will be dropped.
Rewrite(ScopedAllocatorOptimizer * sa_opti,int64 invocation_count,GraphDef * graph,const string & op_name,const std::vector<NodeDef * > & ops,bool * applied)821   Status Rewrite(ScopedAllocatorOptimizer* sa_opti, int64 invocation_count,
822                  GraphDef* graph, const string& op_name,
823                  const std::vector<NodeDef*>& ops, bool* applied) override {
824     if (VLOG_IS_ON(1)) {
825       VLOG(1) << "Rewrite";
826       string op_names;
827       for (auto& nd : ops) {
828         strings::StrAppend(&op_names, nd->name(), ", ");
829       }
830       VLOG(1) << "UnaryElementwiseRewriter::Rewrite " << op_name
831               << " to: " << op_names;
832     }
833     NodeMap* node_map = sa_opti->node_map();
834 
835     // Make a set of the node names for faster membership testing.
836     std::set<string> op_instance_names;
837     for (auto& nd : ops) {
838       op_instance_names.insert(nd->name());
839       VLOG(2) << "op_instance_name " << nd->name();
840     }
841     DataType dtype;
842     std::vector<TensorShape> input_shapes;
843     std::vector<InputDesc> inputs;
844     TensorShape sa_shape;
845     string device_name;
846 
847     TF_RETURN_IF_ERROR(AnalyzeInputs(
848         sa_opti, invocation_count, graph, node_map, ops, op_instance_names,
849         &device_name, &dtype, &input_shapes, &inputs, &sa_shape));
850 
851     int sa_id = sa_opti->NewScopedAllocatorId(input_shapes.size());
852     string sa_name =
853         strings::StrCat("scoped_allocator_", sa_id, "_", invocation_count);
854     TF_RETURN_IF_ERROR(ConstructScopedAllocatorNode(
855         sa_opti, graph, node_map, ops, device_name, dtype, sa_id, sa_name,
856         input_shapes, inputs, sa_shape));
857 
858     // Build a ScopedAllocatorConcat below all of the input nodes.
859     std::vector<NodeDefBuilder::NodeOut> sac_inputs;
860     string sac_name = strings::StrCat("scoped_allocator_concat_", sa_id, "_",
861                                       invocation_count);
862     TF_RETURN_IF_ERROR(BuildSAConcatNode(
863         graph, node_map, ops, op_instance_names, device_name, dtype, sa_id,
864         sa_name, sac_name, sa_shape, &sac_inputs));
865 
866     // Construct a new instance of the parallel op and insert it
867     // immediately below the new ScopedAllocatorConcat.
868     string sa_op_name = strings::StrCat(sa_name, "_", op_name);
869     TF_RETURN_IF_ERROR(BuildReplacementOp(graph, node_map, ops, device_name,
870                                           dtype, op_name, sac_name,
871                                           sa_op_name));
872 
873     // Build a ScopedAllocatorSplit split below the new Op.
874     string sas_name = strings::StrCat("scoped_allocator_split_", sa_id, "_",
875                                       invocation_count);
876     TF_RETURN_IF_ERROR(BuildSplitNode(graph, node_map, ops, input_shapes,
877                                       sac_inputs, device_name, dtype, op_name,
878                                       sa_id, sas_name, sa_name, sa_op_name));
879 
880     // Rewire the graph.
881     TF_RETURN_IF_ERROR(RewireSubgraph(graph, node_map, ops, op_instance_names,
882                                       op_name, sas_name));
883 
884     *applied = true;
885     return Status::OK();
886   }
887 };
888 
ScopedAllocatorOptimizer(RewriterConfig::Toggle opt_level,const ScopedAllocatorOptions & opts)889 ScopedAllocatorOptimizer::ScopedAllocatorOptimizer(
890     RewriterConfig::Toggle opt_level, const ScopedAllocatorOptions& opts)
891     : opt_level_(opt_level) {
892   VLOG(1) << "ScopedAllocatorOptimizer::ScopedAllocatorOptimizer";
893   Rewriter* r = new UnaryElementwiseRewriter();
894   to_delete_.push_back(r);
895   if (opts.enable_op_size() == 0) {
896     // Opts handled by default:
897     for (const auto& op_name : {"CollectiveReduce"}) {
898       op_name_set_.insert(op_name);
899       rewriters_[op_name] = r;
900     }
901   } else {
902     for (const auto& op_name : opts.enable_op()) {
903       op_name_set_.insert(op_name);
904       rewriters_[op_name] = r;
905     }
906   }
907 }
908 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)909 Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/,
910                                           const GrapplerItem& item,
911                                           GraphDef* optimized_graph) {
912   VLOG(3) << "Input graph:";
913   DumpGraphToVLOG(item.graph, /*log_level=*/3);
914 
915   // Nodes that cannot be removed from the graph without damaging correctness,
916   // typically fetch nodes.
917   nodes_to_preserve_ = item.NodesToPreserve();
918 
919   GraphProperties graph_properties(item);
920   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
921   LOG_WARNING_AND_RETURN_IF_ERROR(graph_properties.InferStatically(
922       assume_valid_feeds, /*aggressive_shape_inference=*/false,
923       /*include_tensor_values=*/false));
924   *optimized_graph = item.graph;
925   node_map_ = absl::make_unique<NodeMap>(optimized_graph);
926 
927   LOG_WARNING_AND_RETURN_IF_ERROR(ScopedAllocatorOptimizer::ProcessGraphDef(
928       optimized_graph, graph_properties));
929 
930   VLOG(1) << "ScopedAllocatorOptimizer::Optimize() done";
931   VLOG(3) << "Optimized graph:";
932   DumpGraphToVLOG(*optimized_graph, /*log_level=*/3);
933   return Status::OK();
934 }
935 
GetRewriter(const string & op_name)936 ScopedAllocatorOptimizer::Rewriter* ScopedAllocatorOptimizer::GetRewriter(
937     const string& op_name) {
938   auto it = rewriters_.find(op_name);
939   if (it != rewriters_.end()) {
940     return it->second;
941   }
942   return nullptr;
943 }
944 
NewScopedAllocatorId(int num_fields)945 int ScopedAllocatorOptimizer::NewScopedAllocatorId(int num_fields) {
946   CHECK_GT(num_fields, 0);
947   int id = next_sa_id_;
948   next_sa_id_ += (num_fields + 1);
949   CHECK_GT(next_sa_id_, 0);
950   return id;
951 }
952 
NewIdentityId(int * id)953 Status ScopedAllocatorOptimizer::NewIdentityId(int* id) {
954   *id = next_identity_id_++;
955   if (next_identity_id_ < 0) {
956     return errors::Aborted("NewIdentityId overflow");
957   }
958   return Status::OK();
959 }
960 
~ScopedAllocatorOptimizer()961 ScopedAllocatorOptimizer::~ScopedAllocatorOptimizer() {
962   for (auto ptr : to_delete_) {
963     delete ptr;
964   }
965 }
966 
FindOpOccurrences(GraphDef * graph,const OpNameSet & op_names,GraphOpOccurrences * occs)967 void ScopedAllocatorOptimizer::FindOpOccurrences(GraphDef* graph,
968                                                  const OpNameSet& op_names,
969                                                  GraphOpOccurrences* occs) {
970   VLOG(1) << "FindOpOccurrences ";
971   for (const auto& it : op_names) {
972     VLOG(1) << "search target " << it;
973   }
974   for (int ni = 0; ni < graph->node_size(); ++ni) {
975     NodeDef* node = graph->mutable_node(ni);
976     const string& op_name = node->op();
977     if (op_names.find(op_name) != op_names.end()) {
978       VLOG(1) << "found " << op_name << " on dev " << node->device();
979       (*occs)[node->device()][op_name].push_back(node);
980     }
981   }
982 }
983 
984 namespace {
985 struct OpNameOrder {
operator ()tensorflow::grappler::__anone24104b20211::OpNameOrder986   bool operator()(const NodeDef* a, const NodeDef* b) {
987     return a->name() <= b->name();
988   }
989 };
990 
991 class Tree {
992  public:
Tree(const string & edge,int depth)993   Tree(const string& edge, int depth) : edge_(edge), depth_(depth) {}
~Tree()994   ~Tree() {
995     for (const auto& it : subtrees_) delete it.second;
996   }
997 
GetSubTree(const string & edge)998   Tree* GetSubTree(const string& edge) {
999     auto it = subtrees_.find(edge);
1000     if (it != subtrees_.end()) {
1001       return it->second;
1002     }
1003     Tree* t = new Tree(edge, depth_ + 1);
1004     subtrees_[edge] = t;
1005     return t;
1006   }
1007 
InsertNode(NodeDef * n)1008   void InsertNode(NodeDef* n) { nodes_.push_back(n); }
1009 
1010   string edge_;
1011   int depth_;
1012   std::vector<NodeDef*> nodes_;
1013   absl::flat_hash_map<string, Tree*> subtrees_;
1014 };
1015 
1016 // Applies a function to every Tree in DFS order.  Terminates early
1017 // on any non-OK Status.
ApplyToAll(Tree * tree,const std::function<Status (Tree *)> & func)1018 Status ApplyToAll(Tree* tree, const std::function<Status(Tree*)>& func) {
1019   Status s;
1020   for (const auto& it : tree->subtrees_) {
1021     s = ApplyToAll(it.second, func);
1022     if (!s.ok()) return s;
1023   }
1024   s = func(tree);
1025   return s;
1026 }
1027 
ComputeScopeTree(const string & op_name,const std::vector<NodeDef * > & node_vec)1028 Tree* ComputeScopeTree(const string& op_name,
1029                        const std::vector<NodeDef*>& node_vec) {
1030   Tree* root = new Tree("", 0);
1031   for (NodeDef* n : node_vec) {
1032     std::vector<string> pieces = str_util::Split(n->name(), "/");
1033     // last piece is node name proper.
1034     int depth = pieces.size() - 1;
1035     Tree* subtree = root;
1036     for (int i = 0; i < depth; ++i) {
1037       subtree = subtree->GetSubTree(pieces[i]);
1038     }
1039     subtree->InsertNode(n);
1040   }
1041   return root;
1042 }
1043 
PartitionByLoopStructure(const FrameView & frame_view,std::vector<NodeDef * > nodes,std::vector<std::vector<NodeDef * >> * loop_groups)1044 void PartitionByLoopStructure(const FrameView& frame_view,
1045                               std::vector<NodeDef*> nodes,
1046                               std::vector<std::vector<NodeDef*>>* loop_groups) {
1047   // It is assumed that two nodes with identical loop containment have
1048   // identical integer vectors. Represent those by 64 bit hashes.
1049   absl::flat_hash_map<uint64, std::vector<NodeDef*>> loop_sets;
1050   for (NodeDef* nd : nodes) {
1051     uint64 hash = 0;
1052     const std::vector<int>& loop_ids = frame_view.Frames(*nd);
1053     for (int id : loop_ids) {
1054       hash = Hash64Combine(hash, static_cast<uint64>(id));
1055     }
1056     loop_sets[hash].push_back(nd);
1057   }
1058   for (auto it : loop_sets) {
1059     loop_groups->push_back(std::move(it.second));
1060   }
1061 }
1062 
1063 // Identify outputs that are inputs to multiple sets of nodes.
IdentifyRepeatedInputs(const std::vector<NodeDef * > & nodes,absl::flat_hash_set<string> * seen_outputs,absl::flat_hash_set<string> * repeated_outputs)1064 void IdentifyRepeatedInputs(const std::vector<NodeDef*>& nodes,
1065                             absl::flat_hash_set<string>* seen_outputs,
1066                             absl::flat_hash_set<string>* repeated_outputs) {
1067   for (NodeDef* node : nodes) {
1068     for (const auto& input_name : node->input()) {
1069       if (!seen_outputs->insert(input_name).second) {
1070         repeated_outputs->insert(input_name);
1071       }
1072     }
1073   }
1074 }
1075 
1076 }  // namespace
1077 
ProcessGraphDef(GraphDef * graph,const GraphProperties & graph_properties)1078 Status ScopedAllocatorOptimizer::ProcessGraphDef(
1079     GraphDef* graph, const GraphProperties& graph_properties) {
1080   // Nodes created by this optimizer have the IsStateful() property
1081   // which means their names must be globally unique within a process,
1082   // so we include an optimizer invocation count in every generated
1083   // name.
1084   static std::atomic<int64> invocation_counter(1);
1085   const int64 invocation_count =
1086       invocation_counter.fetch_add(1, std::memory_order_seq_cst);
1087   VLOG(1) << "ProcessGraphDef " << invocation_count;
1088   Status status;
1089   GraphOpOccurrences occ;
1090   FindOpOccurrences(graph, op_name_set_, &occ);
1091   if (!occ.empty()) {
1092     FrameView frame_view;
1093     // TODO(ezhulenev): Pass a GraphView when this optimizer will be migrated
1094     // from NodeMap.
1095     LOG_WARNING_AND_RETURN_IF_ERROR(frame_view.InferFromGraph(*graph));
1096 
1097     for (auto& dt : occ) {
1098       VLOG(2) << "Processing device " << dt.first;
1099       const DevOpOccurrences& dev_occ = dt.second;
1100       for (auto& it : dev_occ) {
1101         string op_name = it.first;
1102         VLOG(1) << "Processing " << op_name << " set size " << it.second.size();
1103         Rewriter* rewriter = GetRewriter(op_name);
1104         if (!rewriter) {
1105           LOG(ERROR) << "Failed to find Rewriter in ScopedAllocatorOptimizer "
1106                      << "for op_name " << op_name;
1107           continue;
1108         }
1109         rewriter->SetGraphProperties(graph_properties);
1110         std::unique_ptr<Tree> root(ComputeScopeTree(it.first, it.second));
1111         // Record outputs that are inputs to multiple Tree nodes.
1112         absl::flat_hash_set<string> seen_outputs;
1113         status = ApplyToAll(root.get(), [this, &seen_outputs](Tree* t) {
1114           IdentifyRepeatedInputs(t->nodes_, &seen_outputs, &repeated_outputs_);
1115           return Status::OK();
1116         });
1117         if (!status.ok()) {
1118           break;
1119         }
1120         // Nodes with a common depth and root path are now grouped
1121         // in the same Tree struct.  Split those groups into subgroups that
1122         // share identical loop nesting.
1123         status = ApplyToAll(root.get(), [this, rewriter, graph, &frame_view,
1124                                          &op_name, invocation_count](Tree* t) {
1125           VLOG(2) << "applied to tree node " << t->edge_ << " at depth "
1126                   << t->depth_ << " of size " << t->nodes_.size();
1127           if (t->nodes_.size() > 1) {
1128             std::vector<std::vector<NodeDef*>> loop_groups;
1129             PartitionByLoopStructure(frame_view, t->nodes_, &loop_groups);
1130             for (auto& lg : loop_groups) {
1131               if (lg.size() > 1) {
1132                 bool applied = false;
1133                 Status s = OrderNodeSet(&lg);
1134                 TF_RETURN_IF_ERROR(s);
1135                 VLOG(1) << "Applying Rewriter for " << op_name;
1136                 s = rewriter->Rewrite(this, invocation_count, graph, op_name,
1137                                       lg, &applied);
1138                 LOG_WARNING_AND_RETURN_IF_ERROR(s);
1139               }
1140             }
1141           }
1142           return Status::OK();
1143         });
1144         if (!status.ok()) {
1145           break;
1146         }
1147       }
1148       if (!status.ok()) {
1149         break;
1150       }
1151     }
1152   }
1153   VLOG(1) << "ScopedAllocatorOptimizer returning " << status;
1154   if (!status.ok()) {
1155     LOG(ERROR) << "ScopedAllocatorOptimizer: " << status;
1156   }
1157   return status;
1158 }
1159 
1160 namespace {
1161 struct InstanceKeyLess {
operator ()tensorflow::grappler::__anone24104b20511::InstanceKeyLess1162   bool operator()(const NodeDef* a, const NodeDef* b) const {
1163     AttrSlice a_attrs = AttrSlice(*a);
1164     AttrSlice b_attrs = AttrSlice(*b);
1165     int32 a_key = -1;
1166     int32 b_key = -1;
1167     Status s = GetNodeAttr(a_attrs, "instance_key", &a_key);
1168     CHECK(s.ok());
1169     s = GetNodeAttr(b_attrs, "instance_key", &b_key);
1170     CHECK(s.ok());
1171     return a_key < b_key;
1172   }
1173 };
1174 
1175 struct NameLess {
operator ()tensorflow::grappler::__anone24104b20511::NameLess1176   bool operator()(const NodeDef* a, const NodeDef* b) const {
1177     return a->name() < b->name();
1178   }
1179 };
1180 
IsCollectiveNode(const NodeDef & n)1181 bool IsCollectiveNode(const NodeDef& n) {
1182   AttrSlice attrs = AttrSlice(n);
1183   int key = -1;
1184   if (!IsCollective(n)) return false;
1185   Status s = GetNodeAttr(attrs, "instance_key", &key);
1186   if (s.ok() && key >= 0) {
1187     return true;
1188   }
1189   return false;
1190 }
1191 }  // namespace
1192 
OrderNodeSet(std::vector<NodeDef * > * nodes) const1193 Status ScopedAllocatorOptimizer::OrderNodeSet(
1194     std::vector<NodeDef*>* nodes) const {
1195   // Nodes should be identical type.  Default order is by name but for
1196   // collectives we order by increasing instance_key so each group gets
1197   // the same instance_key.
1198   if (nodes->size() <= 1) return Status::OK();
1199   if (IsCollectiveNode(*nodes->at(0))) {
1200     sort(nodes->begin(), nodes->end(), InstanceKeyLess());
1201   } else {
1202     sort(nodes->begin(), nodes->end(), NameLess());
1203   }
1204   return Status::OK();
1205 }
1206 
1207 }  // namespace grappler
1208 }  // namespace tensorflow
1209 
1210 #undef LOG_WARNING_AND_RETURN_IF_ERROR
1211