1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
17 
18 #include <numeric>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/tf2xla/literal_util.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
25 #include "tensorflow/compiler/tf2xla/xla_context.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/core/common_runtime/dma_helper.h"
30 
31 namespace tensorflow {
32 
XlaOpKernelContext(OpKernelContext * context)33 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
34     : context_(context), dynamic_dimension_is_minus_one_(false) {}
35 
ValidateInputsAreSameShape(OpKernel * op)36 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
37   return context_->ValidateInputsAreSameShape(op);
38 }
39 
xla_context() const40 XlaContext* XlaOpKernelContext::xla_context() const {
41   return &XlaContext::Get(context_);
42 }
43 
builder() const44 xla::XlaBuilder* XlaOpKernelContext::builder() const {
45   return xla_context()->builder();
46 }
47 
compiler() const48 XlaCompiler* XlaOpKernelContext::compiler() const {
49   return xla_context()->compiler();
50 }
51 
InputExpression(int index)52 const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
53   return *XlaExpression::CastExpressionFromTensor(context_->input(index));
54 }
55 
InputExpression(absl::string_view name)56 const XlaExpression& XlaOpKernelContext::InputExpression(
57     absl::string_view name) {
58   return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
59 }
60 
Input(int index)61 xla::XlaOp XlaOpKernelContext::Input(int index) {
62   return InputExpression(index).AsXlaOp(builder());
63 }
64 
Input(absl::string_view name)65 xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
66   return InputExpression(name).AsXlaOp(builder());
67 }
68 
InputShape(int index)69 TensorShape XlaOpKernelContext::InputShape(int index) {
70   return context_->input(index).shape();
71 }
72 
InputShape(absl::string_view name)73 TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
74   return GetInputTensorByName(name).shape();
75 }
76 
InputXlaShape(int index)77 xla::StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(int index) {
78   return builder()->GetShape(Input(index));
79 }
80 
InputXlaShape(absl::string_view name)81 xla::StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(
82     absl::string_view name) {
83   return builder()->GetShape(Input(name));
84 }
85 
input_type(int index) const86 DataType XlaOpKernelContext::input_type(int index) const {
87   DataType type = context_->input_dtype(index);
88   if (type == DT_UINT8) {
89     // Masqueraded XlaExpression could have different type. See
90     // XlaOpKernelContext::SetOutputExpression for details.
91     auto expression =
92         XlaExpression::CastExpressionFromTensor(context_->input(index));
93     type = expression->dtype();
94   }
95   return type;
96 }
97 
InputType(absl::string_view name)98 DataType XlaOpKernelContext::InputType(absl::string_view name) {
99   const Tensor& tensor = GetInputTensorByName(name);
100   DataType type = tensor.dtype();
101   if (type == DT_UINT8) {
102     // Masqueraded XlaExpression could have different type. See
103     // XlaOpKernelContext::SetOutputExpression for details.
104     auto expression = XlaExpression::CastExpressionFromTensor(tensor);
105     type = expression->dtype();
106   }
107   return type;
108 }
109 
input_xla_type(int index)110 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
111   xla::PrimitiveType type;
112   Status status = DataTypeToPrimitiveType(input_type(index), &type);
113   if (!status.ok()) {
114     SetStatus(status);
115     return xla::PRIMITIVE_TYPE_INVALID;
116   }
117   return type;
118 }
119 
InputXlaType(absl::string_view name)120 xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) {
121   xla::PrimitiveType type;
122   Status status = DataTypeToPrimitiveType(InputType(name), &type);
123   if (!status.ok()) {
124     SetStatus(status);
125     return xla::PRIMITIVE_TYPE_INVALID;
126   }
127   return type;
128 }
129 
ConstantInput(int index,xla::Literal * constant_literal)130 Status XlaOpKernelContext::ConstantInput(int index,
131                                          xla::Literal* constant_literal) {
132   return ConstantInputReshaped(
133       index, context_->input(index).shape().dim_sizes(), constant_literal);
134 }
135 
InputIndex(XlaOpKernelContext * context,absl::string_view name)136 static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
137                                      absl::string_view name) {
138   int start, stop;
139   TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
140   if (stop != start + 1) {
141     return errors::InvalidArgument("OpKernel used list-valued input name '",
142                                    name,
143                                    "' when single-valued input was "
144                                    "expected");
145   }
146   return start;
147 }
148 
ConstantInput(absl::string_view name,xla::Literal * constant_literal)149 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
150                                          xla::Literal* constant_literal) {
151   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
152   return ConstantInput(index, constant_literal);
153 }
154 
ConstantInputReshaped(int index,absl::Span<const int64> new_dims,xla::Literal * constant_literal)155 Status XlaOpKernelContext::ConstantInputReshaped(
156     int index, absl::Span<const int64> new_dims,
157     xla::Literal* constant_literal) {
158   XlaExpression e = InputExpression(index);
159   auto* client = compiler() ? compiler()->client() : nullptr;
160   xla::StatusOr<absl::optional<Tensor>> constant_or_status =
161       e.ResolveConstant(client, dynamic_dimension_is_minus_one_);
162   if (!constant_or_status.ok()) {
163     Status status = constant_or_status.status();
164     errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
165                             context_->op_kernel().type_string(),
166                             " operator as a compile-time constant.");
167     return status;
168   }
169   absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
170   if (!constant.has_value()) {
171     return errors::InvalidArgument(
172         "Input ", index, " to node `", context_->op_kernel().name(),
173         "` with op ", context_->op_kernel().type_string(),
174         " must be a compile-time constant.\n\n"
175         "XLA compilation requires that operator arguments that represent "
176         "shapes or dimensions be evaluated to concrete values at compile time. "
177         "This error means that a shape or dimension argument could not be "
178         "evaluated at compile time, usually because the value of the argument "
179         "depends on a parameter to the computation, on a variable, or on a "
180         "stateful operation such as a random number generator.");
181   }
182 
183   Tensor temp(constant->dtype());
184   if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
185     return errors::InvalidArgument(
186         context_->op_kernel().name(), " input ", index, " has shape ",
187         constant->shape().DebugString(),
188         " but was asked to be reshaped to incompatible shape ",
189         TensorShape(new_dims).DebugString());
190   }
191 
192   TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
193   return Status::OK();
194 }
195 
196 // Converts an int32 or int64 scalar literal to an int64.
LiteralToInt64Scalar(const xla::LiteralSlice & literal,int64 * out)197 static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
198                                    int64* out) {
199   if (literal.shape().rank() != 0) {
200     return errors::InvalidArgument("value is not a scalar");
201   }
202   if (literal.shape().element_type() == xla::S32) {
203     *out = literal.Get<int32>({});
204   } else if (literal.shape().element_type() == xla::S64) {
205     *out = literal.Get<int64>({});
206   } else {
207     return errors::InvalidArgument("value must be either int32 or int64");
208   }
209   return Status::OK();
210 }
211 
212 // Converts an float32 or float64 scalar literal to a float64.
LiteralToFloat64Scalar(const xla::LiteralSlice & literal,double * out)213 static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
214                                      double* out) {
215   if (literal.shape().rank() != 0) {
216     return errors::InvalidArgument("value is not a scalar");
217   }
218   if (literal.shape().element_type() == xla::F32) {
219     *out = literal.Get<float>({});
220   } else if (literal.shape().element_type() == xla::F64) {
221     *out = literal.Get<double>({});
222   } else {
223     return errors::InvalidArgument("value must be either float32 or float64");
224   }
225   return Status::OK();
226 }
227 
ConstantInputAsIntScalar(int index,int64 * out)228 Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
229   xla::Literal literal;
230   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
231   return LiteralToInt64Scalar(literal, out);
232 }
233 
ConstantInputAsIntScalar(absl::string_view name,int64 * out)234 Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name,
235                                                     int64* out) {
236   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
237   return ConstantInputAsIntScalar(index, out);
238 }
239 
ConstantInputAsFloatScalar(int index,double * out)240 Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
241   xla::Literal literal;
242   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
243   return LiteralToFloat64Scalar(literal, out);
244 }
245 
LiteralToPredVector(const xla::LiteralSlice & literal,std::vector<bool> * out)246 static Status LiteralToPredVector(const xla::LiteralSlice& literal,
247                                   std::vector<bool>* out) {
248   if (literal.shape().rank() != 1) {
249     return errors::InvalidArgument("value is not 1D, rank: ",
250                                    literal.shape().rank());
251   }
252   int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
253   if (literal.shape().element_type() != xla::PRED) {
254     return errors::InvalidArgument("value is not PRED");
255   }
256   for (int64 i = 0; i < size; ++i) {
257     out->push_back(literal.Get<bool>({i}));
258   }
259   return Status::OK();
260 }
261 
ResolveInputDynamismIntoPred(int index,bool * out)262 Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
263   xla::Literal literal;
264   XlaExpression e = InputExpression(index);
265   auto* client = compiler() ? compiler()->client() : nullptr;
266   xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
267   if (!dynamism_or_status.ok()) {
268     // When failed to resolve dynamism, conservatively consider the value
269     // dynamic.
270     //
271     // TODO(b/176993339): Support resolving dynamism across computations so
272     // resolving dynamism will not fail.
273     *out = true;
274     return Status::OK();
275   }
276   Tensor dynamism = dynamism_or_status.ValueOrDie();
277 
278   Tensor temp(dynamism.dtype());
279   TensorShape tensor_shape({});
280   if (!temp.CopyFrom(dynamism, tensor_shape)) {
281     return errors::InvalidArgument(
282         context_->op_kernel().name(), " input ", index, " has shape ",
283         dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
284   }
285 
286   TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
287   *out = literal.Get<bool>({});
288   return Status::OK();
289 }
290 
ResolveInputDynamismIntoPredVector(int index,std::vector<bool> * out)291 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
292     int index, std::vector<bool>* out) {
293   xla::Literal literal;
294   XlaExpression e = InputExpression(index);
295   auto* client = compiler() ? compiler()->client() : nullptr;
296   xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
297   if (!dynamism_or_status.ok()) {
298     // When failed to resolve dynamism, conservatively consider the value
299     // dynamic.
300     //
301     // TODO(b/176993339): Support resolving dynamism across computations so
302     // resolving dynamism will not fail.
303     out->resize(InputShape(index).num_elements(), false);
304     return Status::OK();
305   }
306   Tensor dynamism = dynamism_or_status.ValueOrDie();
307 
308   Tensor temp(dynamism.dtype());
309   TensorShape tensor_shape({InputShape(index).num_elements()});
310   if (!temp.CopyFrom(dynamism, tensor_shape)) {
311     return errors::InvalidArgument(
312         context_->op_kernel().name(), " input ", index, " has shape ",
313         dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape);
314   }
315 
316   TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
317   return LiteralToPredVector(literal, out);
318 }
319 
320 // Converts an int32 or int64 1D literal to an int64 vector.
LiteralToInt64Vector(const xla::LiteralSlice & literal,std::vector<int64> * out)321 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
322                                    std::vector<int64>* out) {
323   if (literal.shape().rank() != 1) {
324     return errors::InvalidArgument("value is not 1D, rank: ",
325                                    literal.shape().rank());
326   }
327   int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
328   if (literal.shape().element_type() == xla::S32) {
329     for (int64 i = 0; i < size; ++i) {
330       out->push_back(literal.Get<int32>({i}));
331     }
332   } else if (literal.shape().element_type() == xla::S64) {
333     for (int64 i = 0; i < size; ++i) {
334       out->push_back(literal.Get<int64>({i}));
335     }
336   } else {
337     return errors::InvalidArgument("value must be either int32 or int64");
338   }
339   return Status::OK();
340 }
341 
ConstantInputAsIntVector(int index,std::vector<int64> * out)342 Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
343                                                     std::vector<int64>* out) {
344   xla::Literal literal;
345   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
346   return LiteralToInt64Vector(literal, out);
347 }
348 
ConstantInputAsIntVector(absl::string_view name,std::vector<int64> * out)349 Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name,
350                                                     std::vector<int64>* out) {
351   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
352   return ConstantInputAsIntVector(index, out);
353 }
354 
ConstantInputReshapedToIntVector(int index,std::vector<int64> * out)355 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
356     int index, std::vector<int64>* out) {
357   xla::Literal literal;
358   TF_RETURN_IF_ERROR(ConstantInputReshaped(
359       index, {InputShape(index).num_elements()}, &literal));
360   return LiteralToInt64Vector(literal, out);
361 }
362 
ConstantInputReshapedToIntVector(absl::string_view name,std::vector<int64> * out)363 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
364     absl::string_view name, std::vector<int64>* out) {
365   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
366   xla::Literal literal;
367   TF_RETURN_IF_ERROR(ConstantInputReshaped(
368       index, {InputShape(index).num_elements()}, &literal));
369   return LiteralToInt64Vector(literal, out);
370 }
371 
ConstantInputAsInt64Literal(int index,xla::Literal * out)372 Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
373                                                        xla::Literal* out) {
374   xla::Literal literal;
375   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
376   switch (literal.shape().element_type()) {
377     case xla::S32: {
378       *out = xla::Literal(
379           xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
380       auto src_data = literal.data<int32>();
381       for (int64 i = 0; i < src_data.size(); ++i) {
382         out->data<int64>()[i] = src_data[i];
383       }
384       return Status::OK();
385     }
386     case xla::S64:
387       *out = std::move(literal);
388       return Status::OK();
389 
390     default:
391       return errors::InvalidArgument(
392           "Invalid argument to ConstantInputAsInt64Literal: ",
393           xla::ShapeUtil::HumanString(literal.shape()));
394   }
395 }
396 
ConstantInputAsInt64Literal(absl::string_view name,xla::Literal * out)397 Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name,
398                                                        xla::Literal* out) {
399   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
400   return ConstantInputAsInt64Literal(index, out);
401 }
402 
403 // TODO(phawkins): validate that the dimensions form a valid shape, fail
404 // gracefully if they do not.
ConstantInputAsShape(int index,TensorShape * shape)405 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
406   xla::Literal literal;
407   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
408   std::vector<int64> dims;
409   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
410   *shape = TensorShape(dims);
411   return Status::OK();
412 }
413 
ConstantInputAsPartialShape(int index,PartialTensorShape * shape)414 Status XlaOpKernelContext::ConstantInputAsPartialShape(
415     int index, PartialTensorShape* shape) {
416   xla::Literal literal;
417   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
418   // If `literal` is a scalar it's value must be -1.
419   if (literal.shape().rank() == 0) {
420     int64 shape_val;
421     TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
422     if (shape_val != -1) {
423       return errors::InvalidArgument(
424           "Cannot convert value to PartialTensorShape: ", shape_val);
425     }
426     *shape = PartialTensorShape();  // Shape with unknown rank.
427     return Status::OK();
428   }
429   std::vector<int64> dims;
430   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
431   *shape = PartialTensorShape(dims);
432   return Status::OK();
433 }
434 
InputList(absl::string_view name,std::vector<xla::XlaOp> * handles,std::vector<TensorShape> * shapes)435 Status XlaOpKernelContext::InputList(absl::string_view name,
436                                      std::vector<xla::XlaOp>* handles,
437                                      std::vector<TensorShape>* shapes) {
438   OpInputList inputs;
439   TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
440   handles->clear();
441   shapes->clear();
442   for (const Tensor& input : inputs) {
443     handles->push_back(
444         XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
445     shapes->push_back(input.shape());
446   }
447   return Status::OK();
448 }
449 
ConstantInputList(absl::string_view name,std::vector<xla::Literal> * outputs)450 Status XlaOpKernelContext::ConstantInputList(
451     absl::string_view name, std::vector<xla::Literal>* outputs) {
452   int start, stop;
453   TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
454   outputs->resize(stop - start);
455   for (int i = start; i < stop; ++i) {
456     TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
457   }
458   return Status::OK();
459 }
460 
461 namespace {
462 
ReadVariableInputTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,TensorShape * shape,xla::XlaOp * value)463 Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
464                                const XlaOpKernelContext* ctx,
465                                TensorShape* shape, xla::XlaOp* value) {
466   const XlaExpression* expression =
467       XlaExpression::CastExpressionFromTensor(tensor);
468   XlaResource* variable = expression->resource();
469   TF_RET_CHECK(variable != nullptr);
470   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
471   if (!variable->initialized()) {
472     return errors::FailedPrecondition(
473         "Read variable failure ", variable->name(),
474         ". It could mean the variable is uninitialized or the variable is on "
475         "another device ");
476   }
477   if (variable->type() != type) {
478     return errors::InvalidArgument(
479         "Type mismatch for read of variable ", variable->name(), ". Expected ",
480         DataTypeString(type), "; got ", DataTypeString(variable->type()));
481   }
482   if (shape) {
483     *shape = variable->shape();
484   }
485 
486   if (!variable->IsOverwritten() && expression->constant_value()) {
487     TF_ASSIGN_OR_RETURN(xla::Literal literal,
488                         HostTensorToLiteral(*expression->constant_value()));
489     *value = xla::ConstantLiteral(ctx->builder(), literal);
490     return Status::OK();
491   }
492 
493   TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
494                       ctx->compiler()->options().shape_representation_fn(
495                           variable->shape(), variable->type(),
496                           /*use_fast_memory=*/false));
497   xla::Shape xla_shape;
498   TF_RETURN_IF_ERROR(
499       TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
500   if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
501     *value = variable->value();
502   } else {
503     *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
504   }
505   return Status::OK();
506 }
507 
508 }  // namespace
509 
ReadVariableInput(int index,DataType type,TensorShape * shape,xla::XlaOp * value)510 Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
511                                              TensorShape* shape,
512                                              xla::XlaOp* value) {
513   return ReadVariableInputTensor(context_->input(index), type, this, shape,
514                                  value);
515 }
516 
ReadVariableInput(absl::string_view name,DataType type,TensorShape * shape,xla::XlaOp * value)517 Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
518                                              DataType type, TensorShape* shape,
519                                              xla::XlaOp* value) {
520   return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
521                                  value);
522 }
523 
GetVariableTypeAndShape(int index,DataType * type,TensorShape * shape) const524 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
525                                                    TensorShape* shape) const {
526   const Tensor& tensor = context_->input(index);
527   const XlaExpression* expression =
528       XlaExpression::CastExpressionFromTensor(tensor);
529   XlaResource* variable = expression->resource();
530   TF_RET_CHECK(variable != nullptr);
531   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
532   if (!variable->initialized()) {
533     return errors::InvalidArgument(
534         "Read variable failure ", variable->name(),
535         ". It could mean the variable is uninitialized or the variable is on "
536         "another device ");
537   }
538   *type = variable->type();
539   *shape = variable->shape();
540   return Status::OK();
541 }
542 
SetOutputExpression(int index,const XlaExpression & expression)543 void XlaOpKernelContext::SetOutputExpression(int index,
544                                              const XlaExpression& expression) {
545   Status status = [&] {
546     // The step's default allocator is the dummy XlaCompilationAllocator which
547     // simply allocates a metadata buffer to hold the expression to which it
548     // corresponds.
549     // Provides a special behavior for DT_VARIANT and other types that are not
550     // trivially copyable. In those cases, allocate a tensor of type DT_UINT8.
551     if (!DataTypeCanUseMemcpy(expression.dtype())) {
552       // tensor_data() is not supported for tensors that cannot be copied via
553       // memcpy, as the copy logic might try to inspect the stored data (e.g.
554       // a std::string). This is likely to fail, as the data is invalid given
555       // that it actually encodes an XlaExpression. Using a uint8 tensor is
556       // always safe, so simply do that.
557       // TODO(jpienaar): This should be refactored to stop masquerading
558       // XlaExpressions as Tensors.
559       Tensor output;
560       TensorShape tensor_shape;
561       TF_RETURN_IF_ERROR(
562           context_->allocate_temp(DT_UINT8, tensor_shape, &output));
563       context_->set_output(index, output);
564     } else {
565       Tensor* output = nullptr;
566       TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
567       TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
568     }
569     XlaExpression::AssignExpressionToTensor(expression,
570                                             context_->mutable_output(index));
571     return Status::OK();
572   }();
573   if (!status.ok()) {
574     SetStatus(status);
575   }
576 }
577 
output_xla_type(int index)578 xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
579   xla::PrimitiveType type;
580   Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
581   if (!status.ok()) {
582     SetStatus(status);
583     return xla::PRIMITIVE_TYPE_INVALID;
584   }
585   return type;
586 }
587 
SetOutput(int index,const xla::XlaOp & handle)588 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
589   SetOutputExpression(
590       index,
591       XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
592 }
593 
SetConstantOutput(int index,const Tensor & constant)594 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
595   SetOutputExpression(index, XlaExpression::Constant(constant));
596 }
597 
SetTensorListOutput(int index,const xla::XlaOp & handle)598 void XlaOpKernelContext::SetTensorListOutput(int index,
599                                              const xla::XlaOp& handle) {
600   SetOutputExpression(index, XlaExpression::TensorList(handle));
601 }
602 
SetResourceOutput(int index,XlaResource * resource)603 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
604   SetOutputExpression(index, XlaExpression::Resource(resource));
605 }
606 
GetResourceInput(int index,XlaResource ** resource)607 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
608   const XlaExpression* expression =
609       XlaExpression::CastExpressionFromTensor(context_->input(index));
610   TF_RET_CHECK(expression->resource() != nullptr);
611   *resource = expression->resource();
612   return Status::OK();
613 }
614 
615 namespace {
616 
AssignVariableTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,xla::XlaOp handle,xla::XlaBuilder * builder)617 Status AssignVariableTensor(const Tensor& tensor, DataType type,
618                             const XlaOpKernelContext* ctx, xla::XlaOp handle,
619                             xla::XlaBuilder* builder) {
620   const XlaExpression* expression =
621       XlaExpression::CastExpressionFromTensor(tensor);
622   XlaResource* variable = expression->resource();
623   TF_RET_CHECK(variable != nullptr);
624   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
625 
626   auto shape_or_status = builder->GetShape(handle);
627   if (!shape_or_status.ok()) {
628     return shape_or_status.status();
629   }
630   TensorShape shape;
631   TF_RETURN_IF_ERROR(
632       XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
633 
634   TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
635 
636   TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
637                       ctx->compiler()->options().shape_representation_fn(
638                           shape, type,
639                           /*use_fast_memory=*/false));
640   xla::Shape xla_shape;
641   TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
642   if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
643     handle = xla::Reshape(handle,
644                           xla::AsInt64Slice(representation_shape.dimensions()));
645   }
646   variable->SetRepresentationShape(representation_shape);
647   return variable->SetValue(handle);
648 }
649 
650 }  // namespace
651 
AssignVariable(int input_index,DataType type,xla::XlaOp handle)652 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
653                                           xla::XlaOp handle) {
654   TF_RET_CHECK(handle.valid());
655   return AssignVariableTensor(context_->input(input_index), type, this, handle,
656                               builder());
657 }
658 
AssignVariable(absl::string_view name,DataType type,xla::XlaOp handle)659 Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
660                                           xla::XlaOp handle) {
661   TF_RET_CHECK(handle.valid());
662   return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
663                               builder());
664 }
665 
GetStatusWithStackTrace(const Status & s,const XlaOpKernelContext * ctx)666 static Status GetStatusWithStackTrace(const Status& s,
667                                       const XlaOpKernelContext* ctx) {
668   if (s.code() == error::INVALID_ARGUMENT) {
669     return Status{s.code(),
670                   absl::StrCat(s.error_message(), "\n", ctx->StackTrace())};
671   }
672   return s;
673 }
674 
CtxFailure(const Status & s)675 void XlaOpKernelContext::CtxFailure(const Status& s) {
676   context_->CtxFailure(GetStatusWithStackTrace(s, this));
677 }
CtxFailureWithWarning(const Status & s)678 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
679   context_->CtxFailureWithWarning(GetStatusWithStackTrace(s, this));
680 }
681 
CtxFailure(const char * file,int line,const Status & s)682 void XlaOpKernelContext::CtxFailure(const char* file, int line,
683                                     const Status& s) {
684   context_->CtxFailure(file, line, GetStatusWithStackTrace(s, this));
685 }
CtxFailureWithWarning(const char * file,int line,const Status & s)686 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
687                                                const Status& s) {
688   context_->CtxFailureWithWarning(file, line, GetStatusWithStackTrace(s, this));
689 }
690 
GetOrCreateMax(const DataType type)691 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
692     const DataType type) {
693   return xla_context()->GetOrCreateMax(type);
694 }
695 
GetOrCreateMin(const DataType type)696 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
697     const DataType type) {
698   return xla_context()->GetOrCreateMin(type);
699 }
700 
GetOrCreateAdd(const DataType type)701 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
702     const DataType type) {
703   return xla_context()->GetOrCreateAdd(type);
704 }
705 
GetOrCreateMul(const DataType type)706 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
707     const DataType type) {
708   return xla_context()->GetOrCreateMul(type);
709 }
710 
GetInputTensorByName(absl::string_view name)711 const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
712   const Tensor* tensor;
713   CHECK(context_->input(name, &tensor).ok());
714   return *tensor;
715 }
716 
XlaOpKernel(OpKernelConstruction * context)717 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
718 
Compute(OpKernelContext * context)719 void XlaOpKernel::Compute(OpKernelContext* context) {
720   XlaOpKernelContext xla_context(context);
721   Compile(&xla_context);
722 }
723 
StackTrace() const724 std::string XlaOpKernelContext::StackTrace() const {
725   if (const AbstractStackTrace* stack_trace =
726           xla_context()->StackTraceForNodeName(op_kernel().name())) {
727     AbstractStackTrace::TracePrintingOptions opts;
728     opts.show_line_contents = true;
729     opts.filter_common_prefix = true;
730     opts.drop_internal_frames = true;
731     return absl::StrCat("\nStack trace for op definition: \n",
732                         stack_trace->ToString(opts), "\n");
733   } else {
734     return "";
735   }
736 }
737 
738 }  // namespace tensorflow
739