1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Ops for split. 17 18 #include "tensorflow/compiler/tf2xla/type_util.h" 19 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/compiler/xla/literal.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 29 namespace tensorflow { 30 namespace { 31 32 class SplitOp : public XlaOpKernel { 33 public: SplitOp(OpKernelConstruction * ctx)34 explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 35 Compile(XlaOpKernelContext * ctx)36 void Compile(XlaOpKernelContext* ctx) override { 37 const int32 num_split = num_outputs(); 38 const TensorShape split_dim_shape = ctx->InputShape("split_dim"); 39 const TensorShape input_shape = ctx->InputShape(1); 40 41 OP_REQUIRES( 42 ctx, TensorShapeUtils::IsScalar(split_dim_shape), 43 errors::InvalidArgument("split_dim must be a scalar but has rank ", 44 split_dim_shape.dims())); 45 int64 split_dim_orig; 46 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &split_dim_orig)); 47 48 int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() 49 : split_dim_orig; 50 OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), 51 errors::InvalidArgument("-input rank(-", input_shape.dims(), 52 ") <= split_dim < input rank (", 53 input_shape.dims(), "), but got ", 54 split_dim_orig)); 55 56 OP_REQUIRES( 57 ctx, num_split > 0, 58 errors::InvalidArgument( 59 "Number of ways to split should be > 0, but got ", num_split)); 60 61 OP_REQUIRES( 62 ctx, input_shape.dim_size(split_dim) % num_split == 0, 63 errors::InvalidArgument( 64 "Number of ways to split should evenly divide the split " 65 "dimension, but got split_dim ", 66 split_dim_orig, " (size = ", input_shape.dim_size(split_dim), ") ", 67 "and num_split ", num_split)); 68 69 // All the slices are the same size: this is the size along the 70 // split dimension. 71 const int32 slice_size = input_shape.dim_size(split_dim) / num_split; 72 73 // The vectors we will use to define the slice. The entry for the 74 // split dimensions varies for each output. 75 std::vector<int64> begin(input_shape.dims(), 0); 76 std::vector<int64> limits(input_shape.dims()); 77 std::vector<int64> strides(input_shape.dims(), 1); 78 for (int i = 0; i < input_shape.dims(); ++i) { 79 // Initially set up the limits to be the full size of the input: 80 // the split dimension is filled in below. 81 int64 dim = input_shape.dim_size(i); 82 limits[i] = dim; 83 } 84 85 auto input = ctx->Input(1); 86 87 // Create each of the outputs. 88 for (int i = 0; i < num_split; ++i) { 89 // Slice out the ith split from the split dimension. 90 begin[split_dim] = i * slice_size; 91 limits[split_dim] = (i + 1) * slice_size; 92 ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); 93 } 94 } 95 }; 96 97 REGISTER_XLA_OP(Name("Split").CompileTimeConstantInput("split_dim"), SplitOp); 98 99 class SplitVOp : public XlaOpKernel { 100 public: SplitVOp(OpKernelConstruction * ctx)101 explicit SplitVOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 102 Compile(XlaOpKernelContext * ctx)103 void Compile(XlaOpKernelContext* ctx) override { 104 const int32 num_split = num_outputs(); 105 const TensorShape input_shape = ctx->InputShape(0); 106 const TensorShape index_shape = ctx->InputShape(2); 107 108 int64 split_dim_orig; 109 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &split_dim_orig)); 110 int64 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() 111 : split_dim_orig; 112 OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), 113 errors::InvalidArgument("-input rank(-", input_shape.dims(), 114 ") <= split_dim < input rank (", 115 input_shape.dims(), "), but got ", 116 split_dim_orig)); 117 118 xla::XlaOp input = ctx->Input(0); 119 120 OP_REQUIRES(ctx, input_shape.dims() > 0, 121 errors::InvalidArgument("Can't split a 0 dimensional input")); 122 123 OP_REQUIRES( 124 ctx, num_split > 0, 125 errors::InvalidArgument( 126 "Number of ways to split should be > 0, but got ", num_split)); 127 128 // Check that sizes are correct. 129 int total_split_size = 0; 130 int neg_one_dim = -1; 131 const TensorShape split_size_shape = ctx->InputShape(1); 132 OP_REQUIRES(ctx, 133 split_size_shape.dims() == 1 && 134 split_size_shape.num_elements() == num_split, 135 errors::InvalidArgument( 136 "shape of tensor describing " 137 " the output must have dimension 1 and the same " 138 " number of elements as the output. Got ", 139 split_size_shape.dims(), "-D and ", 140 split_size_shape.num_elements(), " elements")); 141 // Get the dimension of this split. 142 std::vector<int64> split_sizes; 143 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &split_sizes)); 144 145 for (int i = 0; i < num_split; ++i) { 146 int64 slice_size = split_sizes[i]; 147 if (slice_size == -1) { 148 OP_REQUIRES( 149 ctx, neg_one_dim == -1, 150 errors::InvalidArgument("Only one dimensions can have a value of" 151 "-1. Second one found at dimension ", 152 i)); 153 neg_one_dim = i; 154 } else { 155 total_split_size += slice_size; 156 } 157 } 158 159 OP_REQUIRES( 160 ctx, 161 (neg_one_dim == -1 && 162 total_split_size == input_shape.dim_size(split_dim)) || 163 (neg_one_dim >= 0 && 164 total_split_size <= input_shape.dim_size(split_dim)), 165 errors::InvalidArgument("Determined shape must either match " 166 "input shape along split_dim exactly if " 167 "fully specified, or be less than the size of " 168 "the input along split_dim if not fully " 169 "specified. Got: ", 170 total_split_size)); 171 172 if (neg_one_dim >= 0) { 173 split_sizes[neg_one_dim] = 174 input_shape.dim_size(split_dim) - total_split_size; 175 } 176 177 // The vectors we will use to define the slice. The entry for the 178 // split dimensions varies for each output. 179 std::vector<int64> begin(input_shape.dims(), 0); 180 auto dim_sizes = input_shape.dim_sizes(); 181 std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end()); 182 std::vector<int64> strides(input_shape.dims(), 1); 183 for (int i = 0; i < num_split; ++i) { 184 TensorShape output_shape(input_shape); 185 int slice_size = split_sizes[i]; 186 output_shape.set_dim(split_dim, slice_size); 187 188 // Slice out the ith split from the split dimension. 189 limits[split_dim] = begin[split_dim] + slice_size; 190 ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); 191 begin[split_dim] = limits[split_dim]; 192 } 193 } 194 }; 195 196 REGISTER_XLA_OP(Name("SplitV") 197 .CompileTimeConstantInput("split_dim") 198 .CompileTimeConstantInput("size_splits"), 199 SplitVOp); 200 201 } // namespace 202 } // namespace tensorflow 203