1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "absl/types/optional.h" 26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 31 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 32 33 namespace xla { 34 namespace spmd { 35 36 struct SpmdPartitionerOptions { 37 // Always exchange halo on LHS for all convolutions. If false, backprop filter 38 // convolution exchanges halo on RHS. 39 bool conv_halo_exchange_always_on_lhs = true; 40 41 // The number of instructions to be reported for the highest memory profile 42 // instructions. 43 int64 report_instruction_count = 5; 44 45 // The minimum size in MiB of an einsum operand to be considered using 46 // windowed implementation in an HLO loop. 47 int64 threshold_for_windowed_einsum_mib = 256; 48 49 // Whether unroll windowed einsum loop by degree of two. 50 bool unroll_windowed_einsum = false; 51 52 // Whether doing bidirectional collective permute in windowed einsum loop. 53 bool bidirectional_windowed_einsum = false; 54 55 // Whether the entry computations' signature could change after partitioning. 56 bool allow_module_signature_change = false; 57 58 // Whether to use cached all-gather to avoid repeatedly replicate a tiled 59 // tensor. If it is set to false, the result tends to be more 60 // memory-efficient, and the compiler can use the ScheduleAwareAllGatherCSE 61 // pass to CSE some all-gathers which are relatively close to each other. 62 bool cache_all_gather = true; 63 // When making a compromise between windowed einsum speed and memory usage 64 // prefer the former if true. 65 bool choose_faster_windowed_einsum_over_mem = false; 66 }; 67 68 // Class to wrap the computation builder to capture information during SPMD 69 // transformation. 70 class SpmdBuilder : public HloComputation::Builder { 71 public: SpmdBuilder(const std::string & name,HloInstruction * hlo)72 SpmdBuilder(const std::string& name, HloInstruction* hlo) 73 : HloComputation::Builder(name) { 74 visiting_hlo_ = hlo; 75 } 76 HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction); 77 derived_instructions(HloInstruction * hlo)78 const std::vector<HloInstruction*>& derived_instructions( 79 HloInstruction* hlo) { 80 return instructions_.at(hlo); 81 } 82 set_visiting_hlo(HloInstruction * hlo)83 void set_visiting_hlo(HloInstruction* hlo) { visiting_hlo_ = hlo; } 84 visiting_hlo()85 HloInstruction* visiting_hlo() const { return visiting_hlo_; } 86 87 // Wrapper of queries to broadcast_dims_. BroadcastDimsForCreatedHlo(const HloInstruction * hlo)88 absl::optional<const absl::flat_hash_set<int64>*> BroadcastDimsForCreatedHlo( 89 const HloInstruction* hlo) { 90 auto it = broadcast_dims_.find(hlo); 91 if (it == broadcast_dims_.end()) { 92 return absl::nullopt; 93 } 94 return &it->second; 95 } 96 97 private: 98 // Currently visiting instruction. 99 HloInstruction* visiting_hlo_; 100 101 // Map from the currently visiting (old) instruction to new instructions 102 // created during SPMD partitioning. 103 HloInstructionMap<std::vector<HloInstruction*>> instructions_; 104 105 // Maps from each created instruction to a set of dimensions that are from 106 // broadcasts or elementwise ops over broadcasts. This means elements along 107 // these dimensions have the same value. 108 absl::flat_hash_map<const HloInstruction*, absl::flat_hash_set<int64>> 109 broadcast_dims_; 110 }; 111 112 // A set of functions that create the cross-partition collective ops. 113 struct SPMDCollectiveOpsCreator { 114 // Function used to create a partition ID HLO. 115 std::function<HloInstruction*(SpmdBuilder*)> create_partition_id; 116 117 // Function used to create a cross-partition all-reduce HLO. 118 std::function<HloInstruction*( 119 SpmdBuilder*, HloInstruction* operand, HloComputation* reduction, 120 const std::vector<std::vector<int64>>& partition_subgroups, 121 int64 channel_id)> 122 create_cross_partition_all_reduce; 123 124 // Function used to create a cross-partition collective-permute HLO. 125 std::function<HloInstruction*( 126 SpmdBuilder*, HloInstruction* operand, 127 std::vector<std::pair<int64, int64>>& src_dst_pairs, 128 int64 next_channel_id)> 129 create_cross_partition_collective_permute; 130 131 // Function used to create a cross-partition all-to-all HLO. 132 std::function<HloInstruction*( 133 SpmdBuilder*, absl::Span<HloInstruction* const> operands, 134 const std::vector<std::vector<int64>>& partition_subgroups, 135 int64 channel_id, absl::optional<int64> split_dimension)> 136 create_cross_partition_all_to_all; 137 138 // Function used to create a cross-partition all-gather HLO. This is optional: 139 // if it is nullptr, the partitioner will use all-reduce instead. 140 std::function<HloInstruction*( 141 SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape, 142 const std::vector<std::vector<int64>>& partition_subgroups, 143 int64 channel_id, int64 all_gather_dimension)> 144 create_cross_partition_all_gather; 145 }; 146 147 // Create a default SPMDCollectiveOpsCreator. 148 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions, 149 int64 num_replicas); 150 151 // Logger to report memory usage during SPMD partitioning. 152 class SpmdLogger { 153 public: SpmdLogger(int64 report_instruction_count)154 explicit SpmdLogger(int64 report_instruction_count) 155 : report_instruction_count_(report_instruction_count) {} 156 static std::string ReportBeforePartition(const HloModule& module, 157 int64 report_instruction_count); 158 static std::string ReportAfterPartition(const HloModule& module, 159 int64 report_instruction_count); 160 161 // Registers the logging for the groups of instructions created to transform 162 // the given hlo. 163 void RegisterLogEntry(HloInstruction* hlo, 164 const std::vector<HloInstruction*>& group); 165 166 std::string MakeReport(); 167 168 private: 169 template <typename F> 170 static std::string ReportMemoryUsage(const HloModule& module, const F& filter, 171 int64 report_instruction_count); 172 173 // A vector of logging messages (one for each original HLO instruction), where 174 // the first integer of the pair represents the size of the HBM used. 175 std::vector<std::pair<int64, std::string>> entries_; 176 177 int64 report_instruction_count_; 178 }; 179 180 class SpmdPartitioningVisitor; 181 182 class SpmdPartitioner : public HloModulePass { 183 public: 184 SpmdPartitioner(int64 num_partitions, int64 num_replicas, 185 SpmdPartitionerOptions options); SpmdPartitioner(int64 num_partitions,int64 num_replicas,SpmdPartitionerOptions options,SPMDCollectiveOpsCreator collective_ops_creator)186 SpmdPartitioner(int64 num_partitions, int64 num_replicas, 187 SpmdPartitionerOptions options, 188 SPMDCollectiveOpsCreator collective_ops_creator) 189 : num_partitions_(num_partitions), 190 num_replicas_(num_replicas), 191 options_(std::move(options)), 192 collective_ops_creator_(std::move(collective_ops_creator)) {} name()193 absl::string_view name() const override { return "spmd-partitioning"; } 194 StatusOr<bool> Run(HloModule* module) override; 195 196 // Transforms the given computation with SPMD instructions, replacing it with 197 // a new computation. 198 StatusOr<bool> PartitionComputation(HloComputation* computation, 199 const HloSharding& root_sharding, 200 int64* next_channel_id, 201 SpmdLogger* logger); 202 203 // Creates all-gather(s) based on HloSharding. Can be overridden to customize. 204 // The default uses a single all-gather even if there are multiple sharded 205 // dimensions, and adds potential reshapes and transposes to achieve that. 206 // If it returns false, the partitioner will fall back to all-reduce. 207 // `selected_dims` specifies the dimensions along which the all-gather happens 208 // in the tiled sharding, which allows potentially creating a subgroup 209 // all-gather. 210 virtual HloInstruction* AllGatherShards( 211 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 212 int64* next_channel_id, absl::Span<const int64> selected_dims, 213 const SPMDCollectiveOpsCreator& collectives_creator); 214 215 // Creates all-reduce(s) across devices along selected_dims in sharding. Can 216 // be overridden to customize. 217 virtual HloInstruction* AllReduceAlongShardingDims( 218 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 219 int64* next_channel_id, absl::Span<const int64> selected_dims, 220 const SPMDCollectiveOpsCreator& collectives_creator, 221 HloComputation* reduction); 222 options()223 const SpmdPartitionerOptions& options() { return options_; } 224 225 protected: 226 virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor( 227 HloComputation* computation, int64 num_partitions, int64 num_replicas, 228 const SPMDCollectiveOpsCreator& collective_ops_creator, 229 int64* next_channel_id, SpmdLogger* logger, 230 SpmdPartitionerOptions options); 231 232 HloInstruction* AllGatherShardsInternal( 233 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 234 int64* next_channel_id, absl::Span<const int64> selected_dims, 235 const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag); 236 HloInstruction* AllReduceAlongShardingDimsInternal( 237 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 238 int64* next_channel_id, absl::Span<const int64> selected_dims, 239 const SPMDCollectiveOpsCreator& collectives_creator, 240 HloComputation* reduction, bool per_dim_ar); 241 242 // Verify that the sharding of instructions in the module are valid, and also 243 // fill in missing sharding information. 244 Status PreprocessSharding(HloModule* module); 245 246 const int64 num_partitions_; 247 const int64 num_replicas_; 248 249 SpmdPartitionerOptions options_; 250 SPMDCollectiveOpsCreator collective_ops_creator_; 251 }; 252 253 // Class describes partition state of the data represented by an HLO created 254 // during SPMD partitioning pass. 255 // 256 // Data on some devices may include padding region, if the base (full) shape 257 // could not be evenly partitioned. 258 class PartitionedHlo { 259 public: 260 // Return value for ReshardAsWindowedInput which describes the resharded HLO, 261 // the window for the user on the shard, and if necessary, the dynamic slice 262 // offsets to be applied to the output of the op being sharded. 263 struct WindowedInputShardReturnValue { 264 HloInstruction* sharded_input; 265 Window shard_window; 266 absl::optional<std::vector<HloInstruction*>> dynamic_slice_index_on_output; 267 }; 268 // A cache for resharding each partitioned HLO. 269 struct ReshardCache { 270 struct PerHloCache { 271 std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache; 272 std::vector< 273 std::tuple<HloSharding, Window, WindowedInputShardReturnValue>> 274 window_reshard_cache; 275 }; 276 // Use std::unordered_map for pointer stability. 277 std::unordered_map<HloInstruction*, PerHloCache> per_hlo_cache; 278 // Caches for nested partitioning of grouped sharding. Each string key 279 // represents a unique way of grouping devices. 280 absl::flat_hash_map<std::string, std::unique_ptr<ReshardCache>> 281 groupd_caches; 282 }; 283 struct PartitioningState { 284 SpmdBuilder* b; 285 HloModule* module; 286 int64 num_replicas; 287 HloInstruction* partition_id; 288 SPMDCollectiveOpsCreator collective_ops_creator; 289 int64* next_channel_id; 290 ReshardCache* reshard_cache; 291 SpmdPartitioner* partitioner; 292 }; PartitionedHlo(HloInstruction * hlo,Shape base_shape,PartitioningState state)293 PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) 294 : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) { 295 CHECK(hlo->has_sharding()) 296 << "PartitionedHlo is missing sharding:" << hlo->ToString(); 297 // If the tuple shape instruction does not have a tuple sharding, reassign 298 // to use the tuple sharding. Reshard() implementation assumes this. 299 if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) { 300 hlo_->set_sharding( 301 hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie()); 302 } 303 } 304 305 // Reshards the current SPMD instruction to a new sharding. Could only modify 306 // the reshard cache. 307 PartitionedHlo Reshard(const HloSharding& target); 308 309 // Pads the garbage area of the output with the provided value. Normally, 310 // unevenly partitioned dimensions are padded on the right, but this function 311 // allows specifying left-padded dimensions, which can be used during the 312 // handling of kReverse, etc. 313 PartitionedHlo PadWithValue(HloInstruction* pad_value, 314 absl::Span<const int64> left_padded_dims = {}, 315 absl::Span<const int64> skipped_dims = {}) const; 316 317 // Returns the SPMD instruction. hlo()318 HloInstruction* hlo() const { return hlo_; } 319 320 // Returns the sharding of the SPMD instruction. sharding()321 const HloSharding& sharding() const { return hlo_->sharding(); } 322 323 // Original full shape of the data. base_shape()324 const Shape& base_shape() const { return base_shape_; } 325 NewChannel()326 int64 NewChannel() const { return (*state_.next_channel_id)++; } 327 328 // Reshards the HLO to a usable partitioned input for a windowed user. Could 329 // only modify the reshard cache. 330 absl::optional<WindowedInputShardReturnValue> ReshardAsWindowedInput( 331 const Window& window, const HloSharding& target, 332 HloInstruction* pad_value, bool mask_invalid_region = true); 333 state()334 const PartitioningState& state() const { return state_; } 335 336 // Helper function to replicate the data on all devices. Could only modify 337 // the reshard cache. 338 PartitionedHlo Replicate(); 339 340 // Helper function to replicate the data for partitions along the given dims. 341 HloInstruction* ReplicatePartial(absl::Span<const int64> dims); 342 343 private: 344 // Same as Reshard except that it does not explicitly modify the reshard 345 // cache, although it would indirectly modify by calling Replicate(). 346 PartitionedHlo ReshardNoCache(const HloSharding& target); 347 348 // Helper function to broadcast data from a single device to all devices. 349 PartitionedHlo Broadcast() const; 350 351 // Helper function to reshard the tensor using AllToAll (instead of the 352 // default of Replicate followed by Slice). 353 PartitionedHlo ReshardWithAllToAll( 354 const HloSharding& target, 355 absl::Span<const std::pair<int64, int64>> source_target_dims) const; 356 357 // Helper function to reshard the tensor using CollectivePermute. 358 PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; 359 360 // Helper function to reshard to partial replicate using AllGather. 361 absl::optional<PartitionedHlo> ReshardToPartialReplicateWithAllGather( 362 const HloSharding& target); 363 364 // Helper function to reshard from partial replicate using DynamicSlice. 365 absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice( 366 const HloSharding& target); 367 368 // Helper function to reshard from partial replicate using AllToAll. 369 absl::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll( 370 const HloSharding& target); 371 372 // SPMD instruction. 373 HloInstruction* hlo_; 374 375 // The original shape of the data before SPMD transformation is applied. 376 Shape base_shape_; 377 378 PartitioningState state_; 379 }; 380 381 struct DotConvDimsMapping { 382 // The dimension numbers for the operands and output corresponding to a 383 // logical dimension (e.g., batch, contracting, non-contracting). If an 384 // operand or the output doesn't have the logical dimension, it is set to 385 // -1. 386 struct DimsMapping { 387 int64 lhs; 388 int64 rhs; 389 int64 output; 390 // input mapped to index in input_spatial_dimensions(). 391 int64 spatial; 392 }; 393 std::vector<DimsMapping> batch_dims; 394 std::vector<DimsMapping> contracting_dims; 395 std::vector<DimsMapping> lhs_non_contracting_dims; 396 std::vector<DimsMapping> rhs_non_contracting_dims; 397 std::vector<DimsMapping> conv_spatial_dims; 398 }; 399 400 class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { 401 public: 402 SpmdPartitioningVisitor( 403 HloComputation* computation, int64 num_partitions, int64 num_replicas, 404 const SPMDCollectiveOpsCreator& collective_ops_creator, 405 int64* next_channel_id, SpmdLogger* logger, 406 SpmdPartitionerOptions options, SpmdPartitioner* partitioner); 407 408 Status DefaultAction(HloInstruction* hlo) override; 409 Status HandleAllReduce(HloInstruction* hlo) override; 410 Status HandleBroadcast(HloInstruction* hlo) override; 411 Status HandleConstant(HloInstruction* hlo) override; 412 Status HandleCustomCall(HloInstruction* hlo) override; 413 Status HandleDot(HloInstruction* hlo) override; 414 Status HandleDynamicSlice(HloInstruction* hlo) override; 415 Status HandleDynamicUpdateSlice(HloInstruction* hlo) override; 416 Status HandleFft(HloInstruction* hlo) override; 417 Status HandleGather(HloInstruction* hlo) override; 418 Status HandleGetTupleElement(HloInstruction* hlo) override; 419 Status HandleInfeed(HloInstruction* hlo) override; 420 Status HandleOutfeed(HloInstruction* hlo) override; 421 Status HandlePad(HloInstruction* hlo) override; 422 Status HandleParameter(HloInstruction* hlo) override; 423 Status HandleReduce(HloInstruction* hlo) override; 424 Status HandleReverse(HloInstruction* hlo) override; 425 Status HandleWhile(HloInstruction* hlo) override; 426 Status HandleConditional(HloInstruction* hlo) override; 427 Status HandleReduceWindow(HloInstruction* hlo) override; 428 Status HandleSelectAndScatter(HloInstruction* hlo) override; 429 Status HandleTuple(HloInstruction* hlo) override; 430 Status HandleRng(HloInstruction* hlo) override; 431 Status HandleConvolution(HloInstruction* hlo) override; 432 Status HandleConcatenate(HloInstruction* hlo) override; 433 Status HandleScatter(HloInstruction* hlo) override; 434 Status HandleSlice(HloInstruction* hlo) override; 435 Status HandleSort(HloInstruction* hlo) override; 436 Status HandleTranspose(HloInstruction* hlo) override; 437 Status HandleReshape(HloInstruction* hlo) override; 438 Status HandleIota(HloInstruction* hlo) override; 439 Status HandlePartitionId(HloInstruction* hlo) override; 440 441 // Implementation of dot partitioning given DotGeneralDimsMapping. 442 Status HandleDotHelper(HloInstruction* hlo, 443 const DotConvDimsMapping& dims_mapping, 444 const std::function<StatusOr<HloInstruction*>( 445 HloInstruction*, HloInstruction*, SpmdBuilder*, 446 const Window& conv_window)>& create_sharded_dot); 447 448 // Common handle for elementwise HLOs. 449 Status HandleElementwise(HloInstruction* hlo); 450 451 // Common handle for HLOs that runs on a single device. 452 Status HandleSingleDevice(const HloInstruction* hlo); 453 454 // Returns the PartitionedHlo that corresponds to the original hlo. GetPartitionedHlo(const HloInstruction * hlo)455 PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) { 456 CHECK_EQ(partitioned_instructions_.count(hlo), 1); 457 return partitioned_instructions_.find(hlo)->second; 458 } 459 460 // Sets the PartitionedHlo for the original hlo. SetPartitionedHlo(const HloInstruction * hlo,const PartitionedHlo & partitioned_hlo)461 void SetPartitionedHlo(const HloInstruction* hlo, 462 const PartitionedHlo& partitioned_hlo) { 463 CHECK_EQ(partitioned_instructions_.count(hlo), 0); 464 partitioned_instructions_.emplace(hlo, partitioned_hlo); 465 changed_ = true; 466 } 467 468 // Convenient wrapper that creates PartitionedHlo from the result of the func 469 // and maps it to the given original hlo. SetPartitionedHlo(const HloInstruction * hlo,const std::function<HloInstruction * ()> & func)470 void SetPartitionedHlo(const HloInstruction* hlo, 471 const std::function<HloInstruction*()>& func) { 472 HloInstruction* new_hlo = func(); 473 new_hlo->set_sharding(hlo->sharding()); 474 new_hlo->set_metadata(hlo->metadata()); 475 SetPartitionedHlo( 476 hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState())); 477 changed_ = true; 478 } 479 NewChannel()480 int64 NewChannel() { return (*next_channel_id_)++; } 481 MakePartitioningState()482 PartitionedHlo::PartitioningState MakePartitioningState() { 483 PartitionedHlo::PartitioningState state; 484 state.b = &b_; 485 state.module = module_; 486 state.num_replicas = num_replicas_; 487 state.partition_id = partition_id_; 488 state.collective_ops_creator = collective_ops_creator_; 489 state.next_channel_id = next_channel_id_; 490 state.reshard_cache = &reshard_cache_; 491 state.partitioner = partitioner_; 492 return state; 493 } 494 builder()495 SpmdBuilder* builder() { return &b_; } 496 497 StatusOr<bool> DoPartition(HloComputation* computation, 498 const HloSharding& root_sharding, 499 const SpmdPartitionerOptions& options); 500 501 // Information about a loop created for windowed dot-general. Used when 502 // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor 503 // finishes traversing the graph. 504 struct WindowedDotGeneralLoop { 505 HloInstruction* while_loop; 506 int64 windowed_operand; 507 bool windowed_in_contracting_dims; 508 bool windowed_in_batch_dims; 509 bool operands_sharded_at_contracting_dims; 510 }; 511 512 private: 513 Status Preprocess(HloInstruction* hlo) override; 514 Status Postprocess(HloInstruction* hlo) override; 515 516 // Performs code motion for windowed dot-general loops in 517 // windowed_dot_general_loops_. Invoked after the visitor finishes traversing 518 // the graph. 519 Status DoCodeMotionForWindowedDotGeneralLoops( 520 HloComputation* computation, const SpmdPartitionerOptions& options); 521 522 bool changed_; 523 HloModule* module_; 524 int64 num_partitions_; 525 int64 num_replicas_; 526 527 SPMDCollectiveOpsCreator collective_ops_creator_; 528 529 // Tracks the next channel id to use for cross-partition all-reduce. 530 int64* next_channel_id_; 531 SpmdBuilder b_; 532 533 HloInstruction* partition_id_; 534 535 PartitionedHlo::ReshardCache reshard_cache_; 536 537 // Mapping from the instruction in the original computation to the new SPMD 538 // partitioned instruction. 539 ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_; 540 541 std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_; 542 543 HloInstruction* visiting_hlo_; 544 SpmdLogger* logger_; 545 const SpmdPartitionerOptions options_; 546 SpmdPartitioner* partitioner_; 547 std::vector<HloSharding> visiting_hlo_operand_shardings_; 548 absl::optional<HloSharding> visiting_hlo_sharding_; 549 }; 550 551 } // namespace spmd 552 } // namespace xla 553 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ 554