Home
last modified time | relevance | path

Searched refs:rhs_shape (Results 1 – 25 of 51) sorted by relevance

123

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dxla_broadcast_helper_op.cc39 const TensorShape rhs_shape = context->InputShape(1); in Compile() local
41 const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims(); in Compile()
42 const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape; in Compile()
43 const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape; in Compile()
51 lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 || in Compile()
52 rhs_shape.dims() == 0, in Compile()
57 lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); in Compile()
69 lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); in Compile()
89 lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); in Compile()
Dmatrix_triangular_solve_op.cc37 const TensorShape rhs_shape = ctx->InputShape(1); in Compile() local
45 MatMulBCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); in Compile()
49 rhs_shape.DebugString())); in Compile()
63 std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast); in Compile()
76 const TensorShape& rhs_shape, const MatMulBCast& broadcast_helper);
83 xla::XlaOp rhs, const TensorShape& rhs_shape, in Broadcast() argument
87 int64 n = rhs_shape.dim_size(rhs_shape.dims() - 1); in Broadcast()
Dcwise_ops.cc35 const TensorShape rhs_shape = ctx->InputShape(1); in Compile() local
43 BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape), in Compile()
48 rhs_shape.DebugString())); in Compile()
68 int max_rank = std::max(lhs_shape.dims(), rhs_shape.dims()); in Compile()
69 int min_rank = std::min(lhs_shape.dims(), rhs_shape.dims()); in Compile()
81 rhs_shape.dim_sizes(), bcast, extend_dimension); in Compile()
/external/tensorflow/tensorflow/core/kernels/mlir_generated/
Dbase_binary_ops_test.h45 const TensorShape& rhs_shape, in SetOpKernel() argument
61 AddInputFromArray<T>(rhs_shape, rhs_input); in SetOpKernel()
70 const TensorShape& rhs_shape, in RunAndExpectResult() argument
75 SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input, in RunAndExpectResult()
95 const TensorShape& rhs_shape, in RunAndExpectInvalidArgument() argument
98 SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input, in RunAndExpectInvalidArgument()
114 TensorShape rhs_shape{2}; in TestIncompatibleShapes()
118 test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements()); in TestIncompatibleShapes()
121 rhs_shape, repeated_rhs_input, config); in TestIncompatibleShapes()
235 TensorShape rhs_shape{6}; in TestBroadcastingExpand()
[all …]
/external/tensorflow/tensorflow/python/kernel_tests/
Dtridiagonal_solve_op_test.py516 def test_raises(diags_shape, rhs_shape): argument
517 self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "compact")
527 def test_raises(diags_tuple_shapes, rhs_shape): argument
529 self._assertRaises(diagonals, _tf_ones(rhs_shape), "sequence")
541 def test_raises(diags_shape, rhs_shape): argument
542 self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "matrix")
552 rhs_shape, argument
560 rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
574 rhs_shape=[None],
583 rhs_shape=[4],
[all …]
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dunroll_batch_matmul.cc215 auto rhs_shape = rhs_type.getShape(); in matchAndRewrite() local
221 const int dims_b = rhs_shape.size(); in matchAndRewrite()
240 rhs_shape = rhs_type.getShape(); in matchAndRewrite()
243 if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) { in matchAndRewrite()
251 RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, element_type); in matchAndRewrite()
268 for (auto dim : rhs_shape) { in matchAndRewrite()
277 rhs_shape.begin(), rhs_shape.end())); in matchAndRewrite()
294 rhs_shape[dims_b - 1], element_type, loc, rewriter); in matchAndRewrite()
301 result_shape.push_back(rhs_shape[dims_b - 1]); in matchAndRewrite()
Deinsum.cc245 std::vector<int64_t> rhs_shape; in reshapeForBatchMatmul() local
247 rhs_shape.reserve(dnums.lhs_rhs_out.size() + 2); in reshapeForBatchMatmul()
251 rhs_shape.push_back(b); in reshapeForBatchMatmul()
279 rhs_shape.push_back(lhs_rhs_size); in reshapeForBatchMatmul()
285 rhs_shape.push_back(rhs_size); in reshapeForBatchMatmul()
289 failed(VerifyShapeOfReshapeOp(rhs_shape))) in reshapeForBatchMatmul()
294 *rhs = createReshapeOp(*rhs, rhs_shape, rhs_type.getElementType(), loc, in reshapeForBatchMatmul()
Dbatchmatmul_to_einsum.cc62 auto rhs_shape = rhs_type.getShape(); in matchAndRewrite() local
66 const int dims_b = rhs_shape.size(); in matchAndRewrite()
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Ddot_op_emitter.cc70 Shape rhs_shape; member
79 rhs_shape = instr.operand(1)->shape(); in DotInfo()
255 Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape}; in EmitLinalgMatmul()
268 dot_info_.rhs_shape.ToString(true)); in EmitLinalgMatmul()
283 dot_info_.rhs_shape.rank()); in EmitLinalgMatmul()
518 const Shape& rhs_shape = rhs_array_.GetShape(); in Emit() local
520 if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) { in Emit()
523 ShapeUtil::IsScalar(rhs_shape)); in Emit()
553 const Shape& rhs_shape = rhs_array_.GetShape(); in EmitNaiveLlvmIrGemm() local
564 rhs_shape.dimensions(rhs_reduction_dimension)); in EmitNaiveLlvmIrGemm()
[all …]
Dcpu_layout_assignment_test.cc67 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); in TEST_F() local
72 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); in TEST_F()
102 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); in TEST_F() local
109 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); in TEST_F()
145 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); in TEST_F() local
153 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); in TEST_F()
186 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); in TEST_F() local
191 HloInstruction::CreateParameter(0, rhs_shape, "param0")); in TEST_F()
200 ShapeLayout(LayoutUtil::GetWithDefaultLayout(rhs_shape)); in TEST_F()
219 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); in TEST_F() local
[all …]
Ddot_op_emitter_internal.h39 Shape rhs_shape; member
46 rhs_shape = instr.operand(1)->shape(); in DotInfo()
/external/tensorflow/tensorflow/compiler/xla/service/
Ddot_decomposer.cc96 const auto& rhs_shape = original_dot->operand(1)->shape(); in CanonicalizeDot() local
97 const int64 rhs_rank = rhs_shape.rank(); in CanonicalizeDot()
106 rhs_contracting_size *= rhs_shape.dimensions(i); in CanonicalizeDot()
110 rhs_non_contracting_size *= rhs_shape.dimensions(i); in CanonicalizeDot()
129 ShapeUtil::PermuteDimensions(rhs_transpose, rhs_shape), in CanonicalizeDot()
140 ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims), in CanonicalizeDot()
Dshape_inference_test.cc417 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); in TEST_F() local
439 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
462 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); in TEST_F() local
485 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
508 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4}); in TEST_F() local
531 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
542 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2}); in TEST_F() local
570 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
592 Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4}); in TEST_F() local
607 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6, in TEST_F()
[all …]
/external/tensorflow/tensorflow/lite/kernels/
Dbatch_matmul.cc510 const RuntimeShape& rhs_shape, const TfLiteTensor* rhs, in EvalInt8() argument
529 op_params, rhs_shape, GetTensorData<int8_t>(rhs), lhs_shape, in EvalInt8()
533 optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs), in EvalInt8()
545 const RuntimeShape& rhs_shape, const TfLiteTensor* rhs, in EvalInt16() argument
563 op_params, rhs_shape, GetTensorData<int16_t>(rhs), lhs_shape, in EvalInt16()
573 const RuntimeShape& rhs_shape, in EvalQuantized() argument
592 context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized, in EvalQuantized()
595 return EvalInt8<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs, in EvalQuantized()
598 return EvalInt16<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs, in EvalQuantized()
682 RuntimeShape rhs_shape = in Eval() local
[all …]
/external/tensorflow/tensorflow/lite/kernels/internal/reference/
Dbatch_matmul.h54 const RuntimeShape& rhs_shape, const float* rhs_data, in BatchMatMul() argument
59 RuntimeShape::ExtendedShape(5, rhs_shape); in BatchMatMul()
109 const RuntimeShape& rhs_shape, const int8_t* rhs_data, in BatchMatMul() argument
117 RuntimeShape::ExtendedShape(5, rhs_shape); in BatchMatMul()
201 const RuntimeShape& rhs_shape, const T* rhs_data, in BatchMatMul() argument
206 RuntimeShape::ExtendedShape(5, rhs_shape); in BatchMatMul()
/external/tensorflow/tensorflow/core/kernels/
Dcwise_ops_test.cc215 TensorShape rhs_shape; in BiasAdd() local
216 rhs_shape = TensorShape({cols}); in BiasAdd()
217 Tensor rhs(type, rhs_shape); in BiasAdd()
325 TensorShape lhs_shape, rhs_shape; in BcastAdd() local
328 rhs_shape = TensorShape({rows, 1}); in BcastAdd()
331 rhs_shape = TensorShape({cols}); in BcastAdd()
334 rhs_shape = TensorShape({1, cols}); in BcastAdd()
337 rhs_shape = TensorShape({rows, 1}); in BcastAdd()
341 Tensor rhs(DT_FLOAT, rhs_shape); in BcastAdd()
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dgemm_thunk.cc40 config.rhs_shape = gemm->operand(1)->shape(); in GetGpuGemmConfig()
183 const Shape &rhs_shape = gemm_config.rhs_shape; in RunGemm() local
204 for (const auto *shape : {&lhs_shape, &rhs_shape, &output_shape}) { in RunGemm()
246 rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim); in RunGemm()
Dgpu_layout_assignment.cc124 Shape rhs_shape = instr->operand(1)->shape(); in AddBackendConstraintsToDnnConvCustomCall() local
136 filter_shape = &rhs_shape; in AddBackendConstraintsToDnnConvCustomCall()
141 filter_shape = &rhs_shape; in AddBackendConstraintsToDnnConvCustomCall()
147 output_shape = &rhs_shape; in AddBackendConstraintsToDnnConvCustomCall()
175 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1)); in AddBackendConstraintsToDnnConvCustomCall()
/external/tensorflow/tensorflow/core/kernels/mkl/
Dmkl_batch_matmul_op.cc178 const TensorShape& lhs_shape, const TensorShape& rhs_shape, in CreateMatMulParams() argument
181 const auto ndims_rhs = rhs_shape.dims(); in CreateMatMulParams()
184 auto rhs_dims = TFShapeToMklDnnDims(rhs_shape); in CreateMatMulParams()
196 ExpandInputDimsToOutputShape(rhs_shape, out_shape, &rhs_dims); in CreateMatMulParams()
/external/tensorflow/tensorflow/lite/kernels/internal/optimized/
Dbatch_matmul.h29 const RuntimeShape& rhs_shape, const float* rhs_data, in BatchMatMul() argument
38 RuntimeShape::ExtendedShape(5, rhs_shape); in BatchMatMul()
116 const RuntimeShape& rhs_shape, const int8_t* rhs_data, in BatchMatMul() argument
129 RuntimeShape::ExtendedShape(5, rhs_shape); in BatchMatMul()
273 const RuntimeShape& rhs_shape, const int8_t* rhs_data, in BatchMatMul() argument
283 RuntimeShape::ExtendedShape(5, rhs_shape); in BatchMatMul()
/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/
Dlegalize_tf.cc599 auto rhs_shape = rhs.getType().cast<ShapedType>().getShape(); in matchAndRewrite() local
601 if (lhs_shape == rhs_shape) { in matchAndRewrite()
607 if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, in matchAndRewrite()
631 if (result_type.getShape() != rhs_shape) { in matchAndRewrite()
667 auto rhs_shape = rhs.getType().cast<ShapedType>().getShape(); in matchAndRewrite() local
670 if (lhs_shape == rhs_shape && cond_shape == lhs_shape) { in matchAndRewrite()
676 if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, in matchAndRewrite()
712 if (result_shape != rhs_shape) { in matchAndRewrite()
/external/tensorflow/tensorflow/compiler/tests/
Dtridiagonal_solve_ops_test.py477 def test_raises(diags_shape, rhs_shape): argument
478 self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "compact")
487 def test_raises(diags_tuple_shapes, rhs_shape): argument
489 self._assertRaises(diagonals, _tf_ones(rhs_shape), "sequence")
500 def test_raises(diags_shape, rhs_shape): argument
501 self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "matrix")
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dtridiagonal.cc62 TF_ASSIGN_OR_RETURN(Shape rhs_shape, builder->GetShape(rhs)); in CheckSystemAndReturnNumEquations()
67 const auto rhs_rank = rhs_shape.rank(); in CheckSystemAndReturnNumEquations()
88 const auto rhs_num_eqs = ShapeUtil::GetDimension(rhs_shape, rank - 1); in CheckSystemAndReturnNumEquations()
/external/tensorflow/tensorflow/python/ops/linalg/
Dlinear_operator_tridiag.py316 rhs_shape = array_ops.shape(rhs)
319 self._shape_tensor(diagonals)[:-2], rhs_shape[:-2])
322 [broadcast_shape, rhs_shape[-2:]], axis=-1))
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/
Dlower_general_dot.cc166 auto rhs_shape = rhs_shape_type.getShape(); in matchAndRewrite() local
168 RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); in matchAndRewrite()

123