Home
last modified time | relevance | path

Searched refs:group_dim_sizes (Results 1 – 4 of 4) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner_util.h300 std::vector<int64> group_dim_sizes, int64 data_rank, in GroupedSharding()
304 group_dim_sizes(std::move(group_dim_sizes)), in GroupedSharding()
309 std::vector<int64> group_dim_sizes; member
355 absl::Span<const int64> group_dims, absl::Span<const int64> group_dim_sizes,
Dspmd_partitioner_util.cc1390 std::vector<int64> group_dim_sizes(group_dims.size()); in GroupShardingOnDims() local
1393 group_dim_sizes[i] = in GroupShardingOnDims()
1398 std::vector<std::vector<int64>> device_groups(Product(group_dim_sizes)); in GroupShardingOnDims()
1412 std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(), in GroupShardingOnDims()
1460 tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i]; in UngroupSharding()
1469 int64 groups_in_this_dim = grouped_sharding.group_dim_sizes[i]; in UngroupSharding()
1564 int64 groups = grouped_sharding.group_dim_sizes[i]; in GetPerGroupBaseShape()
1696 absl::Span<const int64> group_dims, absl::Span<const int64> group_dim_sizes, in PerGroupSliceFromReplicated() argument
1713 group_level_tile_dims[group_dims[i]] = group_dim_sizes[i]; in PerGroupSliceFromReplicated()
Ddot_handler.cc1757 output_grouped.group_dim_sizes, b); in PartitionDotGroupOnBatch()
Dspmd_partitioner.cc462 target_grouped.group_dim_sizes, state_.b); in ReshardNoCache()