1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17 #include "tensorflow/compiler/xla/client/xla_builder.h"
18 
19 namespace xla {
20 
SliceInMinorDims(XlaOp x,absl::Span<const int64> start,absl::Span<const int64> end)21 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
22                        absl::Span<const int64> end) {
23   XlaBuilder* builder = x.builder();
24   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
25     TF_RET_CHECK(start.size() == end.size());
26     int64 n_minor_dims = start.size();
27 
28     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
29 
30     const int64 n_dims = shape.rank();
31     TF_RET_CHECK(n_minor_dims <= n_dims);
32     auto major_dims = AsInt64Slice(shape.dimensions())
33                           .subspan(
34                               /*pos=*/0,
35                               /*len=*/n_dims - n_minor_dims);
36 
37     // Prepends 0s in the major dim
38     std::vector<int64> padded_start(n_dims, 0);
39     std::copy(start.begin(), start.end(),
40               padded_start.begin() + major_dims.size());
41 
42     // Prepends the shape of the major dims.
43     std::vector<int64> padded_end(n_dims);
44     std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
45     std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
46 
47     std::vector<int64> strides(n_dims, 1);
48     return Slice(x, padded_start, padded_end, strides);
49   });
50 }
51 
UpdateSlice(XlaOp x,XlaOp update,absl::Span<const int64> start)52 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) {
53   XlaBuilder* builder = x.builder();
54   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
55     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
56     const int64 n_dims = shape.rank();
57     TF_RET_CHECK(start.size() == n_dims);
58 
59     // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
60     std::vector<int32> start_as_int32(start.begin(), start.end());
61     std::vector<XlaOp> start_ops(start.size());
62     for (int i = 0; i < start.size(); ++i) {
63       start_ops[i] = ConstantR0(builder, start_as_int32[i]);
64     }
65     return DynamicUpdateSlice(x, update, start_ops);
66   });
67 }
68 
UpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const int64> start)69 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
70                              absl::Span<const int64> start) {
71   XlaBuilder* builder = x.builder();
72   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
73     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
74     const int64 n_dims = shape.rank();
75     const int64 n_minor_dims = start.size();
76     TF_RET_CHECK(n_minor_dims <= n_dims);
77     std::vector<int64> padded_start(n_dims, 0);
78     std::copy(start.begin(), start.end(),
79               padded_start.begin() + (n_dims - n_minor_dims));
80     return UpdateSlice(x, update, padded_start);
81   });
82 }
83 
84 namespace {
85 
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)86 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
87                                  absl::Span<const int64> ys) {
88   std::vector<int64> output(xs.size() + ys.size());
89   std::copy(xs.begin(), xs.end(), output.begin());
90   std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
91   return output;
92 }
93 
PrependZerosInMajorDims(XlaOp x,absl::Span<const XlaOp> starts)94 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims(
95     XlaOp x, absl::Span<const XlaOp> starts) {
96   XlaBuilder* builder = x.builder();
97   TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
98   const int64 n_dims = shape.rank();
99   auto zero = ConstantR0<int32>(builder, 0);
100   std::vector<XlaOp> padded_starts(n_dims, zero);
101   for (int i = 0; i < starts.size(); ++i) {
102     padded_starts[n_dims - starts.size() + i] = starts[i];
103   }
104   return padded_starts;
105 }
106 
107 }  // namespace
108 
DynamicSliceInMinorDims(XlaOp x,absl::Span<const XlaOp> starts,absl::Span<const int64> sizes)109 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
110                               absl::Span<const int64> sizes) {
111   XlaBuilder* builder = x.builder();
112   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
113     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
114     const int64 n_dims = shape.rank();
115     int64 n_minor_dims = starts.size();
116     TF_RET_CHECK(n_minor_dims == sizes.size());
117     TF_RET_CHECK(n_minor_dims <= n_dims);
118     auto major_dims = AsInt64Slice(shape.dimensions())
119                           .subspan(
120                               /*pos=*/0,
121                               /*len=*/n_dims - sizes.size());
122     TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
123     auto padded_sizes = ConcatVectors(major_dims, sizes);
124     return DynamicSlice(x, padded_starts, padded_sizes);
125   });
126 }
127 
DynamicUpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const XlaOp> starts)128 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
129                                     absl::Span<const XlaOp> starts) {
130   XlaBuilder* builder = x.builder();
131   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
132     TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
133     return DynamicUpdateSlice(x, update, padded_starts);
134   });
135 }
136 
TorchGather(XlaOp input,XlaOp index,int64 dim)137 XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) {
138   XlaBuilder* builder = input.builder();
139   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
140     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
141     ShapeUtil::AppendMajorDimension(1, &index_shape);
142     std::vector<XlaOp> to_concat;
143     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
144     to_concat.reserve(input_shape.rank());
145     for (int64 i = 0; i < input_shape.rank(); ++i) {
146       if (i == dim) {
147         to_concat.push_back(Reshape(index, index_shape.dimensions()));
148       } else {
149         to_concat.push_back(Iota(builder, index_shape, i));
150       }
151     }
152     XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank());
153     std::vector<int64> slice_sizes(input_shape.rank(), 1);
154     GatherDimensionNumbers gather_dnums;
155     gather_dnums.set_index_vector_dim(input_shape.rank());
156     for (int64 i = 0; i < input_shape.rank(); ++i) {
157       gather_dnums.add_collapsed_slice_dims(i);
158       gather_dnums.add_start_index_map(i);
159     }
160     return Gather(input, gather_indices, gather_dnums, slice_sizes);
161   });
162 }
163 
TorchIndexSelect(XlaOp input,XlaOp index,int64 dim)164 XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim) {
165   XlaBuilder* builder = input.builder();
166   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
167     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
168     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
169     std::vector<int64> slice_sizes = input_shape.dimensions();
170     slice_sizes[dim] = 1;
171     GatherDimensionNumbers gather_dnums;
172     for (int64 i = 0; i < input_shape.rank(); ++i) {
173       if (i != dim) {
174         gather_dnums.add_offset_dims(i);
175       }
176     }
177     gather_dnums.set_index_vector_dim(index_shape.rank());
178     gather_dnums.add_collapsed_slice_dims(dim);
179     gather_dnums.add_start_index_map(dim);
180     return Gather(input, index, gather_dnums, slice_sizes);
181   });
182 }
183 
184 }  // namespace xla
185