Home
last modified time | relevance | path

Searched refs:operand_sharding (Results 1 – 6 of 6) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dtpu_sharding_identification_pass.cc151 if (auto operand_sharding = GetXlaShardingFromOperand(operand)) { in IdentifyXlaShardingForComputationInputs() local
152 sharding_for_args.push_back(operand_sharding.getValue()); in IdentifyXlaShardingForComputationInputs()
154 builder->getStringAttr(operand_sharding.getValue())); in IdentifyXlaShardingForComputationInputs()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dgather_scatter_handler.cc340 HloSharding operand_sharding = gather_sharding->operand_sharding; in PartitionIndexParallelDimensions() local
344 GroupShardingOnDims(operand_sharding, operand_parallel_dims); in PartitionIndexParallelDimensions()
356 operand = operand.Reshard(operand_sharding); in PartitionIndexParallelDimensions()
376 operand.base_shape(), operand_sharding, operand.state().partition_id, in PartitionIndexParallelDimensions()
425 if (operand_sharding.NumTiles() == in PartitionIndexParallelDimensions()
426 operand_sharding.NumTiles(operand_parallel_dims) && in PartitionIndexParallelDimensions()
Dspmd_partitioner_util.h36 HloSharding operand_sharding; member
/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_util.cc800 const Shape& operand_shape, const HloSharding& operand_sharding, in PassthroughOperandToGatherOutputOrScatterUpdate() argument
806 if (operand_sharding.IsTileMaximal()) { in PassthroughOperandToGatherOutputOrScatterUpdate()
807 return operand_sharding; in PassthroughOperandToGatherOutputOrScatterUpdate()
812 int64 dim_partitions = operand_sharding.tile_assignment().dim(i); in PassthroughOperandToGatherOutputOrScatterUpdate()
832 if (operand_sharding.ReplicateOnLastTileDim()) { in PassthroughOperandToGatherOutputOrScatterUpdate()
834 operand_sharding.tile_assignment().dimensions().back()); in PassthroughOperandToGatherOutputOrScatterUpdate()
836 Array<int64> tile_assignment = operand_sharding.tile_assignment(); in PassthroughOperandToGatherOutputOrScatterUpdate()
838 return operand_sharding.ReplicateOnLastTileDim() in PassthroughOperandToGatherOutputOrScatterUpdate()
840 operand_sharding.metadata()) in PassthroughOperandToGatherOutputOrScatterUpdate()
841 : HloSharding::Tile(tile_assignment, operand_sharding.metadata()); in PassthroughOperandToGatherOutputOrScatterUpdate()
Dsharding_propagation.cc342 const HloSharding& operand_sharding = operand->sharding(); in InferDotShardingFromOperands() local
343 if (operand_sharding.IsTileMaximal()) { in InferDotShardingFromOperands()
344 return operand_sharding; in InferDotShardingFromOperands()
363 operand_sharding, contracting_dims); in InferDotShardingFromOperands()
410 const HloSharding& operand_sharding = operand->sharding(); in InferGatherParallelShardingFromOperands() local
411 if (operand_sharding.IsTileMaximal()) { in InferGatherParallelShardingFromOperands()
412 return operand_sharding; in InferGatherParallelShardingFromOperands()
431 operand_sharding.tile_assignment().dim(indices_idx); in InferGatherParallelShardingFromOperands()
435 operand_sharding, index_non_parallel_dims); in InferGatherParallelShardingFromOperands()
Dbatchnorm_expander.cc273 HloSharding operand_sharding = in HandleBatchNormTraining() local
282 inst->set_sharding(operand_sharding); in HandleBatchNormTraining()