/external/tensorflow/tensorflow/compiler/xla/service/ |
D | hlo_sharding_util.cc | 83 old.tile_assignment().num_elements() != in MergeSharding() 84 to_merge->tile_assignment().num_elements()) { in MergeSharding() 88 int64 num_devices = old.tile_assignment().num_elements(); in MergeSharding() 91 new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions()); in MergeSharding() 92 for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) { in MergeSharding() 93 int64 new_dim = to_merge->tile_assignment().dim(i); in MergeSharding() 94 int64 old_dim = old.tile_assignment().dim(i); in MergeSharding() 108 replication >= old.tile_assignment().dimensions().back()) { in MergeSharding() 120 group_id *= to_merge->tile_assignment().dim(i); in MergeSharding() 125 old.tile_assignment().Each( in MergeSharding() [all …]
|
D | sharding_propagation.cc | 97 if (instruction->sharding().tile_assignment().dim(i) == in MaybeImproveInstructionSharding() 98 sharding.tile_assignment().dim(i)) { in MaybeImproveInstructionSharding() 101 if (instruction->sharding().tile_assignment().dim(i) != 1) { in MaybeImproveInstructionSharding() 431 operand_sharding.tile_assignment().dim(indices_idx); in InferGatherParallelShardingFromOperands() 441 replicate_non_parallel_dims.tile_assignment().dimensions().back()); in InferGatherParallelShardingFromOperands() 443 auto output_tile_assignment = replicate_non_parallel_dims.tile_assignment(); in InferGatherParallelShardingFromOperands() 496 partitions *= sharding.tile_assignment().dim(dim.lhs); in InferConvolutionShardingFromOperands() 499 partitions *= sharding.tile_assignment().dim(dim.rhs); in InferConvolutionShardingFromOperands() 545 const auto& tile_assignment = lhs->sharding().tile_assignment(); in InferConvolutionShardingFromOperands() local 546 if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) { in InferConvolutionShardingFromOperands() [all …]
|
D | hlo_sharding_util_test.cc | 165 EXPECT_EQ(result.tile_assignment(), Array2D<int64>({{0}, {1}, {2}, {3}})); in TEST() 172 EXPECT_EQ(result.tile_assignment(), Array2D<int64>({{0, 2, 1, 3}})); in TEST() 181 result.tile_assignment(), in TEST() 190 EXPECT_EQ(result.tile_assignment(), in TEST() 199 EXPECT_EQ(result.tile_assignment(), in TEST() 214 EXPECT_EQ(result.tile_assignment(), in TEST()
|
D | hlo_sharding.h | 61 static HloSharding Tile(const Array<int64>& tile_assignment, 63 return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false, 262 const Array<int64>& tile_assignment() const { return tile_assignment_; } in tile_assignment() function 315 explicit HloSharding(const Array<int64>& tile_assignment, 322 tile_assignment_(tile_assignment), in replicated_()
|
D | hlo_sharding.cc | 255 index.size() < tile_assignment().num_dimensions()) { in DeviceForTileIndex() 518 Array<int64> tile_assignment( in FromProto() local 522 proto.tile_assignment_devices().end(), tile_assignment.begin()); in FromProto() 524 ? PartialTile(tile_assignment, metadata) in FromProto() 525 : HloSharding(tile_assignment, in FromProto() 604 return tile_assignment().num_elements() / in NumTiles() 605 tile_assignment().dimensions().back(); in NumTiles() 607 return tile_assignment().num_elements(); in NumTiles() 616 !absl::c_linear_search(dims, tile_assignment().num_dimensions() - 1)); in NumTiles() 619 CHECK(d < tile_assignment().num_dimensions()); in NumTiles() [all …]
|
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/ |
D | xla_sharding.py | 75 def tile(cls, tile_assignment): argument 92 if not isinstance(tile_assignment, _np.ndarray): 94 dims = list(tile_assignment.shape) 95 flattened_devices = tile_assignment.reshape(-1, order='C') 103 def partial_tile(cls, tile_assignment): argument 117 if not isinstance(tile_assignment, _np.ndarray): 119 dims = list(tile_assignment.shape) 120 flattened_devices = tile_assignment.reshape(-1, order='C') 280 tile_assignment, argument 292 return Sharding.tile(tile_assignment).apply_to_tensor( [all …]
|
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | spmd_partitioner_util.cc | 141 if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { in EvenlyPartitions() 212 offset_arrays[i].resize(sharding.tile_assignment().num_elements()); in MakePartitionOffsets() 215 sharding.tile_assignment().Each( in MakePartitionOffsets() 223 if (sharding.tile_assignment().dim(i) == 1 || in MakePartitionOffsets() 242 auto dimensions = sharding.tile_assignment().dimensions(); in MakeTiledPartitionOrdinals() 288 i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i)); in GetPaddedShapeForUnevenPartitioning() 308 int64 rank = partial_sharding.tile_assignment().num_dimensions() - 1; in PartialReplicateReshardCompatibleSharding() 309 int64 target_rank = target_sharding.tile_assignment().num_dimensions() - in PartialReplicateReshardCompatibleSharding() 316 partial_sharding.tile_assignment().Each( in PartialReplicateReshardCompatibleSharding() 320 gid *= partial_sharding.tile_assignment().dim(i); in PartialReplicateReshardCompatibleSharding() [all …]
|
D | spmd_partitioner.cc | 196 group_size *= sharding.tile_assignment().dim(i); in GetPartitionGroupsForReplication() 199 sharding.tile_assignment().num_elements() / group_size); in GetPartitionGroupsForReplication() 200 sharding.tile_assignment().Each( in GetPartitionGroupsForReplication() 205 group_id *= sharding.tile_assignment().dim(i); in GetPartitionGroupsForReplication() 456 std::vector<int64> group_dims(target.tile_assignment().num_dimensions() - in ReshardNoCache() 505 index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) - in PadWithValue() 520 if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 || in PadWithValue() 584 int64 shard_count = target.tile_assignment().dim(i); in ReshardAsWindowedInput() 745 if (target.tile_assignment().dim(i) == 1) { in ReshardAsWindowedInput() 783 int64 shard_count = target.tile_assignment().dim(dim); in ReshardAsWindowedInput() [all …]
|
D | convolution_handler.cc | 262 return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 264 rhs_sharding.tile_assignment().dim( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 285 (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) > in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 287 rhs.sharding().tile_assignment().dim( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 293 (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) > in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 295 rhs.sharding().tile_assignment().dim( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 325 int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 571 return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() 573 rhs_sharding.tile_assignment().dim( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() 595 (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) > in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() [all …]
|
D | gather_scatter_handler.cc | 44 operand.sharding().tile_assignment().dim(dim); in GatherScatterOperandPartitionedOnlyOnTrivialSliceDims() 64 int64 partitions = operand.sharding().tile_assignment().dim(dim); in IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims() 143 indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) == in PartitionIndexOnlyPartition() 195 if (operand.sharding().tile_assignment().dim(i) > 1) { in ParititonPassthroughOperand() 289 operand.sharding().tile_assignment().num_dimensions() - 1); in ParititonTrivialIndexedOperandDimension() 354 indices_sharding.tile_assignment().dim(indices_idx); in PartitionIndexParallelDimensions() 360 indices_sharding.tile_assignment().dimensions().back()); in PartitionIndexParallelDimensions() 362 Array<int64> output_tile_assignment = indices_sharding.tile_assignment(); in PartitionIndexParallelDimensions() 534 indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) == in HandleScatter() 583 {indices.sharding().tile_assignment().num_dimensions() - 1}); in HandleScatter()
|
D | dot_handler.cc | 330 for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { in FirstShardingDimWithPartitionOfSize() 331 if (sharding.tile_assignment().dim(i) == num_partitions) { in FirstShardingDimWithPartitionOfSize() 717 CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()), in PartitionBaseCase() 720 for (int64 i = 0; i < slice_sharding->tile_assignment().num_dimensions(); in PartitionBaseCase() 722 if (slice_sharding->tile_assignment().dim(i) > 1) { in PartitionBaseCase() 1675 : lhs.sharding().tile_assignment().dimensions(); in PartitionDotGroupOnBatch() 1679 : rhs.sharding().tile_assignment().dimensions(); in PartitionDotGroupOnBatch() 1681 output_sharding.tile_assignment().dimensions(); in PartitionDotGroupOnBatch() 1692 output_sharding.tile_assignment().dim(dim.output); in PartitionDotGroupOnBatch() 1694 output_sharding.tile_assignment().dim(dim.output); in PartitionDotGroupOnBatch() [all …]
|
D | fft_handler.cc | 285 sharding.tile_assignment().Each( in GetFinalFftUsingCollectivePermute() 289 int64 dst_device = sharding.tile_assignment()(target_indices); in GetFinalFftUsingCollectivePermute() 365 hlo->sharding().tile_assignment().dimensions().back() != in HandleFft()
|
/external/tensorflow/tensorflow/compiler/xla/client/ |
D | sharding_builder.cc | 42 const TileAssignment& tile_assignment) { in Tile() argument 46 for (int64 dim : tile_assignment.dimensions()) { in Tile() 49 for (uint32 device : tile_assignment) { in Tile()
|
D | sharding_builder.h | 48 OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment);
|
/external/tensorflow/tensorflow/python/tpu/ |
D | tpu_feed.py | 103 tile_assignment = np.arange(np.prod(dims)).reshape(dims) 106 tile_assignment=tile_assignment,
|
/external/tensorflow/tensorflow/compiler/tf2xla/ |
D | xla_helpers.cc | 155 int64 device = *sharding->tile_assignment().begin(); in RewriteLayoutWithShardedShape()
|
D | xla_compiler_test.cc | 1765 xla::Array<int64> tile_assignment({2}); in TEST_F() local 1766 tile_assignment.FillIota(0); in TEST_F() 1767 xla::HloSharding sharding = xla::HloSharding::Tile(tile_assignment); in TEST_F()
|
/external/tensorflow/tensorflow/python/distribute/ |
D | tpu_strategy.py | 546 tile_assignment = np.arange(num_partition_splits).reshape( 548 return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
|
/external/tensorflow/tensorflow/compiler/xla/ |
D | xla_data.proto | 638 // None of the above; tile_shape and tile_assignment are both used. 663 // dimensions of tile_assignment(), but replicated across devices along the
|