Home
last modified time | relevance | path

Searched refs:buffer_shape (Results 1 – 15 of 15) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dscatter_nd_op.cc34 Status ValidateUpdateShape(const TensorShape& buffer_shape, in ValidateUpdateShape() argument
52 ", buffer_shape: ", buffer_shape.DebugString(), in ValidateUpdateShape()
57 if (buffer_shape.dims() < in ValidateUpdateShape()
62 batch_dim + buffer_shape.dims() - num_index_dims) { in ValidateUpdateShape()
72 buffer_shape.dim_size(d + num_index_dims)) { in ValidateUpdateShape()
89 TensorShape buffer_shape; in Compile() local
90 OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape)); in Compile()
93 context, TensorShapeUtils::IsVectorOrHigher(buffer_shape), in Compile()
95 "got shape: ", buffer_shape.DebugString())); in Compile()
99 buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 && in Compile()
[all …]
Dsegment_reduction_ops.cc81 TensorShape buffer_shape = data_shape; in Compile() local
82 buffer_shape.RemoveDimRange(0, indices_shape.dims()); in Compile()
83 buffer_shape.InsertDim(0, num_segments); in Compile()
86 xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); in Compile()
Dtensor_list_utils.cc139 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { in GetTensorListBufferShape() argument
146 *buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); in GetTensorListBufferShape()
212 auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); in GetLeadingDimForTensorList() local
213 *leading_dim_is_dynamic = buffer_shape.is_dynamic_dimension(0); in GetLeadingDimForTensorList()
215 *leading_dim = buffer_shape.dimensions(0); in GetLeadingDimForTensorList()
515 const xla::Shape& buffer_shape = in ExecuteTensorListGetItem() local
517 std::vector<xla::XlaOp> start_indices(buffer_shape.dimensions_size(), in ExecuteTensorListGetItem()
521 std::vector<int64> slice_shape = xla::SpanToVector(buffer_shape.dimensions()); in ExecuteTensorListGetItem()
528 for (int64 i = 1; i < buffer_shape.dimensions_size(); ++i) { in ExecuteTensorListGetItem()
529 if (buffer_shape.is_dynamic_dimension(i)) { in ExecuteTensorListGetItem()
Dtensor_list_utils.h43 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape);
Dtensor_list_ops.cc411 TensorShape buffer_shape; in Compile() local
412 OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(buffer_xla_shape, &buffer_shape)); in Compile()
416 ctx, XlaGather(buffer, buffer_shape, indices, indices_shape, /*axis=*/0, in Compile()
/external/tensorflow/tensorflow/compiler/tf2xla/lib/
Dscatter.cc39 TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); in XlaScatter()
51 if (num_index_dims > buffer_shape.rank()) { in XlaScatter()
56 xla::ShapeUtil::HumanString(buffer_shape), ")"); in XlaScatter()
74 if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) { in XlaScatter()
77 xla::ShapeUtil::HumanString(buffer_shape)); in XlaScatter()
144 int64 buffer_rank = buffer_shape.rank(); in XlaScatter()
153 expected_updates_dims.push_back(buffer_shape.dimensions(dim)); in XlaScatter()
179 xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {}); in XlaScatter()
189 VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape); in XlaScatter()
/external/tensorflow/tensorflow/core/tpu/kernels/xla/
Dsegment_reduction_ops.cc71 TensorShape buffer_shape = data_shape; in Compile() local
72 buffer_shape.RemoveDimRange(0, indices_shape.dims()); in Compile()
73 buffer_shape.InsertDim(0, num_segments); in Compile()
76 buffer_shape.dim_sizes()); in Compile()
/external/tensorflow/tensorflow/compiler/xla/service/
Dtransfer_manager.cc213 const Shape& buffer_shape = in ReadDynamicShapes() local
215 if (buffer_shape.IsTuple()) { in ReadDynamicShapes()
226 Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape); in ReadDynamicShapes()
228 int64 metadata_size = shape_size_fn(buffer_shape) - offset; in ReadDynamicShapes()
239 ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}), in ReadDynamicShapes()
Dlayout_assignment.cc536 const Shape& buffer_shape = instruction->operand(0)->shape(); in AddMandatoryConstraints() local
537 TF_RET_CHECK(buffer_shape.IsArray()); in AddMandatoryConstraints()
540 ->LayoutShapeForChannel(buffer_shape, channel_id); in AddMandatoryConstraints()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dcollection_ops_util.cc153 llvm::SmallVector<int64_t, 8> buffer_shape; in CreateInitBufferValue() local
154 buffer_shape.push_back(max_size); in CreateInitBufferValue()
156 buffer_shape.push_back(dim); in CreateInitBufferValue()
164 auto buffer_type = RankedTensorType::get(buffer_shape, element_dtype); in CreateInitBufferValue()
167 ArrayRef<Value>{zero, GetR1Const(buffer_shape, builder, op->getLoc())}); in CreateInitBufferValue()
Dtensor_array_ops_decomposition.cc336 llvm::SmallVector<int64_t, 8> buffer_shape; in HandleTensorArraySplitV3Op() local
337 buffer_shape.push_back(count); in HandleTensorArraySplitV3Op()
338 for (int64_t dim : elem_type.getShape()) buffer_shape.push_back(dim); in HandleTensorArraySplitV3Op()
344 buffer_shape, elem_type.getElementType())}, in HandleTensorArraySplitV3Op()
346 cutil::GetR1Const(buffer_shape, builder, in HandleTensorArraySplitV3Op()
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dbuffer_comparator.cc564 const Shape& buffer_shape, in DeviceCompare() argument
619 CalculateLaunchDimensions(buffer_shape, gpu_device_info); in DeviceCompare()
Dgpu_conv_algorithm_picker.cc396 const Shape& buffer_shape) { in PickBestAlgorithmNoCacheCuda() argument
398 InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer); in PickBestAlgorithmNoCacheCuda()
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_compiler.cc913 xla::Shape buffer_shape; in XLAShapeForArgument() local
915 TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); in XLAShapeForArgument()
917 {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})}); in XLAShapeForArgument()
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dmhlo_to_lhlo_with_xla.cc1486 const xla::Shape& buffer_shape = xla::ShapeUtil::GetSubshape( in Initialize() local
1492 buffer_shape, builder_)); in Initialize()