Home
last modified time | relevance | path

Searched refs:sharding (Results 1 – 25 of 94) sorted by relevance

1234

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_test.cc57 HloSharding sharding = HloSharding::Replicate(); in TEST_F() local
58 EXPECT_TRUE(sharding.IsReplicated()); in TEST_F()
59 EXPECT_TRUE(sharding.IsTileMaximal()); in TEST_F()
60 EXPECT_TRUE(sharding.UsesDevice(0)); in TEST_F()
61 EXPECT_TRUE(sharding.UsesDevice(65535)); in TEST_F()
64 EXPECT_EQ(other, sharding); in TEST_F()
66 EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}), in TEST_F()
68 EXPECT_FALSE(sharding.HasUniqueDevice()); in TEST_F()
72 HloSharding sharding = HloSharding::AssignDevice(5); in TEST_F() local
73 EXPECT_FALSE(sharding.IsReplicated()); in TEST_F()
[all …]
Dhlo_sharding_metadata.cc53 const HloSharding& sharding) { in SetSingleSharding() argument
54 VLOG(4) << " " << instruction->name() << " to " << sharding; in SetSingleSharding()
55 instruction->set_single_sharding(sharding); in SetSingleSharding()
120 const HloSharding& sharding) { in FixupPassThroughDomainLinks() argument
127 gte->set_sharding(sharding); in FixupPassThroughDomainLinks()
143 std::shared_ptr<const HloSharding> sharding) { in CloneShardingForDomain() argument
144 auto single_sharding = sharding->ExtractSingleSharding(); in CloneShardingForDomain()
146 return sharding; in CloneShardingForDomain()
152 const HloSharding& sharding) { in ApplyDomainSingleSharding() argument
153 VLOG(4) << "Applying " << sharding << " sharding"; in ApplyDomainSingleSharding()
[all …]
Dsharding_propagation.cc58 bool IsSpatiallyPartitioned(const HloSharding& sharding) { in IsSpatiallyPartitioned() argument
59 if (sharding.IsTuple()) { in IsSpatiallyPartitioned()
60 return absl::c_any_of(sharding.tuple_elements(), IsSpatiallyPartitioned); in IsSpatiallyPartitioned()
62 return !sharding.IsTileMaximal() || sharding.IsReplicated(); in IsSpatiallyPartitioned()
66 return hlo->has_sharding() && IsSpatiallyPartitioned(hlo->sharding()); in IsSpatiallyPartitioned()
74 bool MaybeImproveInstructionSharding(HloSharding sharding, in MaybeImproveInstructionSharding() argument
78 if (!IsSpatiallyPartitioned(sharding)) { in MaybeImproveInstructionSharding()
83 instruction->set_sharding(std::move(sharding)); in MaybeImproveInstructionSharding()
86 int64 sharding_tiles = sharding.NumTiles(); in MaybeImproveInstructionSharding()
87 if (hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding, in MaybeImproveInstructionSharding()
[all …]
Dsharding_propagation_test.cc44 instruction->set_sharding(instruction->sharding().WithoutMetadata()); in ClearMetadata()
79 const HloSharding& sharding, in MatchAndExplain() argument
81 if (sharding.metadata().size() != metadata_.size()) { in MatchAndExplain()
82 *listener << sharding.ToString(/*include_metadata=*/true) in MatchAndExplain()
89 if (!protobuf_util::ProtobufEquals(sharding.metadata()[i], in MatchAndExplain()
91 *listener << sharding.ToString(/*include_metadata=*/true) in MatchAndExplain()
164 EXPECT_THAT(instruction->sharding(), in TEST_P()
167 EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); in TEST_P()
191 EXPECT_THAT(instruction->sharding(), in TEST_F()
214 EXPECT_THAT(instruction->sharding(), in TEST_F()
[all …]
Dhlo_sharding_util_test.cc39 HloSharding sharding = HloSharding::AssignDevice(7); in TEST() local
41 ReshapeSharding(input_shape, output_shape, sharding); in TEST()
43 EXPECT_EQ(result.value(), sharding); in TEST()
49 HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}})); in TEST() local
51 ReshapeSharding(input_shape, output_shape, sharding); in TEST()
112 HloSharding sharding = HloSharding::Tile(sharding_array); in TEST() local
114 ReshapeSharding(input_shape, output_shape, sharding); in TEST()
116 EXPECT_EQ(result.value(), sharding); in TEST()
146 HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}})); in TEST() local
147 absl::optional<HloSharding> result = ReshapeSharding(shape, shape, sharding); in TEST()
[all …]
Dhlo_sharding_util.cc117 const HloSharding& sharding) { in MergeSharding() argument
230 for (auto& it : instruction->sharding().UsedDevices(nullptr)) { in GetMostOccurringDevice()
247 for (auto& it : instruction->sharding().UsedDevices(&count)) { in GetDominantDevice()
268 HloSharding TransposeSharding(const HloSharding& sharding, in TransposeSharding() argument
270 if (sharding.IsTileMaximal()) { in TransposeSharding()
271 return sharding; in TransposeSharding()
274 if (sharding.ReplicateOnLastTileDim() && in TransposeSharding()
275 dimensions.size() < sharding.tile_assignment().num_dimensions()) { in TransposeSharding()
281 tile_assignment_dim[i] = sharding.tile_assignment().dim(perm_dimensions[i]); in TransposeSharding()
283 Array<int64> tile_assignment = sharding.tile_assignment(); in TransposeSharding()
[all …]
Dhlo_sharding.cc128 for (auto& sharding : shardings) { in Tuple() local
129 CHECK(!sharding.IsTuple()) << sharding.ToString(); in Tuple()
139 const HloSharding& sharding) { in SingleTuple() argument
141 CHECK(!sharding.IsTuple()) << sharding.ToString(); in SingleTuple()
144 flattened_list.resize(leaf_count, sharding); in SingleTuple()
149 const HloSharding& sharding) { in Single() argument
150 return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding; in Single()
483 TF_ASSIGN_OR_RETURN(HloSharding sharding, in FromProto()
485 tuple_shardings.push_back(sharding); in FromProto()
664 auto assign_metadata = [&](HloSharding& sharding) { in WithMetadata() argument
[all …]
Dhlo_sharding_util.h81 HloSharding TransposeSharding(const HloSharding& sharding,
91 const HloSharding& sharding);
96 HloSharding ReverseSharding(const HloSharding& sharding,
103 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
185 const HloSharding& sharding, const std::vector<int64>& available_devices);
190 const HloSharding& sharding, absl::Span<const int64> dims_to_replicate);
195 HloSharding RemoveShapeDimensions(const HloSharding& sharding,
/external/tensorflow/tensorflow/compiler/tf2xla/
Dsharding_util_test.cc28 [](absl::optional<xla::OpSharding> sharding) -> int64 { in TEST() argument
29 if (sharding.has_value() && in TEST()
30 sharding.value().type() == xla::OpSharding::MAXIMAL) { in TEST()
31 return sharding.value().tile_assignment_devices(0); in TEST()
77 auto check_metadata = [](const xla::OpSharding& sharding) { in TEST_P() argument
78 ASSERT_EQ(sharding.metadata_size(), 1); in TEST_P()
79 const auto& metadata = sharding.metadata(0); in TEST_P()
91 auto& sharding = status_or_sharding.ValueOrDie(); in TEST_P() local
92 ASSERT_TRUE(sharding.has_value()); in TEST_P()
93 if (sharding->type() == xla::OpSharding::TUPLE) { in TEST_P()
[all …]
Dsharding_util.cc37 void AssignOpMetadataToSharding(xla::OpSharding& sharding, in AssignOpMetadataToSharding() argument
40 if (sharding.type() == xla::OpSharding::TUPLE) { in AssignOpMetadataToSharding()
41 for (auto& sharding_element : *sharding.mutable_tuple_shardings()) { in AssignOpMetadataToSharding()
45 *sharding.add_metadata() = metadata; in AssignOpMetadataToSharding()
80 auto sharding = xla::sharding_builder::AssignDevice(core); in ParseShardingFromDevice() local
82 *sharding.add_metadata() = metadata.value(); in ParseShardingFromDevice()
84 return absl::optional<xla::OpSharding>(sharding); in ParseShardingFromDevice()
91 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, in ParseShardingFromDevice()
94 device_name, num_cores_per_replica, sharding, in ParseShardingFromDevice()
106 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, in ParseShardingFromDevice()
[all …]
Dxla_helpers.cc140 const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory, in RewriteLayoutWithShardedShape() argument
143 if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) { in RewriteLayoutWithShardedShape()
155 int64 device = *sharding->tile_assignment().begin(); in RewriteLayoutWithShardedShape()
157 sharding->TileOffsetForDevice(*xla_shape, device); in RewriteLayoutWithShardedShape()
158 std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device); in RewriteLayoutWithShardedShape()
183 absl::optional<xla::OpSharding> sharding, bool fast_mem) { in ReshapeWithCorrectRepresentationAndSharding() argument
187 auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; in ReshapeWithCorrectRepresentationAndSharding()
204 if (sharding) { in ReshapeWithCorrectRepresentationAndSharding()
206 xla::HloSharding::FromProto(*sharding)); in ReshapeWithCorrectRepresentationAndSharding()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/
Dtpu_sharding_identification.mlir1 // RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-sharding-identification | FileCheck…
3 // Tests empty cluster func. Empty input/output sharding configuration
20 // Tests with a block argument inputs/outputs with no xla sharding op attached
21 // gets default maximal(0) sharding configuration.
33 // CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
34 // CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
41 // Tests with a inputs/outputs with no xla sharding op attached gets
42 // default maximal(0) sharding configuration.
54 // CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
55 // CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
[all …]
Dtpu_space_to_depth_pass.mlir48sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {mhlo.sharding = …
49sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01…
116sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\2…
117sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\2…
190sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\2…
191sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\2…
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner.cc106 hlo->sharding().IsReplicated(); in ReportBeforePartition()
193 const HloSharding& sharding, absl::Span<const int64> replication_dims) { in GetPartitionGroupsForReplication() argument
196 group_size *= sharding.tile_assignment().dim(i); in GetPartitionGroupsForReplication()
199 sharding.tile_assignment().num_elements() / group_size); in GetPartitionGroupsForReplication()
200 sharding.tile_assignment().Each( in GetPartitionGroupsForReplication()
205 group_id *= sharding.tile_assignment().dim(i); in GetPartitionGroupsForReplication()
346 if (sharding() == target) { in Reshard()
351 hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles(); in Reshard()
361 .reshard_cache.emplace_back(sharding(), *this); in Reshard()
371 << hlo_->sharding().ToString() << " to " << target.ToString(); in ReshardNoCache()
[all …]
Dspmd_partitioner_util.cc47 bool HasReplicatedSharding(const HloSharding& sharding) { in HasReplicatedSharding() argument
48 if (sharding.IsTuple()) { in HasReplicatedSharding()
49 return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding); in HasReplicatedSharding()
51 return sharding.IsReplicated(); in HasReplicatedSharding()
127 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { in EvenlyPartitions() argument
128 if (sharding.IsTuple()) { in EvenlyPartitions()
131 sharding.GetSubSharding(shape, {i}))) { in EvenlyPartitions()
137 if (sharding.IsTileMaximal()) { in EvenlyPartitions()
138 return sharding.IsReplicated(); in EvenlyPartitions()
141 if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { in EvenlyPartitions()
[all …]
Dgather_scatter_handler.cc37 if (operand.sharding().IsTileMaximal()) { in GatherScatterOperandPartitionedOnlyOnTrivialSliceDims()
44 operand.sharding().tile_assignment().dim(dim); in GatherScatterOperandPartitionedOnlyOnTrivialSliceDims()
47 return trivial_slice_dims_partitions == operand.sharding().NumTiles(); in GatherScatterOperandPartitionedOnlyOnTrivialSliceDims()
58 operand.base_shape(), operand.sharding(), partition_id, b); in IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims()
64 int64 partitions = operand.sharding().tile_assignment().dim(dim); in IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims()
140 if (operand.sharding().IsTileMaximal()) { in PartitionIndexOnlyPartition()
141 if (!indices.sharding().IsTileMaximal() && in PartitionIndexOnlyPartition()
143 indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) == in PartitionIndexOnlyPartition()
163 indices.sharding(), index_dim_to_output_dim, in PartitionIndexOnlyPartition()
169 .Reshard(gather->sharding()) in PartitionIndexOnlyPartition()
[all …]
Dspmd_partitioner_util.h40 bool HasReplicatedSharding(const HloSharding& sharding);
85 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding);
89 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
98 const HloSharding& sharding,
106 const Shape& shape, const HloSharding& sharding,
112 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);
122 const HloSharding& sharding);
127 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b);
131 absl::optional<int64> UniqueTiledDim(const HloSharding& sharding);
281 int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
[all …]
Dconvolution_handler.cc97 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); in PartitionConvolutionWithBatchGroupCount()
99 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithBatchGroupCount()
102 (ShardCountAtDim(lhs.sharding(), dnums.input_batch_dimension()) == in PartitionConvolutionWithBatchGroupCount()
105 (ShardCountAtDim(rhs.sharding(), in PartitionConvolutionWithBatchGroupCount()
120 lhs.sharding(), lhs_to_output_indices); in PartitionConvolutionWithBatchGroupCount()
188 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); in PartitionConvolutionWithFeatureGroupCount()
190 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithFeatureGroupCount()
193 (ShardCountAtDim(lhs.sharding(), dnums.input_feature_dimension()) == in PartitionConvolutionWithFeatureGroupCount()
196 (ShardCountAtDim(rhs.sharding(), in PartitionConvolutionWithFeatureGroupCount()
213 lhs.sharding(), lhs_to_output_indices); in PartitionConvolutionWithFeatureGroupCount()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dxla_sharding_util.cc203 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) { in UnsupportedPartitionedShardingType() argument
204 return sharding != xla::OpSharding::REPLICATED && in UnsupportedPartitionedShardingType()
205 sharding != xla::OpSharding::OTHER; in UnsupportedPartitionedShardingType()
240 xla::OpSharding sharding; in ExtractInputsForLogicalDevices() local
241 sharding.ParseFromString( in ExtractInputsForLogicalDevices()
244 const auto input_sharding_type = sharding.type(); in ExtractInputsForLogicalDevices()
276 for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) { in ExtractInputsForLogicalDevices()
278 sharding.tile_assignment_devices(i); in ExtractInputsForLogicalDevices()
289 cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs); in ExtractInputsForLogicalDevices()
296 for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) { in ExtractInputsForLogicalDevices()
[all …]
/external/tensorflow/tensorflow/core/tpu/kernels/xla/
Dinfeed_op.cc52 absl::optional<xla::OpSharding> sharding) { in UpdateInfeedLayout() argument
53 if (sharding && sharding->type() == xla::OpSharding::OTHER) { in UpdateInfeedLayout()
55 xla::HloSharding::FromProto(*sharding)); in UpdateInfeedLayout()
56 for (int64 i = 0; i < sharding->tile_assignment_devices_size(); ++i) { in UpdateInfeedLayout()
57 auto device = sharding->tile_assignment_devices(i); in UpdateInfeedLayout()
98 OP_REQUIRES_OK(ctx, UpdateInfeedLayout(&xla_shape_, b->sharding())); in Compile()
131 absl::optional<xla::OpSharding> sharding; in Compile() local
132 if (b->sharding()) { in Compile()
133 sharding = b->sharding()->type() == xla::OpSharding::TUPLE in Compile()
134 ? b->sharding()->tuple_shardings(i) in Compile()
[all …]
/external/tensorflow/tensorflow/core/tpu/kernels/
Dtpu_compile_op_support.cc162 Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding, in GetPerDeviceShape() argument
165 ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape); in GetPerDeviceShape()
181 if (sharding.IsTileMaximal()) { in GetPerDeviceShape()
186 std::vector<int64> offset = sharding.TileOffsetForDevice(shape, device); in GetPerDeviceShape()
187 std::vector<int64> limit = sharding.TileLimitForDevice(shape, device); in GetPerDeviceShape()
212 const auto& sharding = proto_arg.sharding(); in AddVariableUpdatesToCores() local
227 if (sharding.type() == xla::OpSharding::MAXIMAL) { in AddVariableUpdatesToCores()
228 add_to_core(sharding.tile_assignment_devices(0), shape); in AddVariableUpdatesToCores()
229 } else if (sharding.type() == xla::OpSharding::OTHER) { in AddVariableUpdatesToCores()
231 xla::HloSharding::FromProto(proto_arg.sharding()); in AddVariableUpdatesToCores()
[all …]
Dtpu_compile_op_common.cc99 (*arg_core_mapping)[arg_index].sharding = proto_arg.sharding(); in SetPerCoreArgShapes()
100 if (proto_arg.sharding().type() == xla::OpSharding::MAXIMAL) { in SetPerCoreArgShapes()
101 const int core = proto_arg.sharding().tile_assignment_devices(0); in SetPerCoreArgShapes()
106 } else if (proto_arg.sharding().type() == xla::OpSharding::OTHER) { in SetPerCoreArgShapes()
108 xla::HloSharding::FromProto(proto_arg.sharding())); in SetPerCoreArgShapes()
109 for (int core : proto_arg.sharding().tile_assignment_devices()) { in SetPerCoreArgShapes()
120 TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED) in SetPerCoreArgShapes()
154 (*retval_core_mapping)[i].sharding = proto_retval.sharding(); in AssignReturnValueToCore()
155 if (proto_retval.sharding().type() == xla::OpSharding::MAXIMAL) { in AssignReturnValueToCore()
156 int core = proto_retval.sharding().tile_assignment_devices(0); in AssignReturnValueToCore()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/
Dargument-sharding.mlir4sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\…
34 // CHECK-SAME: sharding={
40 // CHECK-SAME: sharding={devices=[1,2]0,1}
42 // CHECK-SAME: sharding={maximal device=0}
44 // CHECK-SAME: sharding={replicated}
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dspmd_manual_sharding_ops.cc42 xla::OpSharding sharding; in Compile() local
43 if (!sharding.ParseFromString(manual_sharding_str_)) { in Compile()
51 if (sharding.type() == xla::OpSharding::OTHER) { in Compile()
53 int64 partitions_i = sharding.tile_assignment_dimensions(i); in Compile()
64 sharding); in Compile()
105 xla::OpSharding sharding; in Compile() local
106 if (!sharding.ParseFromString(manual_sharding_str_)) { in Compile()
127 sharding); in Compile()
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/
Dxla_sharding.py182 tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString())
184 tensor = tf2xla.sharding(
185 tensor, sharding=proto.SerializeToString())
245 sharding = get_tensor_sharding(from_tensor)
246 if sharding is None:
250 to_tensor = tf2xla.sharding(to_tensor, sharding=sharding)
251 attr_value = attr_value_pb2.AttrValue(s=sharding)
469 sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping)
470 return sharding.apply_to_tensor(tensor, use_sharding_op=use_sharding_op)

1234