1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
17 #include "absl/container/flat_hash_set.h"
18 #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
19 
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/cc/framework/ops.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/device_base.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/graph_to_functiondef.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/node_builder.h"
33 #include "tensorflow/core/grappler/mutable_graph_view.h"
34 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
35 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
36 #include "tensorflow/core/grappler/utils.h"
37 #include "tensorflow/core/grappler/utils/functions.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39 
40 namespace tensorflow {
41 namespace grappler {
42 namespace vectorization_utils {
43 
44 namespace {
45 
46 // Describes a tensor with its operation Node and output position
47 typedef std::pair<Node*, int> TensorDesc;
48 
49 constexpr char kRetValOp[] = "_Retval";
50 
ReplaceEdgeSources(const TensorDesc & old_src,const TensorDesc & new_src,Graph * graph)51 void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
52                         Graph* graph) {
53   // NOTE: We need two for loops here because we can't mutate the set of output
54   // edges as we iterate over them.
55   std::vector<const Edge*> edges_to_replace;
56   for (auto edge : old_src.first->out_edges()) {
57     if (edge->src_output() == old_src.second) {
58       edges_to_replace.push_back(edge);
59     }
60   }
61   for (auto edge : edges_to_replace) {
62     graph->AddEdge(new_src.first, new_src.second, edge->dst(),
63                    edge->dst_input());
64     graph->RemoveEdge(edge);
65   }
66 }
67 
68 // Update node attrs to keep its properties consistent with the function
UpdateMapDefunAttrs(FunctionBody * map_defun_fn,Node * map_defun_node)69 void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) {
70   map_defun_node->AddAttr("output_types", map_defun_fn->ret_types);
71 
72   // TODO(rachelim): Propagate precise shapes if they're known, which may enable
73   // subsequent optimizations.
74   map_defun_node->AddAttr("output_shapes", std::vector<PartialTensorShape>(
75                                                map_defun_fn->ret_types.size()));
76 }
77 
AddMapDefunOutput(FunctionBody * map_defun_fn,Node * map_defun_node,const TensorDesc & output)78 Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
79                          const TensorDesc& output) {
80   DataType type = output.first->output_type(output.second);
81   int index = map_defun_fn->ret_nodes.size();
82 
83   NodeDef ret_node_def;
84   ret_node_def.set_name("map_out");
85   ret_node_def.set_op(kRetValOp);
86   AddNodeAttr("T", type, &ret_node_def);
87   AddNodeAttr("index", index, &ret_node_def);
88 
89   Status s;
90   Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
91   TF_RETURN_IF_ERROR(s);
92 
93   map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
94   map_defun_fn->ret_nodes.push_back(ret_node);
95   map_defun_fn->ret_types.push_back(type);
96   UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
97 
98   return s;
99 }
100 
RemoveMapDefunOutput(int output_position,Graph * outer_scope,FunctionBody * map_defun_fn,Node * map_defun_node)101 void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
102                           FunctionBody* map_defun_fn, Node* map_defun_node) {
103   DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
104       << "Trying to remove output that doesn't exist. Output number: "
105       << output_position;
106 
107   int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
108 
109   // Modify map_defun_fn's signature and remove the output node from its graph
110   map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
111   map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
112                                 output_position);
113   map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
114                                 output_position);
115   UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
116 
117   // Renumber the nodes and edges that come after
118   for (int i = 0; i < num_later_outputs; ++i) {
119     ReplaceEdgeSources({map_defun_node, output_position + i + 1},
120                        {map_defun_node, output_position + i}, outer_scope);
121     // Each ret node has an "index" attr that has to be updated
122     map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
123                                                           output_position + i);
124   }
125 }
126 
127 // Helper class that vectorizes the body of a MapDefun node, adding new
128 // operations to the graph that collectively compute the same value as what
129 // running the MapDefun function on slices of the input would produce.
130 // This class transforms the input FunctionDefs into their corresponding
131 // Graph objects and works on the graphs directly, then converts them back
132 // to FunctionDefs when GetResult is called.
133 // TODO(rachelim): Move this to its own header.
134 class Vectorization {
135  public:
Vectorization(FunctionDefLibrary * lib)136   explicit Vectorization(FunctionDefLibrary* lib)
137       : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
138 
139   // Adds the vectorized function and new map_defun_fn to lib, and points
140   // vectorized_function to the former. Returns an error status if
141   // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
142   // along the way.
143   Status Vectorize(const FunctionDef& outer_scope,
144                    const NodeDef& map_defun_node, FunctionDef** result);
145 
146  private:
147   // Converts FunctionDefs to Graphs and adds mappings from
148   // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
149   Status Initialize(const FunctionDef& outer_scope,
150                     const NodeDef& map_defun_node);
151 
152   // Converts Graphs back to FunctionDefs and adds them to `lib_`.
153   Status GetResult(FunctionDef** vectorized_function);
154 
155   // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
156   // `outer_scope_`, until there are no convertible outputs remaining.
157   void VectorizeHelper();
158 
159   // Vectorizes map_defun_fn's output at output_position.
160   Status ConvertOutput(int output_position);
161 
162   // Adds mappings from node's outputs tensors to converted output tensors,
163   // creating the necessary new node(s). Generally, the steps to convert an op
164   // are:
165   // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
166   //    These operations collectively compute the same value as what running
167   //    the original operation on slices of the input tensors would produce.
168   //    For example, a Cast op in MapDefun translates to a Cast op in
169   //    `outer_scope_`, since the vectorized version of Cast is itself.
170   // 2) Promote the inputs of the op inputs to outputs of the
171   //    `map_defun_node_` and `map_defun_fn_`.
172   // 3) Add edges between the promoted inputs (that are now outputs of
173   //    `map_defun_node`) and the inputs ports of the new node(s).
174   // 4) For each output of the old node, add the mapping of output tensors to
175   //    the conversion map.
176   Status AddConversionMapping(Node* op_node);
177 
178   // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
179   // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
180   // inputs to `map_defun_node_`. This stacked tensor will be compatible with
181   // the expected output shape of `map_defun_node_`.
182   // This is equivalent to the _stack function in python Pfor.
183   Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);
184 
185   // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
186   // doing a depth-first search from the ret nodes. Lifts tensors that are
187   // unstacked (i.e. don't derive from arg tensors) into `outer_scope_` directly
188   // and adds mappings to `conversion_map_`.
189   // Note that this function may have false negatives, i.e. not
190   // add mappings for some tensors that are unstacked. This may happen in the
191   // following cases: 1) a vectorized op produces unstacked outputs from stacked
192   // inputs (e.g. the vectorized "Shape" op), 2) the tensors are in a cycle, or
193   // 3) the unstacked op could not be lifted into `outer_scope`.
194   Status AddUnstackedTensorMappings();
195 
196   // Recursive helper for `AddUnstackedTensorMappings`. If an op node is
197   // unstacked, lifts its output tensors into `outer_scope`, adding the mappings
198   // to `conversion_map`. Returns true if the unstacked mappings were added.
199   bool AddUnstackedTensorMappingsHelper(
200       TensorDesc&& tensor, absl::flat_hash_set<const Edge*>* visited);
201 
202   // Add mappings from `map_defun_fn_` arg tensors to `map_defun_node_` input
203   // tensors to `conversion_map_`.
204   Status AddArgTensorMappings();
205 
206   // Maps a tensor to the corresponding WrappedTensor. For example,
207   // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
208   std::map<TensorDesc, WrappedTensor> conversion_map_;
209 
210   // Unconvertible ret nodes
211   std::set<Node*> unconvertible_;
212 
213   FunctionDefLibrary* lib_;  // Not owned
214   FunctionLibraryDefinition lib_def_;
215   // Note that FunctionBody has a pointer to a Graph object that corresponds
216   // to the function's subgraph, with additional kArgOp and kRetValOp nodes
217   // that denote that function arguments and return values. These nodes have the
218   // attrs "T" for the type, and "index" for the argument / retval index
219   // respectively. FunctionBody also keeps track of arg/ret_nodes and
220   // arg/ret_types, that should be ordered according to argument/output indices.
221   std::unique_ptr<Graph> outer_scope_;
222   std::unique_ptr<FunctionBody> map_defun_fn_;
223   Node* map_defun_node_ = nullptr;  // Owned by `outer_scope`
224 
225   // Caches the loop_len_node_ needed for tiling unstacked output. This
226   // corresponds to a vector with one element.
227   Node* loop_len_node_ = nullptr;  // Owned by `outer_scope`
228   Status status_;
229 };
230 
AddConversionMapping(Node * op_node)231 Status Vectorization::AddConversionMapping(Node* op_node) {
232   for (auto edge : op_node->in_edges()) {
233     if (edge->IsControlEdge()) {
234       return errors::InvalidArgument(
235           "Vectorizing outputs with control inputs is currently not "
236           "supported.");
237     }
238   }
239 
240   auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
241   if (vectorizer == nullptr) {
242     return errors::Unimplemented("No vectorizer registered for op: ",
243                                  op_node->type_string());
244   }
245   std::vector<WrappedTensor> inputs, outputs;
246   inputs.reserve(op_node->num_inputs());
247   outputs.reserve(op_node->num_outputs());
248 
249   std::vector<const Edge*> input_edges;
250   TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
251 
252   // The inputs for the node to be converted may already have been converted
253   // themselves. For those that are not, we promote them to MapDefun outputs.
254   for (int i = 0; i < op_node->num_inputs(); ++i) {
255     auto edge = input_edges[i];
256     if (auto found = gtl::FindOrNull(conversion_map_,
257                                      {edge->src(), edge->src_output()})) {
258       inputs.push_back(*found);
259     } else {
260       // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
261       // We assume that all unconverted inputs will be stacked, since we
262       // converted all unstacked nodes in `Initialize`. However, it's actually
263       // possible that yet-unconverted nodes may produce unstacked outputs after
264       // they are vectorized. (For example, see the "Shape" converter in
265       // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
266       // an unstacked input but receives a stacked one, vectorizer->Vectorize
267       // will return an error.
268       TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
269                                            {edge->src(), edge->src_output()}));
270       int output_index = map_defun_fn_->ret_nodes.size() - 1;
271       inputs.push_back({map_defun_node_, output_index, true});
272     }
273   }
274 
275   Status s = vectorizer->Vectorize(*op_node, outer_scope_.get(),
276                                    std::move(inputs), &outputs);
277   if (!s.ok()) {
278     VLOG(2) << "Vectorizer for op \"" << op_node->type_string()
279             << "\" failed with error: " << s;
280     return s;
281   }
282   const int64 op_node_num_outputs = op_node->num_outputs();
283   if (op_node_num_outputs != outputs.size()) {
284     return errors::Internal(
285         "Number of vectorizer outputs does not match. Expected: ",
286         op_node->num_outputs(), " Actual: ", outputs.size());
287   }
288 
289   // Add output mappings.
290   for (int i = 0; i < op_node->num_outputs(); ++i) {
291     conversion_map_.insert({{op_node, i}, outputs[i]});
292   }
293 
294   return Status::OK();
295 }
296 
ConvertOutput(int output_position)297 Status Vectorization::ConvertOutput(int output_position) {
298   // ret_edge->src() is the actual op that generated the retval, and
299   // ret_edge->dst() is the retval node whose op is "_Retval"
300   const Edge* ret_edge;
301   TF_RETURN_IF_ERROR(
302       map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
303 
304   TensorDesc output({ret_edge->src(), ret_edge->src_output()});
305   TensorDesc converted_output;
306 
307   // It's possible the output already has a mapping, if it comes from a node
308   // that has already been converted.
309   auto found = gtl::FindOrNull(conversion_map_, output);
310   if (!found) {
311     TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
312     found = &conversion_map_.at(output);
313   }
314 
315   if (found->stacked) {
316     converted_output = {found->node, found->output_index};
317   } else {
318     // Some outputs may be unstacked if they don't derive from arg nodes
319     // (for example, if a function returns a constant). For these, we
320     // have to add extra nodes to tile it in the 0th dimension.
321     TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
322   }
323 
324   ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
325                      outer_scope_.get());
326   RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
327                        map_defun_node_);
328 
329   return Status::OK();
330 }
331 
Vectorize(const FunctionDef & outer_scope,const NodeDef & map_defun_node,FunctionDef ** result)332 Status Vectorization::Vectorize(const FunctionDef& outer_scope,
333                                 const NodeDef& map_defun_node,
334                                 FunctionDef** result) {
335   TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
336   VectorizeHelper();
337   return GetResult(result);
338 }
339 
VectorizeHelper()340 void Vectorization::VectorizeHelper() {
341   while (true) {
342     int output_position = graph_utils::GetFirstElementIndexWithPredicate(
343         [this](Node* n) {
344           return this->unconvertible_.find(n) == this->unconvertible_.end();
345         },
346         map_defun_fn_->ret_nodes);
347 
348     // No outputs left to convert
349     if (output_position == -1) break;
350 
351     Status s = ConvertOutput(output_position);
352     if (!s.ok()) {
353       Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
354       VLOG(2) << "Could not convert the output at node: "
355               << output_node->DebugString() << "\nError: " << s;
356       unconvertible_.insert(output_node);
357     }
358   }
359 
360   // If we've converted all the outputs of the MapDefun function, we no longer
361   // need the MapDefun node and can delete it.
362   if (map_defun_fn_->ret_nodes.empty()) {
363     outer_scope_->RemoveNode(map_defun_node_);
364   }
365 }
366 
Initialize(const FunctionDef & outer_scope,const NodeDef & map_defun_node)367 Status Vectorization::Initialize(const FunctionDef& outer_scope,
368                                  const NodeDef& map_defun_node) {
369   // Convert outer_scope and map_defun_fn to FunctionBodys so we can
370   // work on Graphs directly.
371   const FunctionDef* map_defun_fn =
372       lib_def_.Find(map_defun_node.attr().at("f").func().name());
373 
374   if (map_defun_fn == nullptr) {
375     return errors::NotFound("Could not find function with name ",
376                             map_defun_node.attr().at("f").func().name(),
377                             " in function library.");
378   }
379 
380   std::unique_ptr<FunctionBody> outer_fn;
381   TF_RETURN_IF_ERROR(
382       FunctionDefToBodyHelper(outer_scope, {}, &lib_def_, &outer_fn));
383   // We don't need outer_fn, just the graph
384   outer_scope_.reset(outer_fn->graph);
385   outer_fn->graph = nullptr;
386 
387   TF_RETURN_IF_ERROR(
388       FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_, &map_defun_fn_));
389 
390   // Find the MapDefun node in outer_scope_
391   int node_id = graph_utils::GetFirstElementIndexWithPredicate(
392       [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
393       outer_scope_->nodes());
394   if (node_id == -1) {
395     return errors::NotFound("Could not find node with name ",
396                             map_defun_node.name(), " in outer_scope.");
397   }
398   map_defun_node_ = outer_scope_->FindNodeId(node_id);
399 
400   TF_RETURN_IF_ERROR(AddArgTensorMappings());
401   TF_RETURN_IF_ERROR(AddUnstackedTensorMappings());
402   loop_len_node_ = nullptr;
403 
404   return Status::OK();
405 }
406 
407 // TODO(rachelim): It might be profitable to use the C++ API for this instead of
408 // NodeBuilder
StackTensor(WrappedTensor * unstacked,TensorDesc * result)409 Status Vectorization::StackTensor(WrappedTensor* unstacked,
410                                   TensorDesc* result) {
411   if (unstacked->node->output_type(unstacked->output_index) == DT_VARIANT) {
412     // TODO(b/124069171): "ExpandDims" doesn't work with Variant tensors.
413     return errors::Unimplemented("Cannot stack tensor with Variant type.");
414   }
415   // Note that all these nodes are necessary as the size of the batch may not be
416   // constant.
417   if (unstacked->stacked) {
418     return errors::Internal("Can only stack unstacked tensor.");
419   }
420 
421   Graph* g = outer_scope_.get();
422   auto node_builder = [](StringPiece op) {
423     return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
424   };
425 
426   auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
427                                     Node** result) {
428     TF_RETURN_IF_ERROR(val.status);
429     return node_builder("Const")
430         .Attr("value", val.tensor)
431         .Attr("dtype", val.tensor.dtype())
432         .Finalize(graph, result);
433   };
434 
435   // If loop_len_node_ hasn't been created yet, add the node and cache it.
436   if (loop_len_node_ == nullptr) {
437     Node* input_node;
438     TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));
439 
440     Node* shape_node;
441     TF_RETURN_IF_ERROR(
442         node_builder("Shape").Input(input_node).Finalize(g, &shape_node));
443 
444     Node* const_vec_0;
445     TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
446     Node* const_vec_1;
447     TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));
448 
449     Node* strided_slice_node;
450     TF_RETURN_IF_ERROR(node_builder("StridedSlice")
451                            .Input(shape_node)   // input
452                            .Input(const_vec_0)  // begin
453                            .Input(const_vec_1)  // end
454                            .Input(const_vec_1)  // strides
455                            .Finalize(g, &strided_slice_node));
456 
457     // Produces a vector of length 1
458     TF_RETURN_IF_ERROR(node_builder("Reshape")
459                            .Input(strided_slice_node)  // tensor
460                            .Input(const_vec_1)         // shape
461                            .Finalize(g, &loop_len_node_));
462   }
463 
464   Node* ones_shape;
465   TF_RETURN_IF_ERROR(node_builder("Shape")
466                          .Input(unstacked->node)  // input
467                          .Finalize(g, &ones_shape));
468 
469   Node* ones;
470   TF_RETURN_IF_ERROR(
471       node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));
472 
473   Node* const_0;
474   TF_RETURN_IF_ERROR(make_const(0, g, &const_0));
475 
476   Node* multiples;
477   TF_RETURN_IF_ERROR(node_builder("Concat")
478                          .Input(const_0)                           // concat_dim
479                          .Input({{loop_len_node_, 0}, {ones, 0}})  // values
480                          .Finalize(g, &multiples));
481 
482   Node* expand_dims;
483   TF_RETURN_IF_ERROR(node_builder("ExpandDims")
484                          .Input(unstacked->node)  // input
485                          .Input(const_0)          // dim
486                          .Finalize(g, &expand_dims));
487 
488   TF_RETURN_IF_ERROR(node_builder("Tile")
489                          .Input(expand_dims)  // input
490                          .Input(multiples)    // multiples
491                          .Finalize(g, &result->first));
492   result->second = 0;
493   return Status::OK();
494 }
495 
AddArgTensorMappings()496 Status Vectorization::AddArgTensorMappings() {
497   // Note that inputs to map_defun_fn_ are either regular arguments (for which
498   // the operations are mapped across their 0th dimension) or captured inputs
499   // (for which the operations apply to the argument wholesale).
500   int num_args =
501       map_defun_node_->attrs().Find("Targuments")->list().type_size();
502 
503   auto add_conversion = [this](Node* arg_node, bool stacked) {
504     Node* input_node;
505     TF_RETURN_IF_ERROR(map_defun_node_->input_node(
506         arg_node->attrs().Find("index")->i(), &input_node));
507 
508     conversion_map_.insert({{arg_node, 0}, {input_node, 0, stacked}});
509 
510     // Control inputs
511     conversion_map_.insert({{arg_node, Graph::kControlSlot},
512                             {input_node, Graph::kControlSlot, stacked}});
513 
514     return Status::OK();
515   };
516 
517   // Regular arguments
518   for (int i = 0; i < num_args; ++i) {
519     TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], true));
520   }
521 
522   // Captured inputs. These are applied (without slicing) to every iteration of
523   // the map function, hence are mapped to unstacked nodes.
524   for (int i = num_args, end = map_defun_fn_->arg_nodes.size(); i < end; ++i) {
525     TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], false));
526   }
527 
528   return Status::OK();
529 }
530 
AddUnstackedTensorMappingsHelper(TensorDesc && tensor,absl::flat_hash_set<const Edge * > * visited)531 bool Vectorization::AddUnstackedTensorMappingsHelper(
532     TensorDesc&& tensor, absl::flat_hash_set<const Edge*>* visited) {
533   if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
534     return !found->stacked;
535   }
536 
537   if (tensor.first->op_def().is_stateful()) {
538     // We don't lift stateful nodes directly out of the MapDefun, since they may
539     // have to be executed N times.
540     return false;
541   }
542 
543   bool is_unstacked = true;
544   for (const auto& edge : tensor.first->in_edges()) {
545     // Ignore Source nodes. Note that these are also ignored in the
546     // GraphToFunctionDef conversion.
547     if (edge->src()->IsSource()) continue;
548 
549     if (visited->find(edge) != visited->end()) {
550       // If we've visited this edge already, we're in a cycle. In this case, we
551       // are conservative and don't mark the node as unstacked.
552       is_unstacked = false;
553       continue;
554     }
555     visited->insert(edge);
556 
557     // A node is unstacked if all of its inputs are unstacked
558     is_unstacked &= AddUnstackedTensorMappingsHelper(
559         {edge->src(), edge->src_output()}, visited);
560   }
561 
562   if (!is_unstacked) {
563     return false;
564   }
565 
566   // If the node is unstacked, we copy it into outer_scope_ and
567   // add it to the map. Note that we don't clean up the nodes that are copied
568   // in map_defun_fn_, and rely on them being pruned out later.
569   Status status;
570   Node* node = outer_scope_->AddNode(tensor.first->def(), &status);
571   if (!status.ok()) return false;
572 
573   // Add input edges to nodes that should already have been lifted.
574   for (const auto& edge : tensor.first->in_edges()) {
575     // Ignore Source nodes. Note that these are also ignored in the
576     // GraphToFunctionDef conversion.
577     if (edge->src()->IsSource()) continue;
578 
579     if (auto found = gtl::FindOrNull(conversion_map_,
580                                      {edge->src(), edge->src_output()})) {
581       outer_scope_->AddEdge(found->node, found->output_index, node,
582                             edge->dst_input());
583     } else {
584       return false;
585     }
586   }
587 
588   // Add output mappings
589   for (int i = 0; i < tensor.first->num_outputs(); ++i) {
590     conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
591   }
592   conversion_map_.insert({{tensor.first, Graph::kControlSlot},
593                           WrappedTensor(node, Graph::kControlSlot, false)});
594 
595   return true;
596 }
597 
AddUnstackedTensorMappings()598 Status Vectorization::AddUnstackedTensorMappings() {
599   absl::flat_hash_set<const Edge*> visited;
600   for (const auto& ret_node : map_defun_fn_->ret_nodes) {
601     const Edge* in_edge = nullptr;
602     TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
603     AddUnstackedTensorMappingsHelper({in_edge->src(), in_edge->src_output()},
604                                      &visited);
605   }
606   return Status::OK();
607 }
608 
GetResult(FunctionDef ** vectorized_function)609 Status Vectorization::GetResult(FunctionDef** vectorized_function) {
610   TF_RETURN_IF_ERROR(status_);
611   TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get()));
612   TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph));
613 
614   if (!map_defun_fn_->ret_nodes.empty()) {
615     FunctionDef* map_defun_fn = lib_->add_function();
616     graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
617     TF_RETURN_IF_ERROR(GraphToFunctionDef(
618         *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
619 
620     AttrValue func_attr;
621     func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
622     map_defun_node_->AddAttr("f", func_attr);
623   }
624 
625   *vectorized_function = lib_->add_function();
626   graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
627                                           *vectorized_function);
628   TF_RETURN_IF_ERROR(GraphToFunctionDef(
629       *outer_scope_, (*vectorized_function)->signature().name(),
630       *vectorized_function));
631   return Status::OK();
632 }
633 
634 }  // namespace
635 
VectorizeMapDefun(const FunctionDef & outer_scope,const NodeDef & map_defun_node,FunctionDefLibrary * lib,FunctionDef ** result)636 Status VectorizeMapDefun(const FunctionDef& outer_scope,
637                          const NodeDef& map_defun_node, FunctionDefLibrary* lib,
638                          FunctionDef** result) {
639   *result = nullptr;
640   return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
641 }
642 
643 }  // namespace vectorization_utils
644 }  // namespace grappler
645 }  // namespace tensorflow
646