Home
last modified time | relevance | path

Searched refs:tile_assignment_dims (Results 1 – 2 of 2) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/
Dxla_sharding.py155 tile_assignment_dims = [1] * len(shape)
156 tile_assignment_dims[split_dimension] = num_devices
161 tile_assignment_dimensions=tile_assignment_dims,
/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_util.cc594 std::vector<int64> tile_assignment_dims(hlo.shape().rank()); in GatherEffectiveOutputSharding() local
598 tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i); in GatherEffectiveOutputSharding()
601 tile_assignment_dims[i] = 1; in GatherEffectiveOutputSharding()
758 std::vector<int64> tile_assignment_dims(data_rank, 1LL); in ScatterEffectiveDataSharding() local
763 tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i); in ScatterEffectiveDataSharding()
789 data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims); in ScatterEffectiveDataSharding()