Home
last modified time | relevance | path

Searched refs:batch_shape (Results 1 – 25 of 139) sorted by relevance

123456

/external/tensorflow/tensorflow/python/ops/
Dlinalg_ops_impl.py35 batch_shape=None, argument
43 name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
45 batch_shape = [] if batch_shape is None else batch_shape
62 if isinstance(batch_shape, ops.Tensor) or isinstance(diag_size, ops.Tensor):
63 batch_shape = ops.convert_to_tensor(
64 batch_shape, name='shape', dtype=dtypes.int32)
65 diag_shape = array_ops.concat((batch_shape, [diag_size]), axis=0)
67 shape = array_ops.concat((batch_shape, [num_rows, num_columns]), axis=0)
70 batch_shape = list(batch_shape)
71 diag_shape = batch_shape + [diag_size]
[all …]
Dlinalg_ops.py71 batch_shape = matrix_shape[:-2]
76 identity = eye(small_dim, batch_shape=batch_shape, dtype=matrix.dtype)
200 batch_shape=None, argument
239 batch_shape=batch_shape,
326 batch_shape = tensor_shape[:-2]
331 is_io_bound = batch_shape.num_elements() > np.min(matrix_shape)
/external/tensorflow/tensorflow/python/kernel_tests/linalg/
Dlinear_operator_zeros_test.py67 batch_shape = shape[:-2]
71 num_rows, batch_shape=batch_shape, dtype=dtype)
111 linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=2)
115 linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[2.])
119 linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[-2])
139 batch_shape = array_ops.placeholder_with_default(2, shape=None)
142 num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
147 batch_shape = array_ops.placeholder_with_default([-2], shape=None)
150 num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
197 num_rows=2, batch_shape=variables_module.Variable([2]))
[all …]
Dlinear_operator_identity_test.py57 batch_shape = shape[:-2]
61 num_rows, batch_shape=batch_shape, dtype=dtype)
62 mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)
105 linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=2)
109 linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[2.])
113 linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[-2])
134 batch_shape = array_ops.placeholder_with_default(2, shape=None)
137 num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
142 batch_shape = array_ops.placeholder_with_default([-2], shape=None)
145 num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
[all …]
/external/tensorflow/tensorflow/python/ops/linalg/
Dlinear_operator_low_rank_update.py283 batch_shape = array_ops.broadcast_static_shape(
284 self.base_operator.batch_shape, uv_shape[:-2])
294 batch_shape, self._diag_update.shape[:-1])
340 batch_shape = array_ops.broadcast_static_shape(
341 self.base_operator.batch_shape,
342 self.diag_operator.batch_shape)
343 batch_shape = array_ops.broadcast_static_shape(
344 batch_shape,
346 batch_shape = array_ops.broadcast_static_shape(
347 batch_shape,
[all …]
Dlinear_operator_zeros.py129 batch_shape=None, argument
182 batch_shape=batch_shape,
215 linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape")
239 if batch_shape is None:
243 batch_shape, name="batch_shape_arg")
254 batch_shape = tensor_shape.TensorShape(self._batch_shape_static)
255 return batch_shape.concatenate(matrix_shape)
293 special_shape = self.batch_shape.concatenate([1, 1])
337 if self.batch_shape.is_fully_defined():
338 return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype)
[all …]
Dlinear_operator_kronecker.py261 batch_shape = self.operators[0].batch_shape
263 batch_shape = common_shapes.broadcast_shape(
264 batch_shape, operator.batch_shape)
266 return batch_shape.concatenate(matrix_shape)
281 batch_shape = self.operators[0].batch_shape_tensor()
283 batch_shape = array_ops.broadcast_dynamic_shape(
284 batch_shape, operator.batch_shape_tensor())
286 return array_ops.concat((batch_shape, matrix_shape), 0)
330 batch_shape = array_ops.concat(
332 x += array_ops.zeros(batch_shape, dtype=x.dtype.base_dtype)
[all …]
Dlinear_operator_identity.py92 d_shape = self.batch_shape.concatenate([self._min_matrix_dim()])
209 batch_shape=None, argument
257 batch_shape=batch_shape,
290 linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape")
297 if batch_shape is None:
301 batch_shape, name="batch_shape_arg")
312 batch_shape = tensor_shape.TensorShape(self._batch_shape_static)
313 return batch_shape.concatenate(matrix_shape)
347 special_shape = self.batch_shape.concatenate([1, 1])
383 if self.batch_shape.is_fully_defined():
[all …]
Dlinear_operator_composition.py214 batch_shape = self.operators[0].batch_shape
216 batch_shape = common_shapes.broadcast_shape(
217 batch_shape, operator.batch_shape)
219 return batch_shape.concatenate(matrix_shape)
239 batch_shape = array_ops.shape(zeros)
241 return array_ops.concat((batch_shape, matrix_shape), 0)
Dlinear_operator_block_lower_triangular.py368 batch_shape = self.operators[0][0].batch_shape
371 batch_shape = common_shapes.broadcast_shape(
372 batch_shape, operator.batch_shape)
374 return batch_shape.concatenate(matrix_shape)
386 batch_shape = self.operators[0][0].batch_shape_tensor()
389 batch_shape = array_ops.broadcast_dynamic_shape(
390 batch_shape, operator.batch_shape_tensor())
392 return array_ops.concat((batch_shape, matrix_shape), 0)
Dlinear_operator_block_diag.py265 batch_shape = self.operators[0].batch_shape
267 batch_shape = common_shapes.broadcast_shape(
268 batch_shape, operator.batch_shape)
270 return batch_shape.concatenate(matrix_shape)
286 batch_shape = array_ops.shape(zeros)
288 return array_ops.concat((batch_shape, matrix_shape), 0)
Dlinear_operator_addition.py221 batch_shape = operators[0].batch_shape
223 batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
/external/tensorflow/tensorflow/python/kernel_tests/
Dmatrix_solve_ls_op_test.py47 batch_shape = matrix_shape[:-2]
56 np.tile(matrix, batch_shape + (1, 1)), trainable=False)
57 rhs = variables.Variable(np.tile(rhs, batch_shape + (1, 1)), trainable=False)
89 batch_shape=()): argument
110 if batch_shape != ():
111 a = np.tile(a, batch_shape + (1, 1))
112 b = np.tile(b, batch_shape + (1, 1))
113 np_ans = np.tile(np_ans, batch_shape + (1, 1))
114 np_r_norm = np.tile(np_r_norm, batch_shape)
192 for batch_shape in (), (2, 3):
[all …]
Dmatrix_band_part_op_test.py148 for batch_shape in ((), (2,), (1, 3, 2)):
153 "_".join(map(str, batch_shape + shape)))
155 _GetMatrixBandPartTest(dtype, batch_shape, shape))
158 for batch_shape in ((), (2,)):
163 "_".join(map(str, batch_shape + shape)))
165 _GetMatrixBandPartGradTest(dtype, batch_shape, shape))
Dlinalg_ops_test.py161 batch_shape = (2, 3)
164 linalg_ops.eye(num_rows=2, batch_shape=batch_shape).shape)
168 num_rows=2, num_columns=3, batch_shape=batch_shape).shape)
186 batch_shape = (2, 3)
190 batch_shape=batch_shape)
217 def test_eye_no_placeholder(self, num_rows, num_columns, batch_shape, dtype): argument
219 if batch_shape is not None:
220 eye_np = np.tile(eye_np, batch_shape + [1, 1])
224 batch_shape=batch_shape,
248 self, num_rows, num_columns, batch_shape, dtype): argument
[all …]
Dlu_op_test.py73 batch_shape = lu_shape[:-2]
81 num_rows, batch_shape=batch_shape, dtype=lower.dtype)
88 np.append(batch_shape, num_rows), dtype=lower.dtype)
256 batch_shape = shape[:-2]
262 return np.tile(matrix, batch_shape + (1, 1))
/external/tensorflow/tensorflow/compiler/tests/
Dmatrix_band_part_test.py169 def testMatrixBandPart(self, batch_shape, rows, cols): argument
171 if self.device == 'XLA_CPU' and cols == 7 and rows == 1 and batch_shape == [
177 mat = np.ones(batch_shape + [rows, cols]).astype(dtype)
178 batch_mat = np.tile(mat, batch_shape + [1, 1])
186 if batch_shape:
187 band_np = np.tile(band_np, batch_shape + [1, 1])
/external/tensorflow/tensorflow/core/ops/
Dlinalg_ops.cc37 ShapeHandle batch_shape; in MakeBatchSquareMatrix() local
38 TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape)); in MakeBatchSquareMatrix()
39 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out)); in MakeBatchSquareMatrix()
166 ShapeHandle batch_shape; in SelfAdjointEigV2ShapeFn() local
167 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape)); in SelfAdjointEigV2ShapeFn()
169 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape)); in SelfAdjointEigV2ShapeFn()
175 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape)); in SelfAdjointEigV2ShapeFn()
193 ShapeHandle batch_shape; in LuShapeFn() local
194 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape)); in LuShapeFn()
199 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape)); in LuShapeFn()
[all …]
/external/tensorflow/tensorflow/python/kernel_tests/distributions/
Dnormal_test.py124 self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
125 self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
133 self.assertAllEqual(normal.batch_shape, pdf.get_shape())
134 self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
160 self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
161 self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
170 self.assertAllEqual(normal.batch_shape, pdf.get_shape())
171 self.assertAllEqual(normal.batch_shape, pdf_values.shape)
194 self.assertAllEqual(normal.batch_shape, cdf.get_shape())
195 self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
[all …]
Dcategorical_test.py40 def make_categorical(batch_shape, num_classes, dtype=dtypes.int32): argument
42 list(batch_shape) + [num_classes], -10, 10, dtype=dtypes.float32) - 50.
70 for batch_shape in ([], [1], [2, 3, 4]):
71 dist = make_categorical(batch_shape, 10)
72 self.assertAllEqual(batch_shape, dist.batch_shape)
73 self.assertAllEqual(batch_shape, dist.batch_shape_tensor())
81 for batch_shape in ([], [1], [2, 3, 4]):
83 batch_shape, constant_op.constant(
85 self.assertAllEqual(len(batch_shape), dist.batch_shape.ndims)
86 self.assertAllEqual(batch_shape, dist.batch_shape_tensor())
/external/tensorflow/tensorflow/core/kernels/linalg/
Dlinalg_ops_common.cc93 TensorShape batch_shape; in Compute() local
94 AnalyzeInputs(context, &inputs, &input_matrix_shapes, &batch_shape); in Compute()
99 PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs, in Compute()
113 batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard); in Compute()
119 TensorShapes* input_matrix_shapes, TensorShape* batch_shape) { in AnalyzeInputs() argument
133 batch_shape->AddDim(in.dim_size(dim)); in AnalyzeInputs()
142 context, in.dim_size(dim) == batch_shape->dim_size(dim), in AnalyzeInputs()
163 const TensorShape& batch_shape, TensorOutputs* outputs, in PrepareOutputs() argument
195 output_tensor_shape = batch_shape; in PrepareOutputs()
Dlu_op.cc80 TensorShape batch_shape; in Compute() local
82 batch_shape.AddDim(input.dim_size(dim)); in Compute()
93 TensorShape permutation_shape = batch_shape; in Compute()
124 batch_shape.num_elements(), GetCostPerUnit(input_matrix_shape), in Compute()
/external/tensorflow/tensorflow/python/kernel_tests/proto/
Ddecode_proto_op_test_base.py74 def _compareProtos(self, batch_shape, sizes, fields, field_dict): argument
96 self.assertEqual(list(values.shape)[:-1], batch_shape)
149 def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, argument
179 batch = np.reshape(batch, batch_shape)
200 np.all(np.array(sizes.shape) == batch_shape + [len(field_names)]))
210 self._compareProtos(batch_shape, sizes, fields, field_dict)
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dmatrix_band_part_op.cc57 TensorShape batch_shape = input_shape; in Compile() local
58 batch_shape.RemoveLastDims(2); in Compile()
82 indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); in Compile()
/external/tensorflow/tensorflow/python/ops/distributions/
Ddistribution.py650 if self.batch_shape.is_fully_defined():
651 return ops.convert_to_tensor(self.batch_shape.as_list(),
660 def batch_shape(self): member in Distribution
732 self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor),
1211 maybe_batch_shape=(", batch_shape={}".format(self.batch_shape)
1212 if self.batch_shape.ndims is not None
1227 batch_shape=self.batch_shape,
1276 batch_ndims = self.batch_shape.ndims
1307 self.batch_shape).concatenate([None]*event_ndims)

123456