Home
last modified time | relevance | path

Searched refs:mask_shape (Results 1 – 7 of 7) sorted by relevance

/external/tensorflow/tensorflow/python/keras/layers/
Dmulti_head_attention_test.py206 mask_shape = [batch_size] + mask_dims
212 mask_data = np.random.randint(2, size=mask_shape).astype("bool")
214 null_mask_data = np.ones(mask_shape)
219 mask_tensor = keras.Input(mask_shape[1:], name="mask")
/external/tensorflow/tensorflow/python/keras/utils/
Dconv_utils.py280 mask_shape = input_shape + output_shape
281 mask = np.zeros(mask_shape, np.bool)
/external/tensorflow/tensorflow/python/ops/ragged/
Dragged_array_ops.py197 mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
198 split_size = math_ops.cumprod(mask_shape) + 1
200 elt_size = mask_shape[dim + 1]
/external/tensorflow/tensorflow/compiler/xla/service/
Dconvolution_group_converter.cc169 Shape mask_shape = ShapeUtil::MakeShape( in GetExpandedFilterMask() local
183 mask_shape, mask1, {kernel_input_feature_dim})); in GetExpandedFilterMask()
187 mask_shape, mask2, {kernel_output_feature_dim})); in GetExpandedFilterMask()
Ddynamic_padder.cc253 const Shape mask_shape = in PadWithScalar() local
259 computation->AddInstruction(HloInstruction::CreateIota(mask_shape, dim)); in PadWithScalar()
262 HloInstruction::CreateBroadcast(mask_shape, dynamic_size, {})); in PadWithScalar()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner_util.cc1029 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); in ExchangeHaloAndGetValidData() local
1040 mask_shape, index_in_padded_shape, valid_index_start, in ExchangeHaloAndGetValidData()
1053 mask_shape, index_in_padded_shape, valid_index_limit, in ExchangeHaloAndGetValidData()
1060 mask_shape, HloOpcode::kAnd, predicates[0], predicates[1])) in ExchangeHaloAndGetValidData()
Dspmd_partitioner.cc490 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); in PadWithValue() local
513 mask_shape, index_in_full_shape, broadcast_limit, direction)); in PadWithValue()