Searched refs:mask_shape (Results 1 – 7 of 7) sorted by relevance
/external/tensorflow/tensorflow/python/keras/layers/ |
D | multi_head_attention_test.py | 206 mask_shape = [batch_size] + mask_dims 212 mask_data = np.random.randint(2, size=mask_shape).astype("bool") 214 null_mask_data = np.ones(mask_shape) 219 mask_tensor = keras.Input(mask_shape[1:], name="mask")
|
/external/tensorflow/tensorflow/python/keras/utils/ |
D | conv_utils.py | 280 mask_shape = input_shape + output_shape 281 mask = np.zeros(mask_shape, np.bool)
|
/external/tensorflow/tensorflow/python/ops/ragged/ |
D | ragged_array_ops.py | 197 mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) 198 split_size = math_ops.cumprod(mask_shape) + 1 200 elt_size = mask_shape[dim + 1]
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | convolution_group_converter.cc | 169 Shape mask_shape = ShapeUtil::MakeShape( in GetExpandedFilterMask() local 183 mask_shape, mask1, {kernel_input_feature_dim})); in GetExpandedFilterMask() 187 mask_shape, mask2, {kernel_output_feature_dim})); in GetExpandedFilterMask()
|
D | dynamic_padder.cc | 253 const Shape mask_shape = in PadWithScalar() local 259 computation->AddInstruction(HloInstruction::CreateIota(mask_shape, dim)); in PadWithScalar() 262 HloInstruction::CreateBroadcast(mask_shape, dynamic_size, {})); in PadWithScalar()
|
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | spmd_partitioner_util.cc | 1029 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); in ExchangeHaloAndGetValidData() local 1040 mask_shape, index_in_padded_shape, valid_index_start, in ExchangeHaloAndGetValidData() 1053 mask_shape, index_in_padded_shape, valid_index_limit, in ExchangeHaloAndGetValidData() 1060 mask_shape, HloOpcode::kAnd, predicates[0], predicates[1])) in ExchangeHaloAndGetValidData()
|
D | spmd_partitioner.cc | 490 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); in PadWithValue() local 513 mask_shape, index_in_full_shape, broadcast_limit, direction)); in PadWithValue()
|