Home
last modified time | relevance | path

Searched refs:HloSharding (Results 1 – 25 of 43) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_test.cc57 HloSharding sharding = HloSharding::Replicate(); in TEST_F()
63 HloSharding other = HloSharding::Replicate(); in TEST_F()
72 HloSharding sharding = HloSharding::AssignDevice(5); in TEST_F()
79 HloSharding other = HloSharding::Replicate(); in TEST_F()
87 ShapeTree<HloSharding> shape_tree = in TEST_F()
109 HloSharding sharding = HloSharding::FromProto(proto).ConsumeValueOrDie(); in TEST_F()
116 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3})); in TEST_F()
123 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); in TEST_F()
131 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); in TEST_F()
155 HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}), in TEST_F()
[all …]
Dhlo_sharding_util_test.cc25 EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}), in TEST()
26 HloSharding::Replicate()); in TEST()
30 HloSharding input = HloSharding::Tile(Array4D<int64>({{{{0, 1}}, {{2, 3}}}})); in TEST()
31 HloSharding output = in TEST()
32 HloSharding::Tile(Array4D<int64>({{{{0}, {2}}}, {{{1}, {3}}}})); in TEST()
39 HloSharding sharding = HloSharding::AssignDevice(7); in TEST()
40 absl::optional<HloSharding> result = in TEST()
49 HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}})); in TEST()
50 absl::optional<HloSharding> result = in TEST()
58 HloSharding input_sharding = in TEST()
[all …]
Dhlo_sharding_util.h45 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs);
49 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
81 HloSharding TransposeSharding(const HloSharding& sharding,
89 absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
91 const HloSharding& sharding);
96 HloSharding ReverseSharding(const HloSharding& sharding,
103 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
112 HloSharding GatherOutputSharding(const HloSharding& index_sharding,
117 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
124 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo);
[all …]
Dhlo_sharding.h41 class HloSharding {
45 static HloSharding Replicate(absl::Span<const OpMetadata> metadata = {}) {
46 return HloSharding(/*manual=*/false, /*replicated=*/true, metadata);
50 static HloSharding Manual(absl::Span<const OpMetadata> metadata = {}) {
51 return HloSharding(/*manual=*/true, /*replicated=*/false, metadata);
56 static HloSharding AssignDevice(int64 device_id,
61 static HloSharding Tile(const Array<int64>& tile_assignment,
63 return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false,
70 static HloSharding PartialTile(
78 static HloSharding PartialTile(
[all …]
Dhlo_sharding.cc31 HloSharding HloSharding::AssignDevice(int64 device_id, in AssignDevice()
33 return HloSharding(device_id, metadata); in AssignDevice()
36 HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles, in Tile1D()
43 return HloSharding(assignment, /*replicate_on_last_tile_dim=*/false, in Tile1D()
47 HloSharding HloSharding::PartialTile( in PartialTile()
67 HloSharding HloSharding::PartialTile( in PartialTile()
80 return HloSharding(fully_tiled, /*replicate_on_last_tile_dim=*/false, in PartialTile()
104 return HloSharding(sorted_tile, /*replicate_on_last_tile_dim=*/true, in PartialTile()
108 HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) { in Tuple()
109 std::vector<HloSharding> flattened_list; in Tuple()
[all …]
Dhlo_sharding_metadata.cc53 const HloSharding& sharding) { in SetSingleSharding()
58 bool ShardingMatches(const HloSharding& sharding1, in ShardingMatches()
59 const HloSharding& sharding2) { in ShardingMatches()
120 const HloSharding& sharding) { in FixupPassThroughDomainLinks()
142 std::shared_ptr<const HloSharding> CloneShardingForDomain( in CloneShardingForDomain()
143 std::shared_ptr<const HloSharding> sharding) { in CloneShardingForDomain()
148 return std::make_shared<const HloSharding>(*single_sharding); in CloneShardingForDomain()
152 const HloSharding& sharding) { in ApplyDomainSingleSharding()
172 StatusOr<ShapeTree<HloSharding>> GetShardingTreeFromUser( in GetShardingTreeFromUser()
186 StatusOr<AssignmentKind> AssignLeafSharding(HloSharding* lhs, in AssignLeafSharding()
[all …]
Dhlo_sharding_util.cc37 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) { in IsShardingMoreSpecific()
69 bool MergeSharding(const HloSharding& old, HloSharding* to_merge, in MergeSharding()
117 const HloSharding& sharding) { in MergeSharding()
188 *to_merge = HloSharding::Tile(new_tile, merged_metadata); in MergeSharding()
190 *to_merge = HloSharding::PartialTile(new_tile, merged_metadata); in MergeSharding()
268 HloSharding TransposeSharding(const HloSharding& sharding, in TransposeSharding()
293 ? HloSharding::PartialTile(tile_assignment, sharding.metadata()) in TransposeSharding()
294 : HloSharding::Tile(tile_assignment, sharding.metadata()); in TransposeSharding()
297 absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape, in ReshapeSharding()
299 const HloSharding& sharding) { in ReshapeSharding()
[all …]
Dhlo_domain_test.cc432 HloSharding::Tuple(new_tuple->shape(), {HloSharding::AssignDevice(1), in TEST_F()
433 HloSharding::AssignDevice(0)})); in TEST_F()
472 EXPECT_EQ(root->sharding(), HloSharding::AssignDevice(1)); in TEST_F()
527 EXPECT_EQ(HloSharding::Tuple(tpl->shape(), {HloSharding::AssignDevice(1), in TEST_F()
528 HloSharding::AssignDevice(0)}), in TEST_F()
665 EXPECT_EQ(HloSharding::Tuple(tuple0->shape(), {HloSharding::AssignDevice(1), in TEST_F()
666 HloSharding::AssignDevice(1), in TEST_F()
667 HloSharding::AssignDevice(0)}), in TEST_F()
671 EXPECT_EQ(HloSharding::Tuple(copy0->shape(), {HloSharding::AssignDevice(1), in TEST_F()
672 HloSharding::AssignDevice(0)}), in TEST_F()
[all …]
Dsharding_propagation.cc58 bool IsSpatiallyPartitioned(const HloSharding& sharding) { in IsSpatiallyPartitioned()
74 bool MaybeImproveInstructionSharding(HloSharding sharding, in MaybeImproveInstructionSharding()
122 HloSharding::SingleTuple(instruction->shape(), HloSharding::Replicate())); in SetDefaultTupleSharding()
342 const HloSharding& operand_sharding = operand->sharding(); in InferDotShardingFromOperands()
410 const HloSharding& operand_sharding = operand->sharding(); in InferGatherParallelShardingFromOperands()
433 HloSharding replicate_non_parallel_dims = in InferGatherParallelShardingFromOperands()
446 ? HloSharding::PartialTile( in InferGatherParallelShardingFromOperands()
449 : HloSharding::Tile(output_tile_assignment, in InferGatherParallelShardingFromOperands()
538 HloSharding::Replicate(lhs->sharding().metadata()), instruction, in InferConvolutionShardingFromOperands()
555 return MaybeImproveInstructionSharding(HloSharding::Replicate(), instruction, in InferConvolutionShardingFromOperands()
[all …]
Dhlo_sharding_metadata.h30 explicit ShardingMetadata(std::shared_ptr<const HloSharding> sharding) in ShardingMetadata()
43 const HloSharding* sharding() const { return sharding_.get(); } in sharding()
60 std::shared_ptr<const HloSharding> sharding_;
77 std::shared_ptr<const HloSharding> sharding;
Dhlo_matchers_test.cc166 p1->set_sharding(HloSharding::AssignDevice(1)); in TEST_F()
174 auto sharding = HloSharding::Tuple( in TEST_F()
175 tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1), in TEST_F()
176 HloSharding::Replicate()}); in TEST_F()
181 ::testing::Not(op::Sharding(HloSharding::AssignDevice(1)))); in TEST_F()
184 ::testing::Not(op::Sharding(HloSharding::AssignDevice(0)))); in TEST_F()
185 EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1))); in TEST_F()
191 EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), in TEST_F()
197 EXPECT_THAT(Explain(p1.get(), op::Sharding(HloSharding::AssignDevice(0))), in TEST_F()
Dbatchnorm_expander.cc272 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormTraining()
273 HloSharding operand_sharding = in HandleBatchNormTraining()
276 HloSharding default_sharding = in HandleBatchNormTraining()
278 ? HloSharding::AssignDevice(unique_device.value()) in HandleBatchNormTraining()
279 : HloSharding::Replicate(); in HandleBatchNormTraining()
361 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormInference()
363 HloSharding default_sharding = in HandleBatchNormInference()
365 ? HloSharding::AssignDevice(unique_device.value()) in HandleBatchNormInference()
366 : HloSharding::Replicate(); in HandleBatchNormInference()
533 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormGrad()
[all …]
Dhlo_module.h336 const std::vector<HloSharding>& spmd_parameters_shardings() const { in spmd_parameters_shardings()
341 const std::vector<HloSharding>& shardings) { in set_spmd_parameters_shardings()
353 const HloSharding& spmd_output_sharding() const { in spmd_output_sharding()
357 void set_spmd_output_sharding(const HloSharding& sharding) { in set_spmd_output_sharding()
424 absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_;
428 absl::optional<HloSharding> spmd_output_sharding_;
Dhlo_instruction.h1382 const HloSharding& sharding() const { in sharding()
1386 std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; } in sharding_ptr()
1389 const HloSharding& sharding_or_default(const HloSharding& default_) const { in sharding_or_default()
1401 void set_sharding(const HloSharding& sharding) { in set_sharding()
1402 sharding_ = std::make_shared<const HloSharding>(sharding); in set_sharding()
1404 void set_sharding(std::shared_ptr<const HloSharding> sharding) { in set_sharding()
1407 void set_single_sharding(const HloSharding& sharding);
1410 set_single_sharding(HloSharding::AssignDevice(device)); in set_device_sharding()
2098 std::shared_ptr<const HloSharding> sharding_;
Dhlo_parser.h53 StatusOr<HloSharding> ParseSharding(absl::string_view str);
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner_util.h35 HloSharding indices_sharding;
36 HloSharding operand_sharding;
40 bool HasReplicatedSharding(const HloSharding& sharding);
85 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding);
89 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
98 const HloSharding& sharding,
106 const Shape& shape, const HloSharding& sharding,
112 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);
122 const HloSharding& sharding);
127 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b);
[all …]
Dspmd_partitioner.h199 const HloSharding& root_sharding,
211 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
218 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
233 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
237 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
271 std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache;
273 std::tuple<HloSharding, Window, WindowedInputShardReturnValue>>
307 PartitionedHlo Reshard(const HloSharding& target);
321 const HloSharding& sharding() const { return hlo_->sharding(); } in sharding()
331 const Window& window, const HloSharding& target,
[all …]
Dspmd_partitioner_util.cc47 bool HasReplicatedSharding(const HloSharding& sharding) { in HasReplicatedSharding()
127 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { in EvenlyPartitions()
148 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { in MakePartitionedShape()
167 const HloSharding& sharding, in MakeNonPaddedShapeForGivenPartition()
205 const Shape& shape, const HloSharding& sharding, in MakePartitionOffsets()
240 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { in MakeTiledPartitionOrdinals()
277 const HloSharding& sharding) { in GetPaddedShapeForUnevenPartitioning()
294 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) { in PadBaseShapeBeforeUnevenTiledSharding()
303 absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding( in PartialReplicateReshardCompatibleSharding()
304 const HloSharding& partial_sharding, const HloSharding& target_sharding) { in PartialReplicateReshardCompatibleSharding()
[all …]
Dgather_scatter_handler.cc130 const HloSharding& output_sharding,
181 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in ParititonPassthroughOperand()
190 indices = indices.Reshard(HloSharding::Replicate()); in ParititonPassthroughOperand()
216 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in ParititonTrivialIndexedOperandDimension()
226 indices = indices.Reshard(HloSharding::Replicate()); in ParititonTrivialIndexedOperandDimension()
302 ar->set_sharding(HloSharding::Replicate()); in ParititonTrivialIndexedOperandDimension()
316 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in PartitionIndexParallelDimensions()
319 absl::InlinedVector<std::pair<HloInstruction*, HloSharding>, 2> in PartitionIndexParallelDimensions()
339 HloSharding indices_sharding = gather_sharding->indices_sharding; in PartitionIndexParallelDimensions()
340 HloSharding operand_sharding = gather_sharding->operand_sharding; in PartitionIndexParallelDimensions()
[all …]
Dspmd_partitioner.cc193 const HloSharding& sharding, absl::Span<const int64> replication_dims) { in GetPartitionGroupsForReplication()
345 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { in Reshard()
369 PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { in ReshardNoCache()
482 const HloSharding& sharding = hlo_->sharding(); in PadWithValue()
547 const HloSharding& target, in ReshardAsWindowedInput()
832 const HloSharding& sharding = hlo_->sharding(); in Replicate()
848 cache.emplace_back(HloSharding::Replicate(), std::move(resharded)); in Replicate()
862 result->set_sharding(HloSharding::Replicate()); in Replicate()
911 const HloSharding& target) { in ReshardToPartialReplicateWithAllGather()
977 const HloSharding& target) { in ReshardFromPartialReplicateWithDynamicSlice()
[all …]
Ddot_handler.cc328 const HloSharding& sharding) { in FirstShardingDimWithPartitionOfSize()
391 const absl::optional<HloSharding>& output_sharding_transposed_to_match_lhs, in GetWindowedEinsumConfiguration()
392 const absl::optional<HloSharding>& output_sharding_transposed_to_match_rhs, in GetWindowedEinsumConfiguration()
393 const HloSharding& lhs_sharding, const HloSharding& rhs_sharding) { in GetWindowedEinsumConfiguration()
468 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, in PartitionBaseCase()
484 const HloSharding& lhs_sharding = lhs.sharding(); in PartitionBaseCase()
485 const HloSharding& rhs_sharding = rhs.sharding(); in PartitionBaseCase()
703 const HloSharding* slice_sharding; in PartitionBaseCase()
826 padded_slice_operand->set_sharding(HloSharding::Replicate()); in PartitionBaseCase()
1078 slice_operand->set_sharding(HloSharding::Replicate()); in PartitionBaseCase()
[all …]
Dconvolution_handler.cc43 const HloSharding& output_sharding, in PartitionConvolutionWithBatchGroupCount()
135 const HloSharding& output_sharding, in PartitionConvolutionWithFeatureGroupCount()
229 const HloSharding& output_sharding, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
258 auto unsupported_sharding = [&](const HloSharding& lhs_sharding, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
259 const HloSharding& rhs_sharding) { in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
505 ar->set_sharding(HloSharding::Replicate()); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
516 const HloSharding& output_sharding, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
569 auto unsupported_sharding = [&](const HloSharding& lhs_sharding, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
570 const HloSharding& rhs_sharding) { in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
732 ar->set_sharding(HloSharding::Replicate()); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
[all …]
Dconvolution_handler.h31 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
/external/tensorflow/tensorflow/core/tpu/kernels/
Dtpu_compile_op_support.cc35 using ::xla::HloSharding;
149 ShapeTree<HloSharding> GetSubtree( in GetSubtree()
150 const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) { in GetSubtree()
151 ShapeTree<HloSharding> element_shape_tree( in GetSubtree()
154 HloSharding::Replicate()); in GetSubtree()
162 Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding, in GetPerDeviceShape()
165 ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape); in GetPerDeviceShape()
169 HloSharding element_sharding = tuple_shape_tree.element({i}); in GetPerDeviceShape()
171 element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i)); in GetPerDeviceShape()
231 xla::HloSharding::FromProto(proto_arg.sharding()); in AddVariableUpdatesToCores()
[all …]
Dtpu_compile_op_support.h91 const xla::HloSharding& sharding,
115 xla::ShapeTree<xla::HloSharding> GetSubtree(
116 const xla::ShapeTree<xla::HloSharding>& tuple_shape_tree,
120 const xla::HloSharding& sharding,

12