Home
last modified time | relevance | path

Searched refs:slice_offsets (Results 1 – 4 of 4) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dfft_handler.cc100 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(), zero_s32); in PadEachPartitionWithHaloExchange() local
103 slice_offsets[concat->shape().rank() - 1] = in PadEachPartitionWithHaloExchange()
107 slice_shape, concat, slice_offsets, slice_shape.dimensions())); in PadEachPartitionWithHaloExchange()
Dspmd_partitioner_util.cc492 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(), in TileToPartialReplicateHaloExchange() local
494 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( in TileToPartialReplicateHaloExchange()
497 slice_shape, concat, slice_offsets, slice_shape.dimensions())); in TileToPartialReplicateHaloExchange()
600 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(), in PadFromPartialReplicateShape() local
602 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( in PadFromPartialReplicateShape()
605 slice_shape, concat, slice_offsets, slice_shape.dimensions())); in PadFromPartialReplicateShape()
1004 std::vector<HloInstruction*> slice_offsets(base_shape.rank(), zero); in ExchangeHaloAndGetValidData() local
1005 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( in ExchangeHaloAndGetValidData()
1008 slice_shape, concat, slice_offsets, slice_shape.dimensions())); in ExchangeHaloAndGetValidData()
Ddot_handler.cc3218 std::vector<HloInstruction*> slice_offsets(padded_shape.rank()); in MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions() local
3219 for (int64 i = 0; i < slice_offsets.size(); ++i) { in MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions()
3220 slice_offsets[i] = dus->mutable_operand(i + 2); in MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions()
3226 padded, slice_offsets, dus->operand(1)->shape().dimensions())); in MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions()
3308 for (int64 i = 0; i < slice_offsets.size(); ++i) { in MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions()
3319 reduce_dus_offsets.push_back(slice_offsets[i]); in MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions()
3328 iota->shape(), slice_offsets[dim], {})))); in MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions()
Dspmd_partitioner.cc3339 std::vector<HloInstruction*> slice_offsets(shard_shape.rank(), zero); in HandleSelectAndScatter() local
3348 slice_offsets[i] = left_halo_size; in HandleSelectAndScatter()
3355 slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary( in HandleSelectAndScatter()
3361 shard_shape, sharded_select_and_scatter, slice_offsets, in HandleSelectAndScatter()