Home
last modified time | relevance | path

Searched refs:ag_shape (Results 1 – 3 of 3) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner.h141 SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape,
Dspmd_partitioner.cc3472 SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape, in GetDefaultCollectiveOpsCreator()
3486 ag_shape, operand, all_gather_dimension, device_groups, in GetDefaultCollectiveOpsCreator()
3589 auto ag_shape = operand->shape(); in AllGatherShardsInternal() local
3591 ag_shape.set_dimensions( in AllGatherShardsInternal()
3592 i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i)); in AllGatherShardsInternal()
3594 result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result)); in AllGatherShardsInternal()
Dspmd_partitioner_util.cc1657 SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape, in GetPerGroupCollectiveOpsCreator()
1661 b, operand, ag_shape, in GetPerGroupCollectiveOpsCreator()