1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
17 #include <iterator>
18 #include "absl/algorithm/container.h"
19 #include "absl/container/inlined_vector.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_replace.h"
22 #include "absl/types/optional.h"
23 #include "tensorflow/cc/framework/scope_internal.h"
24 #include "tensorflow/cc/ops/array_ops.h"
25 #include "tensorflow/cc/ops/const_op.h"
26 #include "tensorflow/cc/ops/math_ops.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/compiler/jit/xla_cluster_util.h"
29 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/core/common_runtime/shape_refiner.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/public/session_options.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35 #include "tensorflow/core/util/dump_graph.h"
36 
37 namespace tensorflow {
38 namespace {
39 
40 // StatusOrOptional<T> instances hold
41 //
42 //  - A non-OK Status to indicate an error that needs to be propagated out of
43 //    this pass (e.g. the Graph is malformed).
44 //
45 //  - A nullopt to indicate the function that created the instance failed to do
46 //    what it set out to do but this is not actually an error
47 //    (e.g. TryToGetTensorFromConstOp was passed a non-Const node).
48 //
49 //  - A T to indicate a successful operation.
50 template <class T>
51 using StatusOrOptional = xla::StatusOr<absl::optional<T>>;
52 
TryToGetTensorFromConstOp(Node * n)53 StatusOrOptional<Tensor> TryToGetTensorFromConstOp(Node* n) {
54   if (n->type_string() != "Const") {
55     return {absl::nullopt};
56   }
57 
58   const TensorProto* proto = nullptr;
59   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
60   Tensor tensor(proto->dtype());
61   TF_RET_CHECK(tensor.FromProto(*proto));
62   return {tensor};
63 }
64 
65 struct SliceInputs {
66   Output slice_op;
67   Output input;
68   Output begin;
69   Output size;
70 
71   // The size of the TF slice operation as a std::vector.  We can always compute
72   // this because we only manipulate slices with a Const size.
73   std::vector<int64> size_as_vector;
74 };
75 
IntTensorAsVector(const Tensor & t)76 std::vector<int64> IntTensorAsVector(const Tensor& t) {
77   DCHECK(t.dtype() == DT_INT32 || t.dtype() == DT_INT64);
78   std::vector<int64> result;
79   result.reserve(t.NumElements());
80   for (int i = 0; i < t.NumElements(); i++) {
81     int64 element = t.dtype() == DT_INT32
82                         ? static_cast<int64>(t.flat<int32>()(i))
83                         : t.flat<int64>()(i);
84     result.push_back(element);
85   }
86   return result;
87 }
88 
89 // Packages up the inputs to a Slice operation into an instance of
90 // `SliceInputs`.
GetSliceInputs(Node * slice)91 StatusOrOptional<SliceInputs> GetSliceInputs(Node* slice) {
92   const int kSliceInputIndex = 0;
93   const int kSliceBeginIndex = 1;
94   const int kSliceSizeIndex = 2;
95 
96   const Edge* slice_input_edge;
97   TF_RETURN_IF_ERROR(slice->input_edge(kSliceInputIndex, &slice_input_edge));
98   const Edge* slice_size_edge;
99   TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge));
100   const Edge* slice_begin_edge;
101   TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge));
102 
103   SliceInputs slice_inputs;
104   slice_inputs.input =
105       Output(slice_input_edge->src(), slice_input_edge->src_output());
106   slice_inputs.begin =
107       Output(slice_begin_edge->src(), slice_begin_edge->src_output());
108   slice_inputs.size =
109       Output(slice_size_edge->src(), slice_size_edge->src_output());
110 
111   TF_ASSIGN_OR_RETURN(absl::optional<Tensor> tf_slice_size,
112                       TryToGetTensorFromConstOp(slice_inputs.size.node()));
113   if (!tf_slice_size.has_value()) {
114     return {absl::nullopt};
115   }
116 
117   if (tf_slice_size->dims() != 1) {
118     return {absl::nullopt};
119   }
120 
121   slice_inputs.size_as_vector = IntTensorAsVector(*tf_slice_size);
122   return {slice_inputs};
123 }
124 
125 // Casts `x` to a DT_INT64 if it isn't one already.
MakeInt64(const Scope & host_scope,absl::string_view name,const Output & x)126 Output MakeInt64(const Scope& host_scope, absl::string_view name,
127                  const Output& x) {
128   return x.type() == DT_INT64
129              ? x
130              : ops::Cast(host_scope.WithOpName(name, "_s64"), x, DT_INT64);
131 }
132 
133 // Returns `slice_inputs` with the index and size inputs cast to DT_INT64.
MakeSliceIndexAndSizeInt64(const Scope & host_scope,const SliceInputs & slice_inputs)134 SliceInputs MakeSliceIndexAndSizeInt64(const Scope& host_scope,
135                                        const SliceInputs& slice_inputs) {
136   SliceInputs result;
137   result.input = slice_inputs.input;
138   result.begin = MakeInt64(host_scope, "begin", slice_inputs.begin);
139   result.size = MakeInt64(host_scope, "size", slice_inputs.size);
140   result.size_as_vector = slice_inputs.size_as_vector;
141   return result;
142 }
143 
144 // This class caches emitted constants to avoid creating multiple nodes for the
145 // same constant value.  This helps make the generated GraphDef more readable.
146 class ConstantCache {
147  public:
ConstantCache(const Scope & s,const std::vector<const Edge * > & control_deps)148   explicit ConstantCache(const Scope& s,
149                          const std::vector<const Edge*>& control_deps)
150       : scope_(s), control_deps_(control_deps) {}
151 
Get1DHostConstant(int64 constant)152   Output Get1DHostConstant(int64 constant) {
153     auto it = cache_.find(constant);
154     if (it == cache_.end()) {
155       Output new_const =
156           ops::Const(scope_.WithOpName("const_", constant), {constant});
157       it = cache_.insert({constant, new_const}).first;
158       for (const Edge* e : control_deps_) {
159         scope_.graph()->AddControlEdge(e->src(), new_const.node());
160       }
161     }
162     return it->second;
163   }
164 
165  private:
166   Scope scope_;
167   std::unordered_map<int, Output> cache_;
168   std::vector<const Edge*> control_deps_;
169 };
170 
171 // Returns a node computing the size of the Slice op with inputs `slice_inputs`.
ComputeSliceSize(const Scope & host_scope,const SliceInputs & slice_inputs,std::vector<const Edge * > control_deps,Output * size)172 Status ComputeSliceSize(const Scope& host_scope,
173                         const SliceInputs& slice_inputs,
174                         std::vector<const Edge*> control_deps, Output* size) {
175   // If slice_size[i] >= 0 then slice_size[i] = slice_size[i].
176   //
177   // If slice_size[i] == -1 then slice_size[i] = input_size[i] -
178   // begin[i].
179   //
180   // If slice_size[i] < -1 then executing the slice will throw an error, and we
181   // don't do anything here.  We've already filtered these cases out in
182   // IsRewritableSlice.
183 
184   if (absl::c_all_of(slice_inputs.size_as_vector,
185                      [](int64 i) { return i >= 0; })) {
186     *size = slice_inputs.size;
187     return Status::OK();
188   }
189 
190   Output input_shape =
191       ops::Shape(host_scope.WithOpName("input_shape"), slice_inputs.input,
192                  ops::Shape::OutType(DT_INT64));
193 
194   ConstantCache constant_pool(host_scope, control_deps);
195 
196   std::vector<Output> slice_size;
197   for (int i = 0, end = slice_inputs.size_as_vector.size(); i < end; i++) {
198     if (slice_inputs.size_as_vector[i] >= 0) {
199       slice_size.push_back(
200           constant_pool.Get1DHostConstant(slice_inputs.size_as_vector[i]));
201       continue;
202     }
203 
204     DCHECK_EQ(slice_inputs.size_as_vector[i], -1);
205 
206     Output begin_i = ops::Slice(
207         host_scope.WithOpName("begin_", i), slice_inputs.begin,
208         constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1));
209 
210     Output input_shape_i = ops::Slice(
211         host_scope.WithOpName("input_shape_", i), input_shape,
212         constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1));
213 
214     slice_size.push_back(ops::Sub(host_scope.WithOpName("slice_size_", i),
215                                   input_shape_i, begin_i));
216     DCHECK_EQ(slice_size.back().type(), DT_INT64);
217   }
218 
219   // Trivial ConcatV2 nodes (with exactly one input) are disallowed.
220   if (slice_size.size() == 1) {
221     *size = slice_size[0];
222   } else {
223     auto concat_axis = ops::Const(host_scope.WithOpName("concat_axis"), 0);
224     for (const Edge* e : control_deps) {
225       host_scope.graph()->AddControlEdge(e->src(), concat_axis.node());
226     }
227     *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size,
228                         concat_axis);
229   }
230   return Status::OK();
231 }
232 
233 // Terminology: "static sized" slice is a slice with the
234 // _XlaCompileTimeConstantInputs attribute set to {2}.  The output shape of
235 // these slices can be solely determined by their "size" input.
ConvertTensorFlowSliceToStaticShapedSlice(Graph * g,Node * slice,const SliceInputs & slice_inputs,absl::string_view cluster_name,Node ** result)236 Status ConvertTensorFlowSliceToStaticShapedSlice(
237     Graph* g, Node* slice, const SliceInputs& slice_inputs,
238     absl::string_view cluster_name, Node** result) {
239   string host_name;
240   TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
241       slice->assigned_device_name(), &host_name));
242 
243   Status status;
244   Scope main_scope =
245       NewInternalScope(g, &status, /*refiner=*/nullptr)
246           .WithXlaCluster(string(cluster_name))
247           .NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice"));
248   Scope host_scope = main_scope.WithAssignedDevice(host_name);
249 
250   // In the future we may want to be clever here and avoid the extra Cast ops.
251   SliceInputs slice_inputs_int64 =
252       MakeSliceIndexAndSizeInt64(host_scope, slice_inputs);
253 
254   // Create a list of all control dependencies to be copied when possibly
255   // replacing nodes related to slice_size.
256   Node* old_size;
257   std::vector<const Edge*> old_size_ctrl_deps;
258   TF_RETURN_IF_ERROR(slice->input_node(2, &old_size));
259   absl::c_copy_if(old_size->in_edges(), std::back_inserter(old_size_ctrl_deps),
260                   [](const Edge* e) { return e->IsControlEdge(); });
261 
262   Output slice_size;
263   TF_RETURN_IF_ERROR(ComputeSliceSize(host_scope, slice_inputs_int64,
264                                       old_size_ctrl_deps, &slice_size));
265 
266   *result =
267       ops::Slice(main_scope.WithAssignedDevice(slice->assigned_device_name())
268                      .WithOpName("static_shaped_slice"),
269                  slice_inputs_int64.input, slice_inputs_int64.begin, slice_size)
270           .node();
271 
272   TF_RETURN_IF_ERROR(main_scope.status());
273 
274   std::vector<string> compile_time_const_inputs;
275   compile_time_const_inputs.push_back("size");
276   (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr,
277                      compile_time_const_inputs);
278   return status;
279 }
280 
ReplaceTensorFlowSliceWithStaticShapedSlice(Graph * g,Node * slice,Node * static_shaped_slice)281 void ReplaceTensorFlowSliceWithStaticShapedSlice(Graph* g, Node* slice,
282                                                  Node* static_shaped_slice) {
283   absl::InlinedVector<const Edge*, 6> edges_to_remove;
284   std::vector<const Edge*> slice_out_edges;
285   absl::c_copy(slice->out_edges(), std::back_inserter(slice_out_edges));
286   for (const Edge* e : slice_out_edges) {
287     DCHECK(e->src_output() == 0 || e->src_output() == Graph::kControlSlot);
288 
289     int src_output = e->src_output();
290     int dst_input = e->dst_input();
291     Node* dst = e->dst();
292     g->RemoveEdge(e);
293     g->AddEdge(static_shaped_slice, src_output, dst, dst_input);
294   }
295 
296   for (const Edge* e : slice->in_edges()) {
297     if (e->IsControlEdge()) {
298       g->AddControlEdge(e->src(), static_shaped_slice);
299     }
300   }
301 
302   g->RemoveNode(slice);
303 }
304 
RewriteSlice(Graph * g,Node * slice,const SliceInputs & slice_inputs,absl::string_view cluster_name)305 Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
306                     absl::string_view cluster_name) {
307   VLOG(3) << "Rewriting slice " << slice->name()
308           << " to a \"static shaped\" Slice";
309   Node* static_shaped_slice;
310   TF_RETURN_IF_ERROR(ConvertTensorFlowSliceToStaticShapedSlice(
311       g, slice, slice_inputs, cluster_name, &static_shaped_slice));
312   ReplaceTensorFlowSliceWithStaticShapedSlice(g, slice, static_shaped_slice);
313   return Status::OK();
314 }
315 
316 // Return true if `n` is a slice we should rewrite to have a static shape
317 // (i.e. have the output shape only depend on the "size" input).
ShouldRewriteSlice(Node * n)318 xla::StatusOr<bool> ShouldRewriteSlice(Node* n) {
319   if (n->type_string() != "Slice") {
320     return false;
321   }
322 
323   if (!GetXlaClusterForNode(*n).has_value()) {
324     // There is no need to change slice ops outside XLA clusters.
325     return false;
326   }
327 
328   TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
329                       GetSliceInputs(n));
330   if (!slice_inputs.has_value()) {
331     return false;
332   }
333 
334   // If slice_size[i] < -1 for any i then executing the slice will throw an
335   // error, and we don't do anything here.
336   bool slice_size_has_error = absl::c_all_of(
337       slice_inputs->size_as_vector, [](int64 size_i) { return size_i >= -1; });
338   if (!slice_size_has_error) {
339     return false;
340   }
341 
342   // No point in rewriting slices that have both size and begin as constants.
343   return !slice_inputs->begin.node()->IsConstant();
344 }
345 
FindAndRewriteSlices(Graph * g,bool * changed)346 Status FindAndRewriteSlices(Graph* g, bool* changed) {
347   std::vector<Node*> slices_to_rewrite;
348   for (Node* n : g->nodes()) {
349     TF_ASSIGN_OR_RETURN(bool is_rewritable, ShouldRewriteSlice(n));
350     if (is_rewritable) {
351       slices_to_rewrite.push_back(n);
352     }
353   }
354 
355   for (Node* n : slices_to_rewrite) {
356     TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
357                         GetSliceInputs(n));
358     TF_RET_CHECK(slice_inputs.has_value());
359     TF_RETURN_IF_ERROR(
360         RewriteSlice(g, n, *slice_inputs, *GetXlaClusterForNode(*n)));
361   }
362 
363   if (!slices_to_rewrite.empty()) {
364     // We've added constants to the graph; hook them up to _SOURCE.
365     FixupSourceAndSinkEdges(g);
366   }
367 
368   *changed = !slices_to_rewrite.empty();
369 
370   return Status::OK();
371 }
372 }  // namespace
373 
Run(const GraphOptimizationPassOptions & options)374 Status IncreaseDynamismForAutoJitPass::Run(
375     const GraphOptimizationPassOptions& options) {
376   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
377   if (flags->tf_xla_clustering_debug) {
378     DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass",
379                     **options.graph, options.flib_def);
380   }
381 
382   bool changed;
383   TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed));
384   if (changed && flags->tf_xla_clustering_debug) {
385     DumpGraphToFile("increase_dynamism_for_auto_jit_pass", **options.graph,
386                     options.flib_def);
387   }
388 
389   return Status::OK();
390 }
391 
392 }  // namespace tensorflow
393