1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
17 #include <iterator>
18 #include "absl/algorithm/container.h"
19 #include "absl/container/inlined_vector.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_replace.h"
22 #include "absl/types/optional.h"
23 #include "tensorflow/cc/framework/scope_internal.h"
24 #include "tensorflow/cc/ops/array_ops.h"
25 #include "tensorflow/cc/ops/const_op.h"
26 #include "tensorflow/cc/ops/math_ops.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/compiler/jit/xla_cluster_util.h"
29 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/core/common_runtime/shape_refiner.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/public/session_options.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35 #include "tensorflow/core/util/dump_graph.h"
36
37 namespace tensorflow {
38 namespace {
39
40 // StatusOrOptional<T> instances hold
41 //
42 // - A non-OK Status to indicate an error that needs to be propagated out of
43 // this pass (e.g. the Graph is malformed).
44 //
45 // - A nullopt to indicate the function that created the instance failed to do
46 // what it set out to do but this is not actually an error
47 // (e.g. TryToGetTensorFromConstOp was passed a non-Const node).
48 //
49 // - A T to indicate a successful operation.
50 template <class T>
51 using StatusOrOptional = xla::StatusOr<absl::optional<T>>;
52
TryToGetTensorFromConstOp(Node * n)53 StatusOrOptional<Tensor> TryToGetTensorFromConstOp(Node* n) {
54 if (n->type_string() != "Const") {
55 return {absl::nullopt};
56 }
57
58 const TensorProto* proto = nullptr;
59 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
60 Tensor tensor(proto->dtype());
61 TF_RET_CHECK(tensor.FromProto(*proto));
62 return {tensor};
63 }
64
65 struct SliceInputs {
66 Output slice_op;
67 Output input;
68 Output begin;
69 Output size;
70
71 // The size of the TF slice operation as a std::vector. We can always compute
72 // this because we only manipulate slices with a Const size.
73 std::vector<int64> size_as_vector;
74 };
75
IntTensorAsVector(const Tensor & t)76 std::vector<int64> IntTensorAsVector(const Tensor& t) {
77 DCHECK(t.dtype() == DT_INT32 || t.dtype() == DT_INT64);
78 std::vector<int64> result;
79 result.reserve(t.NumElements());
80 for (int i = 0; i < t.NumElements(); i++) {
81 int64 element = t.dtype() == DT_INT32
82 ? static_cast<int64>(t.flat<int32>()(i))
83 : t.flat<int64>()(i);
84 result.push_back(element);
85 }
86 return result;
87 }
88
89 // Packages up the inputs to a Slice operation into an instance of
90 // `SliceInputs`.
GetSliceInputs(Node * slice)91 StatusOrOptional<SliceInputs> GetSliceInputs(Node* slice) {
92 const int kSliceInputIndex = 0;
93 const int kSliceBeginIndex = 1;
94 const int kSliceSizeIndex = 2;
95
96 const Edge* slice_input_edge;
97 TF_RETURN_IF_ERROR(slice->input_edge(kSliceInputIndex, &slice_input_edge));
98 const Edge* slice_size_edge;
99 TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge));
100 const Edge* slice_begin_edge;
101 TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge));
102
103 SliceInputs slice_inputs;
104 slice_inputs.input =
105 Output(slice_input_edge->src(), slice_input_edge->src_output());
106 slice_inputs.begin =
107 Output(slice_begin_edge->src(), slice_begin_edge->src_output());
108 slice_inputs.size =
109 Output(slice_size_edge->src(), slice_size_edge->src_output());
110
111 TF_ASSIGN_OR_RETURN(absl::optional<Tensor> tf_slice_size,
112 TryToGetTensorFromConstOp(slice_inputs.size.node()));
113 if (!tf_slice_size.has_value()) {
114 return {absl::nullopt};
115 }
116
117 if (tf_slice_size->dims() != 1) {
118 return {absl::nullopt};
119 }
120
121 slice_inputs.size_as_vector = IntTensorAsVector(*tf_slice_size);
122 return {slice_inputs};
123 }
124
125 // Casts `x` to a DT_INT64 if it isn't one already.
MakeInt64(const Scope & host_scope,absl::string_view name,const Output & x)126 Output MakeInt64(const Scope& host_scope, absl::string_view name,
127 const Output& x) {
128 return x.type() == DT_INT64
129 ? x
130 : ops::Cast(host_scope.WithOpName(name, "_s64"), x, DT_INT64);
131 }
132
133 // Returns `slice_inputs` with the index and size inputs cast to DT_INT64.
MakeSliceIndexAndSizeInt64(const Scope & host_scope,const SliceInputs & slice_inputs)134 SliceInputs MakeSliceIndexAndSizeInt64(const Scope& host_scope,
135 const SliceInputs& slice_inputs) {
136 SliceInputs result;
137 result.input = slice_inputs.input;
138 result.begin = MakeInt64(host_scope, "begin", slice_inputs.begin);
139 result.size = MakeInt64(host_scope, "size", slice_inputs.size);
140 result.size_as_vector = slice_inputs.size_as_vector;
141 return result;
142 }
143
144 // This class caches emitted constants to avoid creating multiple nodes for the
145 // same constant value. This helps make the generated GraphDef more readable.
146 class ConstantCache {
147 public:
ConstantCache(const Scope & s,const std::vector<const Edge * > & control_deps)148 explicit ConstantCache(const Scope& s,
149 const std::vector<const Edge*>& control_deps)
150 : scope_(s), control_deps_(control_deps) {}
151
Get1DHostConstant(int64 constant)152 Output Get1DHostConstant(int64 constant) {
153 auto it = cache_.find(constant);
154 if (it == cache_.end()) {
155 Output new_const =
156 ops::Const(scope_.WithOpName("const_", constant), {constant});
157 it = cache_.insert({constant, new_const}).first;
158 for (const Edge* e : control_deps_) {
159 scope_.graph()->AddControlEdge(e->src(), new_const.node());
160 }
161 }
162 return it->second;
163 }
164
165 private:
166 Scope scope_;
167 std::unordered_map<int, Output> cache_;
168 std::vector<const Edge*> control_deps_;
169 };
170
171 // Returns a node computing the size of the Slice op with inputs `slice_inputs`.
ComputeSliceSize(const Scope & host_scope,const SliceInputs & slice_inputs,std::vector<const Edge * > control_deps,Output * size)172 Status ComputeSliceSize(const Scope& host_scope,
173 const SliceInputs& slice_inputs,
174 std::vector<const Edge*> control_deps, Output* size) {
175 // If slice_size[i] >= 0 then slice_size[i] = slice_size[i].
176 //
177 // If slice_size[i] == -1 then slice_size[i] = input_size[i] -
178 // begin[i].
179 //
180 // If slice_size[i] < -1 then executing the slice will throw an error, and we
181 // don't do anything here. We've already filtered these cases out in
182 // IsRewritableSlice.
183
184 if (absl::c_all_of(slice_inputs.size_as_vector,
185 [](int64 i) { return i >= 0; })) {
186 *size = slice_inputs.size;
187 return Status::OK();
188 }
189
190 Output input_shape =
191 ops::Shape(host_scope.WithOpName("input_shape"), slice_inputs.input,
192 ops::Shape::OutType(DT_INT64));
193
194 ConstantCache constant_pool(host_scope, control_deps);
195
196 std::vector<Output> slice_size;
197 for (int i = 0, end = slice_inputs.size_as_vector.size(); i < end; i++) {
198 if (slice_inputs.size_as_vector[i] >= 0) {
199 slice_size.push_back(
200 constant_pool.Get1DHostConstant(slice_inputs.size_as_vector[i]));
201 continue;
202 }
203
204 DCHECK_EQ(slice_inputs.size_as_vector[i], -1);
205
206 Output begin_i = ops::Slice(
207 host_scope.WithOpName("begin_", i), slice_inputs.begin,
208 constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1));
209
210 Output input_shape_i = ops::Slice(
211 host_scope.WithOpName("input_shape_", i), input_shape,
212 constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1));
213
214 slice_size.push_back(ops::Sub(host_scope.WithOpName("slice_size_", i),
215 input_shape_i, begin_i));
216 DCHECK_EQ(slice_size.back().type(), DT_INT64);
217 }
218
219 // Trivial ConcatV2 nodes (with exactly one input) are disallowed.
220 if (slice_size.size() == 1) {
221 *size = slice_size[0];
222 } else {
223 auto concat_axis = ops::Const(host_scope.WithOpName("concat_axis"), 0);
224 for (const Edge* e : control_deps) {
225 host_scope.graph()->AddControlEdge(e->src(), concat_axis.node());
226 }
227 *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size,
228 concat_axis);
229 }
230 return Status::OK();
231 }
232
233 // Terminology: "static sized" slice is a slice with the
234 // _XlaCompileTimeConstantInputs attribute set to {2}. The output shape of
235 // these slices can be solely determined by their "size" input.
ConvertTensorFlowSliceToStaticShapedSlice(Graph * g,Node * slice,const SliceInputs & slice_inputs,absl::string_view cluster_name,Node ** result)236 Status ConvertTensorFlowSliceToStaticShapedSlice(
237 Graph* g, Node* slice, const SliceInputs& slice_inputs,
238 absl::string_view cluster_name, Node** result) {
239 string host_name;
240 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
241 slice->assigned_device_name(), &host_name));
242
243 Status status;
244 Scope main_scope =
245 NewInternalScope(g, &status, /*refiner=*/nullptr)
246 .WithXlaCluster(string(cluster_name))
247 .NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice"));
248 Scope host_scope = main_scope.WithAssignedDevice(host_name);
249
250 // In the future we may want to be clever here and avoid the extra Cast ops.
251 SliceInputs slice_inputs_int64 =
252 MakeSliceIndexAndSizeInt64(host_scope, slice_inputs);
253
254 // Create a list of all control dependencies to be copied when possibly
255 // replacing nodes related to slice_size.
256 Node* old_size;
257 std::vector<const Edge*> old_size_ctrl_deps;
258 TF_RETURN_IF_ERROR(slice->input_node(2, &old_size));
259 absl::c_copy_if(old_size->in_edges(), std::back_inserter(old_size_ctrl_deps),
260 [](const Edge* e) { return e->IsControlEdge(); });
261
262 Output slice_size;
263 TF_RETURN_IF_ERROR(ComputeSliceSize(host_scope, slice_inputs_int64,
264 old_size_ctrl_deps, &slice_size));
265
266 *result =
267 ops::Slice(main_scope.WithAssignedDevice(slice->assigned_device_name())
268 .WithOpName("static_shaped_slice"),
269 slice_inputs_int64.input, slice_inputs_int64.begin, slice_size)
270 .node();
271
272 TF_RETURN_IF_ERROR(main_scope.status());
273
274 std::vector<string> compile_time_const_inputs;
275 compile_time_const_inputs.push_back("size");
276 (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr,
277 compile_time_const_inputs);
278 return status;
279 }
280
ReplaceTensorFlowSliceWithStaticShapedSlice(Graph * g,Node * slice,Node * static_shaped_slice)281 void ReplaceTensorFlowSliceWithStaticShapedSlice(Graph* g, Node* slice,
282 Node* static_shaped_slice) {
283 absl::InlinedVector<const Edge*, 6> edges_to_remove;
284 std::vector<const Edge*> slice_out_edges;
285 absl::c_copy(slice->out_edges(), std::back_inserter(slice_out_edges));
286 for (const Edge* e : slice_out_edges) {
287 DCHECK(e->src_output() == 0 || e->src_output() == Graph::kControlSlot);
288
289 int src_output = e->src_output();
290 int dst_input = e->dst_input();
291 Node* dst = e->dst();
292 g->RemoveEdge(e);
293 g->AddEdge(static_shaped_slice, src_output, dst, dst_input);
294 }
295
296 for (const Edge* e : slice->in_edges()) {
297 if (e->IsControlEdge()) {
298 g->AddControlEdge(e->src(), static_shaped_slice);
299 }
300 }
301
302 g->RemoveNode(slice);
303 }
304
RewriteSlice(Graph * g,Node * slice,const SliceInputs & slice_inputs,absl::string_view cluster_name)305 Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
306 absl::string_view cluster_name) {
307 VLOG(3) << "Rewriting slice " << slice->name()
308 << " to a \"static shaped\" Slice";
309 Node* static_shaped_slice;
310 TF_RETURN_IF_ERROR(ConvertTensorFlowSliceToStaticShapedSlice(
311 g, slice, slice_inputs, cluster_name, &static_shaped_slice));
312 ReplaceTensorFlowSliceWithStaticShapedSlice(g, slice, static_shaped_slice);
313 return Status::OK();
314 }
315
316 // Return true if `n` is a slice we should rewrite to have a static shape
317 // (i.e. have the output shape only depend on the "size" input).
ShouldRewriteSlice(Node * n)318 xla::StatusOr<bool> ShouldRewriteSlice(Node* n) {
319 if (n->type_string() != "Slice") {
320 return false;
321 }
322
323 if (!GetXlaClusterForNode(*n).has_value()) {
324 // There is no need to change slice ops outside XLA clusters.
325 return false;
326 }
327
328 TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
329 GetSliceInputs(n));
330 if (!slice_inputs.has_value()) {
331 return false;
332 }
333
334 // If slice_size[i] < -1 for any i then executing the slice will throw an
335 // error, and we don't do anything here.
336 bool slice_size_has_error = absl::c_all_of(
337 slice_inputs->size_as_vector, [](int64 size_i) { return size_i >= -1; });
338 if (!slice_size_has_error) {
339 return false;
340 }
341
342 // No point in rewriting slices that have both size and begin as constants.
343 return !slice_inputs->begin.node()->IsConstant();
344 }
345
FindAndRewriteSlices(Graph * g,bool * changed)346 Status FindAndRewriteSlices(Graph* g, bool* changed) {
347 std::vector<Node*> slices_to_rewrite;
348 for (Node* n : g->nodes()) {
349 TF_ASSIGN_OR_RETURN(bool is_rewritable, ShouldRewriteSlice(n));
350 if (is_rewritable) {
351 slices_to_rewrite.push_back(n);
352 }
353 }
354
355 for (Node* n : slices_to_rewrite) {
356 TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
357 GetSliceInputs(n));
358 TF_RET_CHECK(slice_inputs.has_value());
359 TF_RETURN_IF_ERROR(
360 RewriteSlice(g, n, *slice_inputs, *GetXlaClusterForNode(*n)));
361 }
362
363 if (!slices_to_rewrite.empty()) {
364 // We've added constants to the graph; hook them up to _SOURCE.
365 FixupSourceAndSinkEdges(g);
366 }
367
368 *changed = !slices_to_rewrite.empty();
369
370 return Status::OK();
371 }
372 } // namespace
373
Run(const GraphOptimizationPassOptions & options)374 Status IncreaseDynamismForAutoJitPass::Run(
375 const GraphOptimizationPassOptions& options) {
376 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
377 if (flags->tf_xla_clustering_debug) {
378 DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass",
379 **options.graph, options.flib_def);
380 }
381
382 bool changed;
383 TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed));
384 if (changed && flags->tf_xla_clustering_debug) {
385 DumpGraphToFile("increase_dynamism_for_auto_jit_pass", **options.graph,
386 options.flib_def);
387 }
388
389 return Status::OK();
390 }
391
392 } // namespace tensorflow
393