Searched refs:group_dim_sizes (Results 1 – 4 of 4) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | spmd_partitioner_util.h | 300 std::vector<int64> group_dim_sizes, int64 data_rank, in GroupedSharding() 304 group_dim_sizes(std::move(group_dim_sizes)), in GroupedSharding() 309 std::vector<int64> group_dim_sizes; member 355 absl::Span<const int64> group_dims, absl::Span<const int64> group_dim_sizes,
|
D | spmd_partitioner_util.cc | 1390 std::vector<int64> group_dim_sizes(group_dims.size()); in GroupShardingOnDims() local 1393 group_dim_sizes[i] = in GroupShardingOnDims() 1398 std::vector<std::vector<int64>> device_groups(Product(group_dim_sizes)); in GroupShardingOnDims() 1412 std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(), in GroupShardingOnDims() 1460 tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i]; in UngroupSharding() 1469 int64 groups_in_this_dim = grouped_sharding.group_dim_sizes[i]; in UngroupSharding() 1564 int64 groups = grouped_sharding.group_dim_sizes[i]; in GetPerGroupBaseShape() 1696 absl::Span<const int64> group_dims, absl::Span<const int64> group_dim_sizes, in PerGroupSliceFromReplicated() argument 1713 group_level_tile_dims[group_dims[i]] = group_dim_sizes[i]; in PerGroupSliceFromReplicated()
|
D | dot_handler.cc | 1757 output_grouped.group_dim_sizes, b); in PartitionDotGroupOnBatch()
|
D | spmd_partitioner.cc | 462 target_grouped.group_dim_sizes, state_.b); in ReshardNoCache()
|