1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
17 
18 #include <numeric>
19 
20 #include "tensorflow/compiler/tf2xla/literal_util.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
24 #include "tensorflow/compiler/tf2xla/xla_context.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 
30 namespace tensorflow {
31 
XlaOpKernelContext(OpKernelContext * context)32 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
33     : context_(context) {}
34 
ValidateInputsAreSameShape(OpKernel * op)35 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
36   return context_->ValidateInputsAreSameShape(op);
37 }
38 
xla_context() const39 XlaContext* XlaOpKernelContext::xla_context() const {
40   return &XlaContext::Get(context_);
41 }
42 
builder() const43 xla::XlaBuilder* XlaOpKernelContext::builder() const {
44   return xla_context()->builder();
45 }
46 
compiler() const47 XlaCompiler* XlaOpKernelContext::compiler() const {
48   return xla_context()->compiler();
49 }
50 
51 // Retrieves an XlaExpression that was allocated by a previous Op.
CastExpressionFromTensor(const Tensor & tensor)52 static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
53   const XlaExpression* expression =
54       reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
55   CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
56       << expression->HumanString();
57   return expression;
58 }
59 
60 // Assigns an XlaExpression to a tensor on an XLA compilation device.
AssignExpressionToTensor(Tensor * tensor,const XlaExpression & value)61 static void AssignExpressionToTensor(Tensor* tensor,
62                                      const XlaExpression& value) {
63   const XlaExpression* expression =
64       reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
65   CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
66       << expression->HumanString();
67   *const_cast<XlaExpression*>(expression) = value;
68 }
69 
InputExpression(int index)70 const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
71   return *CastExpressionFromTensor(context_->input(index));
72 }
73 
InputExpression(absl::string_view name)74 const XlaExpression& XlaOpKernelContext::InputExpression(
75     absl::string_view name) {
76   return *CastExpressionFromTensor(GetInputTensorByName(name));
77 }
78 
Input(int index)79 xla::XlaOp XlaOpKernelContext::Input(int index) {
80   return InputExpression(index).AsXlaOp(builder());
81 }
82 
Input(absl::string_view name)83 xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
84   return InputExpression(name).AsXlaOp(builder());
85 }
86 
InputShape(int index)87 TensorShape XlaOpKernelContext::InputShape(int index) {
88   return context_->input(index).shape();
89 }
90 
InputShape(absl::string_view name)91 TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
92   return GetInputTensorByName(name).shape();
93 }
94 
input_type(int index) const95 DataType XlaOpKernelContext::input_type(int index) const {
96   return context_->input_dtype(index);
97 }
98 
InputType(absl::string_view name)99 DataType XlaOpKernelContext::InputType(absl::string_view name) {
100   return GetInputTensorByName(name).dtype();
101 }
102 
input_xla_type(int index)103 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
104   xla::PrimitiveType type;
105   Status status = DataTypeToPrimitiveType(input_type(index), &type);
106   if (!status.ok()) {
107     SetStatus(status);
108     return xla::PRIMITIVE_TYPE_INVALID;
109   }
110   return type;
111 }
112 
ConstantInput(int index,xla::Literal * constant_literal)113 Status XlaOpKernelContext::ConstantInput(int index,
114                                          xla::Literal* constant_literal) {
115   return ConstantInputReshaped(
116       index, context_->input(index).shape().dim_sizes(), constant_literal);
117 }
118 
InputIndex(XlaOpKernelContext * context,absl::string_view name)119 static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
120                                      absl::string_view name) {
121   int start, stop;
122   TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
123   if (stop != start + 1) {
124     return errors::InvalidArgument("OpKernel used list-valued input name '",
125                                    name,
126                                    "' when single-valued input was "
127                                    "expected");
128   }
129   return start;
130 }
131 
ConstantInput(absl::string_view name,xla::Literal * constant_literal)132 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
133                                          xla::Literal* constant_literal) {
134   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
135   return ConstantInput(index, constant_literal);
136 }
137 
ConstantInputReshaped(int index,absl::Span<const int64> new_dims,xla::Literal * constant_literal)138 Status XlaOpKernelContext::ConstantInputReshaped(
139     int index, absl::Span<const int64> new_dims,
140     xla::Literal* constant_literal) {
141   XlaExpression e = InputExpression(index);
142   xla::StatusOr<absl::optional<Tensor>> constant_or_status =
143       e.ResolveConstant(compiler()->client());
144   if (!constant_or_status.ok()) {
145     Status status = constant_or_status.status();
146     errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
147                             context_->op_kernel().type_string(),
148                             " operator as a compile-time constant.");
149     return status;
150   }
151   absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
152   if (!constant.has_value()) {
153     return errors::InvalidArgument(
154         "Input ", index, " to ", context_->op_kernel().type_string(),
155         " operator must be a compile-time constant.\n"
156         "\n"
157         "XLA compilation requires that operator arguments that represent "
158         "shapes or dimensions be evaluated to concrete values at compile time. "
159         "This error means that a shape or dimension argument could not be "
160         "evaluated at compile time, usually because the value of the argument "
161         "depends on a parameter to the computation, on a variable, or on a "
162         "stateful operation such as a random number generator.");
163   }
164 
165   Tensor temp(constant->dtype());
166   if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
167     return errors::InvalidArgument(
168         context_->op_kernel().name(), " input ", index, " has shape ",
169         constant->shape().DebugString(),
170         " but was asked to be reshaped to incompatible shape ",
171         TensorShape(new_dims).DebugString());
172   }
173 
174   TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
175   return Status::OK();
176 }
177 
178 // Converts an int32 or int64 scalar literal to an int64.
LiteralToInt64Scalar(const xla::LiteralSlice & literal,int64 * out)179 static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
180                                    int64* out) {
181   if (literal.shape().rank() != 0) {
182     return errors::InvalidArgument("value is not a scalar");
183   }
184   if (literal.shape().element_type() == xla::S32) {
185     *out = literal.Get<int32>({});
186   } else if (literal.shape().element_type() == xla::S64) {
187     *out = literal.Get<int64>({});
188   } else {
189     return errors::InvalidArgument("value must be either int32 or int64");
190   }
191   return Status::OK();
192 }
193 
194 // Converts an float32 or float64 scalar literal to a float64.
LiteralToFloat64Scalar(const xla::LiteralSlice & literal,double * out)195 static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
196                                      double* out) {
197   if (literal.shape().rank() != 0) {
198     return errors::InvalidArgument("value is not a scalar");
199   }
200   if (literal.shape().element_type() == xla::F32) {
201     *out = literal.Get<float>({});
202   } else if (literal.shape().element_type() == xla::F64) {
203     *out = literal.Get<double>({});
204   } else {
205     return errors::InvalidArgument("value must be either float32 or float64");
206   }
207   return Status::OK();
208 }
209 
ConstantInputAsIntScalar(int index,int64 * out)210 Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
211   xla::Literal literal;
212   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
213   return LiteralToInt64Scalar(literal, out);
214 }
215 
ConstantInputAsIntScalar(absl::string_view name,int64 * out)216 Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name,
217                                                     int64* out) {
218   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
219   return ConstantInputAsIntScalar(index, out);
220 }
221 
ConstantInputAsFloatScalar(int index,double * out)222 Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
223   xla::Literal literal;
224   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
225   return LiteralToFloat64Scalar(literal, out);
226 }
227 
228 // Converts an int32 or int64 1D literal to an int64 vector.
LiteralToInt64Vector(const xla::LiteralSlice & literal,std::vector<int64> * out)229 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
230                                    std::vector<int64>* out) {
231   if (literal.shape().rank() != 1) {
232     return errors::InvalidArgument("value is not 1D, rank: ",
233                                    literal.shape().rank());
234   }
235   int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
236   if (literal.shape().element_type() == xla::S32) {
237     for (int64 i = 0; i < size; ++i) {
238       out->push_back(literal.Get<int32>({i}));
239     }
240   } else if (literal.shape().element_type() == xla::S64) {
241     for (int64 i = 0; i < size; ++i) {
242       out->push_back(literal.Get<int64>({i}));
243     }
244   } else {
245     return errors::InvalidArgument("value must be either int32 or int64");
246   }
247   return Status::OK();
248 }
249 
ConstantInputAsIntVector(int index,std::vector<int64> * out)250 Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
251                                                     std::vector<int64>* out) {
252   xla::Literal literal;
253   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
254   return LiteralToInt64Vector(literal, out);
255 }
256 
ConstantInputAsIntVector(absl::string_view name,std::vector<int64> * out)257 Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name,
258                                                     std::vector<int64>* out) {
259   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
260   return ConstantInputAsIntVector(index, out);
261 }
262 
ConstantInputReshapedToIntVector(int index,std::vector<int64> * out)263 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
264     int index, std::vector<int64>* out) {
265   xla::Literal literal;
266   TF_RETURN_IF_ERROR(ConstantInputReshaped(
267       index, {InputShape(index).num_elements()}, &literal));
268   return LiteralToInt64Vector(literal, out);
269 }
270 
ConstantInputReshapedToIntVector(absl::string_view name,std::vector<int64> * out)271 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
272     absl::string_view name, std::vector<int64>* out) {
273   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
274   xla::Literal literal;
275   TF_RETURN_IF_ERROR(ConstantInputReshaped(
276       index, {InputShape(index).num_elements()}, &literal));
277   return LiteralToInt64Vector(literal, out);
278 }
279 
ConstantInputAsInt64Literal(int index,xla::Literal * out)280 Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
281                                                        xla::Literal* out) {
282   xla::Literal literal;
283   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
284   switch (literal.shape().element_type()) {
285     case xla::S32: {
286       *out = xla::Literal(
287           xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
288       auto src_data = literal.data<int32>();
289       for (int64 i = 0; i < src_data.size(); ++i) {
290         out->data<int64>()[i] = src_data[i];
291       }
292       return Status::OK();
293     }
294     case xla::S64:
295       *out = std::move(literal);
296       return Status::OK();
297 
298     default:
299       return errors::InvalidArgument(
300           "Invalid argument to ConstantInputAsInt64Literal: ",
301           xla::ShapeUtil::HumanString(literal.shape()));
302   }
303 }
304 
ConstantInputAsInt64Literal(absl::string_view name,xla::Literal * out)305 Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name,
306                                                        xla::Literal* out) {
307   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
308   return ConstantInputAsInt64Literal(index, out);
309 }
310 
311 // TODO(phawkins): validate that the dimensions form a valid shape, fail
312 // gracefully if they do not.
ConstantInputAsShape(int index,TensorShape * shape)313 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
314   xla::Literal literal;
315   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
316   std::vector<int64> dims;
317   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
318   *shape = TensorShape(dims);
319   return Status::OK();
320 }
321 
ConstantInputAsPartialShape(int index,PartialTensorShape * shape)322 Status XlaOpKernelContext::ConstantInputAsPartialShape(
323     int index, PartialTensorShape* shape) {
324   xla::Literal literal;
325   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
326   // If `literal` is a scalar it's value must be -1.
327   if (literal.shape().rank() == 0) {
328     int64 shape_val;
329     TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
330     if (shape_val != -1) {
331       return errors::InvalidArgument(
332           "Cannot convert value to PartialTensorShape: ", shape_val);
333     }
334     *shape = PartialTensorShape();  // Shape with unknown rank.
335     return Status::OK();
336   }
337   std::vector<int64> dims;
338   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
339   *shape = PartialTensorShape(dims);
340   return Status::OK();
341 }
342 
InputList(absl::string_view name,std::vector<xla::XlaOp> * handles,std::vector<TensorShape> * shapes)343 Status XlaOpKernelContext::InputList(absl::string_view name,
344                                      std::vector<xla::XlaOp>* handles,
345                                      std::vector<TensorShape>* shapes) {
346   OpInputList inputs;
347   TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
348   handles->clear();
349   shapes->clear();
350   for (const Tensor& input : inputs) {
351     handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
352     shapes->push_back(input.shape());
353   }
354   return Status::OK();
355 }
356 
ConstantInputList(absl::string_view name,std::vector<xla::Literal> * outputs)357 Status XlaOpKernelContext::ConstantInputList(
358     absl::string_view name, std::vector<xla::Literal>* outputs) {
359   int start, stop;
360   TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
361   outputs->resize(stop - start);
362   for (int i = start; i < stop; ++i) {
363     TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
364   }
365   return Status::OK();
366 }
367 
368 namespace {
369 
ReadVariableInputTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,TensorShape * shape,xla::XlaOp * value)370 Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
371                                const XlaOpKernelContext* ctx,
372                                TensorShape* shape, xla::XlaOp* value) {
373   const XlaExpression* expression = CastExpressionFromTensor(tensor);
374   XlaResource* variable = expression->resource();
375   TF_RET_CHECK(variable != nullptr);
376   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
377   if (!variable->initialized()) {
378     return errors::FailedPrecondition("Read of uninitialized variable ",
379                                       variable->name());
380   }
381   if (variable->type() != type) {
382     return errors::InvalidArgument(
383         "Type mismatch for read of variable ", variable->name(), ". Expected ",
384         DataTypeString(type), "; got ", DataTypeString(variable->type()));
385   }
386   if (shape) {
387     *shape = variable->shape();
388   }
389 
390   TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
391                       ctx->compiler()->options().shape_representation_fn(
392                           variable->shape(), variable->type()));
393   xla::Shape xla_shape;
394   TF_RETURN_IF_ERROR(
395       TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
396   if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
397     *value = variable->value();
398   } else {
399     *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
400   }
401   return Status::OK();
402 }
403 
404 }  // namespace
405 
ReadVariableInput(int index,DataType type,TensorShape * shape,xla::XlaOp * value)406 Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
407                                              TensorShape* shape,
408                                              xla::XlaOp* value) {
409   return ReadVariableInputTensor(context_->input(index), type, this, shape,
410                                  value);
411 }
412 
ReadVariableInput(absl::string_view name,DataType type,TensorShape * shape,xla::XlaOp * value)413 Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
414                                              DataType type, TensorShape* shape,
415                                              xla::XlaOp* value) {
416   return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
417                                  value);
418 }
419 
GetVariableTypeAndShape(int index,DataType * type,TensorShape * shape) const420 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
421                                                    TensorShape* shape) const {
422   const Tensor& tensor = context_->input(index);
423   const XlaExpression* expression = CastExpressionFromTensor(tensor);
424   XlaResource* variable = expression->resource();
425   TF_RET_CHECK(variable != nullptr);
426   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
427   if (!variable->initialized()) {
428     return errors::InvalidArgument("Read of uninitialized variable ",
429                                    variable->name());
430   }
431   *type = variable->type();
432   *shape = variable->shape();
433   return Status::OK();
434 }
435 
SetOutputExpression(int index,const XlaExpression & expression)436 void XlaOpKernelContext::SetOutputExpression(int index,
437                                              const XlaExpression& expression) {
438   Status status = [&] {
439     // The step's default allocator is the dummy XlaCompilationAllocator which
440     // simply allocates a metadata buffer to hold the expression to which it
441     // corresponds.
442     Tensor* output = nullptr;
443     // Provides a special behavior for DT_VARIANT: a variant is treated as
444     // DT_UINT8 scalar as the type to allow mapping for variant to more generic
445     // types.
446     if (expression.dtype() == DT_VARIANT) {
447       // tensor_data() is not supported for variant Tensor (i.e.,
448       // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
449       // XlaExpression inside the Tensor's tensor_data() does not work for
450       // variant. Instead construct a uint8 tensor and store the expression in
451       // its value.
452       // TODO(jpienaar): This should be refactored to stop masquerading
453       // XlaExpressions as Tensors.
454       output = new Tensor();
455       TensorShape tensor_shape;
456       TF_RETURN_IF_ERROR(
457           context_->allocate_temp(DT_UINT8, tensor_shape, output));
458       context_->set_output(index, *output);
459     } else {
460       TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
461       TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
462     }
463     AssignExpressionToTensor(output, expression);
464     return Status::OK();
465   }();
466   if (!status.ok()) {
467     SetStatus(status);
468   }
469 }
470 
output_xla_type(int index)471 xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
472   xla::PrimitiveType type;
473   Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
474   if (!status.ok()) {
475     SetStatus(status);
476     return xla::PRIMITIVE_TYPE_INVALID;
477   }
478   return type;
479 }
480 
SetOutput(int index,const xla::XlaOp & handle)481 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
482   SetOutputExpression(
483       index,
484       XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
485 }
486 
SetConstantOutput(int index,const Tensor & constant)487 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
488   SetOutputExpression(index, XlaExpression::Constant(constant));
489 }
490 
SetTensorListOutput(int index,const xla::XlaOp & handle)491 void XlaOpKernelContext::SetTensorListOutput(int index,
492                                              const xla::XlaOp& handle) {
493   SetOutputExpression(index, XlaExpression::TensorList(handle));
494 }
495 
SetResourceOutput(int index,XlaResource * resource)496 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
497   SetOutputExpression(index, XlaExpression::Resource(resource));
498 }
499 
GetResourceInput(int index,XlaResource ** resource)500 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
501   const XlaExpression* expression =
502       CastExpressionFromTensor(context_->input(index));
503   TF_RET_CHECK(expression->resource() != nullptr);
504   *resource = expression->resource();
505   return Status::OK();
506 }
507 
508 namespace {
509 
AssignVariableTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,xla::XlaOp handle,xla::XlaBuilder * builder)510 Status AssignVariableTensor(const Tensor& tensor, DataType type,
511                             const XlaOpKernelContext* ctx, xla::XlaOp handle,
512                             xla::XlaBuilder* builder) {
513   const XlaExpression* expression = CastExpressionFromTensor(tensor);
514   XlaResource* variable = expression->resource();
515   TF_RET_CHECK(variable != nullptr);
516   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
517 
518   auto shape_or_status = builder->GetShape(handle);
519   if (!shape_or_status.ok()) {
520     return shape_or_status.status();
521   }
522   TensorShape shape;
523   TF_RETURN_IF_ERROR(
524       XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
525 
526   TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
527 
528   TF_ASSIGN_OR_RETURN(
529       xla::Shape representation_shape,
530       ctx->compiler()->options().shape_representation_fn(shape, type));
531   xla::Shape xla_shape;
532   TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
533   if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
534     handle = xla::Reshape(handle,
535                           xla::AsInt64Slice(representation_shape.dimensions()));
536   }
537   variable->SetRepresentationShape(representation_shape);
538   return variable->SetValue(handle);
539 }
540 
541 }  // namespace
542 
AssignVariable(int input_index,DataType type,xla::XlaOp handle)543 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
544                                           xla::XlaOp handle) {
545   TF_RET_CHECK(handle.valid());
546   return AssignVariableTensor(context_->input(input_index), type, this, handle,
547                               builder());
548 }
549 
AssignVariable(absl::string_view name,DataType type,xla::XlaOp handle)550 Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
551                                           xla::XlaOp handle) {
552   TF_RET_CHECK(handle.valid());
553   return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
554                               builder());
555 }
556 
CtxFailure(const Status & s)557 void XlaOpKernelContext::CtxFailure(const Status& s) {
558   context_->CtxFailure(s);
559 }
CtxFailureWithWarning(const Status & s)560 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
561   context_->CtxFailureWithWarning(s);
562 }
CtxFailure(const char * file,int line,const Status & s)563 void XlaOpKernelContext::CtxFailure(const char* file, int line,
564                                     const Status& s) {
565   context_->CtxFailure(file, line, s);
566 }
CtxFailureWithWarning(const char * file,int line,const Status & s)567 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
568                                                const Status& s) {
569   context_->CtxFailureWithWarning(file, line, s);
570 }
571 
GetOrCreateMax(const DataType type)572 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
573     const DataType type) {
574   return xla_context()->GetOrCreateMax(type);
575 }
576 
GetOrCreateMin(const DataType type)577 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
578     const DataType type) {
579   return xla_context()->GetOrCreateMin(type);
580 }
581 
GetOrCreateAdd(const DataType type)582 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
583     const DataType type) {
584   return xla_context()->GetOrCreateAdd(type);
585 }
586 
GetOrCreateMul(const DataType type)587 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
588     const DataType type) {
589   return xla_context()->GetOrCreateMul(type);
590 }
591 
GetInputTensorByName(absl::string_view name)592 const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
593   const Tensor* tensor;
594   CHECK(context_->input(name, &tensor).ok());
595   return *tensor;
596 }
597 
XlaOpKernel(OpKernelConstruction * context)598 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
599 
Compute(OpKernelContext * context)600 void XlaOpKernel::Compute(OpKernelContext* context) {
601   XlaOpKernelContext xla_context(context);
602   Compile(&xla_context);
603 }
604 
605 }  // namespace tensorflow
606