Home
last modified time | relevance | path

Searched refs:diag_shape (Results 1 – 10 of 10) sorted by relevance

/external/tensorflow/tensorflow/lite/testing/op_tests/
Dmatrix_set_diag.py41 diag_shape = parameters["input_diag_shapes"][1]
45 dtype=parameters["input_dtype"], name="diagonal", shape=diag_shape)
51 diag_shape = parameters["input_diag_shapes"][1]
53 diag_values = create_tensor_data(parameters["input_dtype"], diag_shape)
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dmatrix_diag_ops.cc244 const TensorShape diag_shape = context->InputShape(0); in Compile() local
245 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape), in Compile()
248 diag_shape.DebugString())); in Compile()
273 const int64 diag_rank = diag_shape.dims(); in Compile()
274 const int64 max_diag_len = diag_shape.dim_size(diag_rank - 1); in Compile()
278 num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2), in Compile()
311 TensorShape output_shape = diag_shape; in Compile()
467 const TensorShape diag_shape = context->InputShape(1); in Compile() local
469 const int diag_rank = diag_shape.dims(); in Compile()
476 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape), in Compile()
[all …]
/external/tensorflow/tensorflow/python/ops/
Dlinalg_ops_impl.py65 diag_shape = array_ops.concat((batch_shape, [diag_size]), axis=0)
71 diag_shape = batch_shape + [diag_size]
75 diag_ones = array_ops.ones(diag_shape, dtype=dtype)
Darray_grad.py432 diag_shape = op.inputs[1].get_shape()
433 batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
436 diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
444 diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
446 grad, array_ops.zeros(diag_shape, dtype=grad.dtype))
454 diag_shape = op.inputs[1].get_shape()
455 if not diag_shape.is_fully_defined():
479 diag_shape = array_ops.concat([batch_shape, postfix], 0)
482 grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2])
490 diag_shape = op.inputs[1].get_shape()
[all …]
/external/tensorflow/tensorflow/core/kernels/linalg/
Dmatrix_set_diag_op.cc89 const TensorShape& diag_shape = diag.shape(); in Compute() local
97 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape), in Compute()
100 diag_shape.DebugString())); in Compute()
129 (diag_shape.dim_size(input_rank - 2) == num_diags), in Compute()
142 context, expected_diag_shape == diag_shape, in Compute()
148 "\nDiagonal shape: ", diag_shape.DebugString(), in Compute()
/external/tensorflow/tensorflow/python/kernel_tests/linalg/
Dlinear_operator_low_rank_update_test.py65 def _gen_positive_diag(self, dtype, diag_shape): argument
68 diag_shape, minval=1e-4, maxval=1., dtype=dtypes.float32)
72 diag_shape, minval=1e-4, maxval=1., dtype=dtype)
78 diag_shape = shape[:-1]
85 base_diag = self._gen_positive_diag(dtype, diag_shape)
/external/tensorflow/tensorflow/python/ops/linalg/
Dlinear_operator_circulant.py459 diag_shape = self.shape[:-1]
462 diag_shape = self.shape_tensor()[:-1]
464 ones_diag = array_ops.ones(diag_shape, dtype=self.dtype)
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dmatrix.cc176 TF_ASSIGN_OR_RETURN(Shape diag_shape, builder->GetShape(diag)); in SetMatrixDiagonal()
181 const int64 d = diag_shape.dimensions(n_dims - 2); in SetMatrixDiagonal()
192 for (xla::int64 i = 0; i < diag_shape.rank() - 1; ++i) { in SetMatrixDiagonal()
/external/tensorflow/tensorflow/core/framework/
Dcommon_shape_fns.cc1423 ShapeHandle input_shape, diag_shape, diag_index_shape; in MatrixSetDiagV2Shape() local
1425 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape)); in MatrixSetDiagV2Shape()
1451 &diag_shape)); in MatrixSetDiagV2Shape()
1454 c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape)); in MatrixSetDiagV2Shape()
1456 c->WithRankAtMost(c->input(1), input_rank, &diag_shape)); in MatrixSetDiagV2Shape()
1476 if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) { in MatrixSetDiagV2Shape()
1480 diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2, in MatrixSetDiagV2Shape()
1487 c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape)); in MatrixSetDiagV2Shape()
1488 TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape)); in MatrixSetDiagV2Shape()
/external/tensorflow/tensorflow/python/kernel_tests/
Ddiag_op_test.py719 def _testGrad(self, input_shape, diag_shape, diags, align): argument
724 np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
749 diag_shape = input_shape[:-2] + num_diags_dim + (min(input_shape[-2:]),)
750 self._testGrad(input_shape, diag_shape, diags, align)