Searched refs:operand_sharding (Results 1 – 6 of 6) sorted by relevance
151 if (auto operand_sharding = GetXlaShardingFromOperand(operand)) { in IdentifyXlaShardingForComputationInputs() local152 sharding_for_args.push_back(operand_sharding.getValue()); in IdentifyXlaShardingForComputationInputs()154 builder->getStringAttr(operand_sharding.getValue())); in IdentifyXlaShardingForComputationInputs()
340 HloSharding operand_sharding = gather_sharding->operand_sharding; in PartitionIndexParallelDimensions() local344 GroupShardingOnDims(operand_sharding, operand_parallel_dims); in PartitionIndexParallelDimensions()356 operand = operand.Reshard(operand_sharding); in PartitionIndexParallelDimensions()376 operand.base_shape(), operand_sharding, operand.state().partition_id, in PartitionIndexParallelDimensions()425 if (operand_sharding.NumTiles() == in PartitionIndexParallelDimensions()426 operand_sharding.NumTiles(operand_parallel_dims) && in PartitionIndexParallelDimensions()
36 HloSharding operand_sharding; member
800 const Shape& operand_shape, const HloSharding& operand_sharding, in PassthroughOperandToGatherOutputOrScatterUpdate() argument806 if (operand_sharding.IsTileMaximal()) { in PassthroughOperandToGatherOutputOrScatterUpdate()807 return operand_sharding; in PassthroughOperandToGatherOutputOrScatterUpdate()812 int64 dim_partitions = operand_sharding.tile_assignment().dim(i); in PassthroughOperandToGatherOutputOrScatterUpdate()832 if (operand_sharding.ReplicateOnLastTileDim()) { in PassthroughOperandToGatherOutputOrScatterUpdate()834 operand_sharding.tile_assignment().dimensions().back()); in PassthroughOperandToGatherOutputOrScatterUpdate()836 Array<int64> tile_assignment = operand_sharding.tile_assignment(); in PassthroughOperandToGatherOutputOrScatterUpdate()838 return operand_sharding.ReplicateOnLastTileDim() in PassthroughOperandToGatherOutputOrScatterUpdate()840 operand_sharding.metadata()) in PassthroughOperandToGatherOutputOrScatterUpdate()841 : HloSharding::Tile(tile_assignment, operand_sharding.metadata()); in PassthroughOperandToGatherOutputOrScatterUpdate()
342 const HloSharding& operand_sharding = operand->sharding(); in InferDotShardingFromOperands() local343 if (operand_sharding.IsTileMaximal()) { in InferDotShardingFromOperands()344 return operand_sharding; in InferDotShardingFromOperands()363 operand_sharding, contracting_dims); in InferDotShardingFromOperands()410 const HloSharding& operand_sharding = operand->sharding(); in InferGatherParallelShardingFromOperands() local411 if (operand_sharding.IsTileMaximal()) { in InferGatherParallelShardingFromOperands()412 return operand_sharding; in InferGatherParallelShardingFromOperands()431 operand_sharding.tile_assignment().dim(indices_idx); in InferGatherParallelShardingFromOperands()435 operand_sharding, index_non_parallel_dims); in InferGatherParallelShardingFromOperands()
273 HloSharding operand_sharding = in HandleBatchNormTraining() local282 inst->set_sharding(operand_sharding); in HandleBatchNormTraining()