Home
last modified time | relevance | path

Searched refs:tile_assignment (Results 1 – 19 of 19) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_util.cc83 old.tile_assignment().num_elements() != in MergeSharding()
84 to_merge->tile_assignment().num_elements()) { in MergeSharding()
88 int64 num_devices = old.tile_assignment().num_elements(); in MergeSharding()
91 new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions()); in MergeSharding()
92 for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) { in MergeSharding()
93 int64 new_dim = to_merge->tile_assignment().dim(i); in MergeSharding()
94 int64 old_dim = old.tile_assignment().dim(i); in MergeSharding()
108 replication >= old.tile_assignment().dimensions().back()) { in MergeSharding()
120 group_id *= to_merge->tile_assignment().dim(i); in MergeSharding()
125 old.tile_assignment().Each( in MergeSharding()
[all …]
Dsharding_propagation.cc97 if (instruction->sharding().tile_assignment().dim(i) == in MaybeImproveInstructionSharding()
98 sharding.tile_assignment().dim(i)) { in MaybeImproveInstructionSharding()
101 if (instruction->sharding().tile_assignment().dim(i) != 1) { in MaybeImproveInstructionSharding()
431 operand_sharding.tile_assignment().dim(indices_idx); in InferGatherParallelShardingFromOperands()
441 replicate_non_parallel_dims.tile_assignment().dimensions().back()); in InferGatherParallelShardingFromOperands()
443 auto output_tile_assignment = replicate_non_parallel_dims.tile_assignment(); in InferGatherParallelShardingFromOperands()
496 partitions *= sharding.tile_assignment().dim(dim.lhs); in InferConvolutionShardingFromOperands()
499 partitions *= sharding.tile_assignment().dim(dim.rhs); in InferConvolutionShardingFromOperands()
545 const auto& tile_assignment = lhs->sharding().tile_assignment(); in InferConvolutionShardingFromOperands() local
546 if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) { in InferConvolutionShardingFromOperands()
[all …]
Dhlo_sharding_util_test.cc165 EXPECT_EQ(result.tile_assignment(), Array2D<int64>({{0}, {1}, {2}, {3}})); in TEST()
172 EXPECT_EQ(result.tile_assignment(), Array2D<int64>({{0, 2, 1, 3}})); in TEST()
181 result.tile_assignment(), in TEST()
190 EXPECT_EQ(result.tile_assignment(), in TEST()
199 EXPECT_EQ(result.tile_assignment(), in TEST()
214 EXPECT_EQ(result.tile_assignment(), in TEST()
Dhlo_sharding.h61 static HloSharding Tile(const Array<int64>& tile_assignment,
63 return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false,
262 const Array<int64>& tile_assignment() const { return tile_assignment_; } in tile_assignment() function
315 explicit HloSharding(const Array<int64>& tile_assignment,
322 tile_assignment_(tile_assignment), in replicated_()
Dhlo_sharding.cc255 index.size() < tile_assignment().num_dimensions()) { in DeviceForTileIndex()
518 Array<int64> tile_assignment( in FromProto() local
522 proto.tile_assignment_devices().end(), tile_assignment.begin()); in FromProto()
524 ? PartialTile(tile_assignment, metadata) in FromProto()
525 : HloSharding(tile_assignment, in FromProto()
604 return tile_assignment().num_elements() / in NumTiles()
605 tile_assignment().dimensions().back(); in NumTiles()
607 return tile_assignment().num_elements(); in NumTiles()
616 !absl::c_linear_search(dims, tile_assignment().num_dimensions() - 1)); in NumTiles()
619 CHECK(d < tile_assignment().num_dimensions()); in NumTiles()
[all …]
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/
Dxla_sharding.py75 def tile(cls, tile_assignment): argument
92 if not isinstance(tile_assignment, _np.ndarray):
94 dims = list(tile_assignment.shape)
95 flattened_devices = tile_assignment.reshape(-1, order='C')
103 def partial_tile(cls, tile_assignment): argument
117 if not isinstance(tile_assignment, _np.ndarray):
119 dims = list(tile_assignment.shape)
120 flattened_devices = tile_assignment.reshape(-1, order='C')
280 tile_assignment, argument
292 return Sharding.tile(tile_assignment).apply_to_tensor(
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner_util.cc141 if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { in EvenlyPartitions()
212 offset_arrays[i].resize(sharding.tile_assignment().num_elements()); in MakePartitionOffsets()
215 sharding.tile_assignment().Each( in MakePartitionOffsets()
223 if (sharding.tile_assignment().dim(i) == 1 || in MakePartitionOffsets()
242 auto dimensions = sharding.tile_assignment().dimensions(); in MakeTiledPartitionOrdinals()
288 i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i)); in GetPaddedShapeForUnevenPartitioning()
308 int64 rank = partial_sharding.tile_assignment().num_dimensions() - 1; in PartialReplicateReshardCompatibleSharding()
309 int64 target_rank = target_sharding.tile_assignment().num_dimensions() - in PartialReplicateReshardCompatibleSharding()
316 partial_sharding.tile_assignment().Each( in PartialReplicateReshardCompatibleSharding()
320 gid *= partial_sharding.tile_assignment().dim(i); in PartialReplicateReshardCompatibleSharding()
[all …]
Dspmd_partitioner.cc196 group_size *= sharding.tile_assignment().dim(i); in GetPartitionGroupsForReplication()
199 sharding.tile_assignment().num_elements() / group_size); in GetPartitionGroupsForReplication()
200 sharding.tile_assignment().Each( in GetPartitionGroupsForReplication()
205 group_id *= sharding.tile_assignment().dim(i); in GetPartitionGroupsForReplication()
456 std::vector<int64> group_dims(target.tile_assignment().num_dimensions() - in ReshardNoCache()
505 index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) - in PadWithValue()
520 if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 || in PadWithValue()
584 int64 shard_count = target.tile_assignment().dim(i); in ReshardAsWindowedInput()
745 if (target.tile_assignment().dim(i) == 1) { in ReshardAsWindowedInput()
783 int64 shard_count = target.tile_assignment().dim(dim); in ReshardAsWindowedInput()
[all …]
Dconvolution_handler.cc262 return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
264 rhs_sharding.tile_assignment().dim( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
285 (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) > in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
287 rhs.sharding().tile_assignment().dim( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
293 (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) > in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
295 rhs.sharding().tile_assignment().dim( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
325 int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
571 return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
573 rhs_sharding.tile_assignment().dim( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
595 (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) > in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
[all …]
Dgather_scatter_handler.cc44 operand.sharding().tile_assignment().dim(dim); in GatherScatterOperandPartitionedOnlyOnTrivialSliceDims()
64 int64 partitions = operand.sharding().tile_assignment().dim(dim); in IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims()
143 indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) == in PartitionIndexOnlyPartition()
195 if (operand.sharding().tile_assignment().dim(i) > 1) { in ParititonPassthroughOperand()
289 operand.sharding().tile_assignment().num_dimensions() - 1); in ParititonTrivialIndexedOperandDimension()
354 indices_sharding.tile_assignment().dim(indices_idx); in PartitionIndexParallelDimensions()
360 indices_sharding.tile_assignment().dimensions().back()); in PartitionIndexParallelDimensions()
362 Array<int64> output_tile_assignment = indices_sharding.tile_assignment(); in PartitionIndexParallelDimensions()
534 indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) == in HandleScatter()
583 {indices.sharding().tile_assignment().num_dimensions() - 1}); in HandleScatter()
Ddot_handler.cc330 for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { in FirstShardingDimWithPartitionOfSize()
331 if (sharding.tile_assignment().dim(i) == num_partitions) { in FirstShardingDimWithPartitionOfSize()
717 CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()), in PartitionBaseCase()
720 for (int64 i = 0; i < slice_sharding->tile_assignment().num_dimensions(); in PartitionBaseCase()
722 if (slice_sharding->tile_assignment().dim(i) > 1) { in PartitionBaseCase()
1675 : lhs.sharding().tile_assignment().dimensions(); in PartitionDotGroupOnBatch()
1679 : rhs.sharding().tile_assignment().dimensions(); in PartitionDotGroupOnBatch()
1681 output_sharding.tile_assignment().dimensions(); in PartitionDotGroupOnBatch()
1692 output_sharding.tile_assignment().dim(dim.output); in PartitionDotGroupOnBatch()
1694 output_sharding.tile_assignment().dim(dim.output); in PartitionDotGroupOnBatch()
[all …]
Dfft_handler.cc285 sharding.tile_assignment().Each( in GetFinalFftUsingCollectivePermute()
289 int64 dst_device = sharding.tile_assignment()(target_indices); in GetFinalFftUsingCollectivePermute()
365 hlo->sharding().tile_assignment().dimensions().back() != in HandleFft()
/external/tensorflow/tensorflow/compiler/xla/client/
Dsharding_builder.cc42 const TileAssignment& tile_assignment) { in Tile() argument
46 for (int64 dim : tile_assignment.dimensions()) { in Tile()
49 for (uint32 device : tile_assignment) { in Tile()
Dsharding_builder.h48 OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment);
/external/tensorflow/tensorflow/python/tpu/
Dtpu_feed.py103 tile_assignment = np.arange(np.prod(dims)).reshape(dims)
106 tile_assignment=tile_assignment,
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_helpers.cc155 int64 device = *sharding->tile_assignment().begin(); in RewriteLayoutWithShardedShape()
Dxla_compiler_test.cc1765 xla::Array<int64> tile_assignment({2}); in TEST_F() local
1766 tile_assignment.FillIota(0); in TEST_F()
1767 xla::HloSharding sharding = xla::HloSharding::Tile(tile_assignment); in TEST_F()
/external/tensorflow/tensorflow/python/distribute/
Dtpu_strategy.py546 tile_assignment = np.arange(num_partition_splits).reshape(
548 return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
/external/tensorflow/tensorflow/compiler/xla/
Dxla_data.proto638 // None of the above; tile_shape and tile_assignment are both used.
663 // dimensions of tile_assignment(), but replicated across devices along the