1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
31 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
32 
33 namespace xla {
34 namespace spmd {
35 
36 struct SpmdPartitionerOptions {
37   // Always exchange halo on LHS for all convolutions. If false, backprop filter
38   // convolution exchanges halo on RHS.
39   bool conv_halo_exchange_always_on_lhs = true;
40 
41   // The number of instructions to be reported for the highest memory profile
42   // instructions.
43   int64 report_instruction_count = 5;
44 
45   // The minimum size in MiB of an einsum operand to be considered using
46   // windowed implementation in an HLO loop.
47   int64 threshold_for_windowed_einsum_mib = 256;
48 
49   // Whether unroll windowed einsum loop by degree of two.
50   bool unroll_windowed_einsum = false;
51 
52   // Whether doing bidirectional collective permute in windowed einsum loop.
53   bool bidirectional_windowed_einsum = false;
54 
55   // Whether the entry computations' signature could change after partitioning.
56   bool allow_module_signature_change = false;
57 
58   // Whether to use cached all-gather to avoid repeatedly replicate a tiled
59   // tensor. If it is set to false, the result tends to be more
60   // memory-efficient, and the compiler can use the ScheduleAwareAllGatherCSE
61   // pass to CSE some all-gathers which are relatively close to each other.
62   bool cache_all_gather = true;
63   // When making a compromise between windowed einsum speed and memory usage
64   // prefer the former if true.
65   bool choose_faster_windowed_einsum_over_mem = false;
66 };
67 
68 // Class to wrap the computation builder to capture information during SPMD
69 // transformation.
70 class SpmdBuilder : public HloComputation::Builder {
71  public:
SpmdBuilder(const std::string & name,HloInstruction * hlo)72   SpmdBuilder(const std::string& name, HloInstruction* hlo)
73       : HloComputation::Builder(name) {
74     visiting_hlo_ = hlo;
75   }
76   HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
77 
derived_instructions(HloInstruction * hlo)78   const std::vector<HloInstruction*>& derived_instructions(
79       HloInstruction* hlo) {
80     return instructions_.at(hlo);
81   }
82 
set_visiting_hlo(HloInstruction * hlo)83   void set_visiting_hlo(HloInstruction* hlo) { visiting_hlo_ = hlo; }
84 
visiting_hlo()85   HloInstruction* visiting_hlo() const { return visiting_hlo_; }
86 
87   // Wrapper of queries to broadcast_dims_.
BroadcastDimsForCreatedHlo(const HloInstruction * hlo)88   absl::optional<const absl::flat_hash_set<int64>*> BroadcastDimsForCreatedHlo(
89       const HloInstruction* hlo) {
90     auto it = broadcast_dims_.find(hlo);
91     if (it == broadcast_dims_.end()) {
92       return absl::nullopt;
93     }
94     return &it->second;
95   }
96 
97  private:
98   // Currently visiting instruction.
99   HloInstruction* visiting_hlo_;
100 
101   // Map from the currently visiting (old) instruction to new instructions
102   // created during SPMD partitioning.
103   HloInstructionMap<std::vector<HloInstruction*>> instructions_;
104 
105   // Maps from each created instruction to a set of dimensions that are from
106   // broadcasts or elementwise ops over broadcasts. This means elements along
107   // these dimensions have the same value.
108   absl::flat_hash_map<const HloInstruction*, absl::flat_hash_set<int64>>
109       broadcast_dims_;
110 };
111 
112 // A set of functions that create the cross-partition collective ops.
113 struct SPMDCollectiveOpsCreator {
114   // Function used to create a partition ID HLO.
115   std::function<HloInstruction*(SpmdBuilder*)> create_partition_id;
116 
117   // Function used to create a cross-partition all-reduce HLO.
118   std::function<HloInstruction*(
119       SpmdBuilder*, HloInstruction* operand, HloComputation* reduction,
120       const std::vector<std::vector<int64>>& partition_subgroups,
121       int64 channel_id)>
122       create_cross_partition_all_reduce;
123 
124   // Function used to create a cross-partition collective-permute HLO.
125   std::function<HloInstruction*(
126       SpmdBuilder*, HloInstruction* operand,
127       std::vector<std::pair<int64, int64>>& src_dst_pairs,
128       int64 next_channel_id)>
129       create_cross_partition_collective_permute;
130 
131   // Function used to create a cross-partition all-to-all HLO.
132   std::function<HloInstruction*(
133       SpmdBuilder*, absl::Span<HloInstruction* const> operands,
134       const std::vector<std::vector<int64>>& partition_subgroups,
135       int64 channel_id, absl::optional<int64> split_dimension)>
136       create_cross_partition_all_to_all;
137 
138   // Function used to create a cross-partition all-gather HLO. This is optional:
139   // if it is nullptr, the partitioner will use all-reduce instead.
140   std::function<HloInstruction*(
141       SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape,
142       const std::vector<std::vector<int64>>& partition_subgroups,
143       int64 channel_id, int64 all_gather_dimension)>
144       create_cross_partition_all_gather;
145 };
146 
147 // Create a default SPMDCollectiveOpsCreator.
148 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
149                                                         int64 num_replicas);
150 
151 // Logger to report memory usage during SPMD partitioning.
152 class SpmdLogger {
153  public:
SpmdLogger(int64 report_instruction_count)154   explicit SpmdLogger(int64 report_instruction_count)
155       : report_instruction_count_(report_instruction_count) {}
156   static std::string ReportBeforePartition(const HloModule& module,
157                                            int64 report_instruction_count);
158   static std::string ReportAfterPartition(const HloModule& module,
159                                           int64 report_instruction_count);
160 
161   // Registers the logging for the groups of instructions created to transform
162   // the given hlo.
163   void RegisterLogEntry(HloInstruction* hlo,
164                         const std::vector<HloInstruction*>& group);
165 
166   std::string MakeReport();
167 
168  private:
169   template <typename F>
170   static std::string ReportMemoryUsage(const HloModule& module, const F& filter,
171                                        int64 report_instruction_count);
172 
173   // A vector of logging messages (one for each original HLO instruction), where
174   // the first integer of the pair represents the size of the HBM used.
175   std::vector<std::pair<int64, std::string>> entries_;
176 
177   int64 report_instruction_count_;
178 };
179 
180 class SpmdPartitioningVisitor;
181 
182 class SpmdPartitioner : public HloModulePass {
183  public:
184   SpmdPartitioner(int64 num_partitions, int64 num_replicas,
185                   SpmdPartitionerOptions options);
SpmdPartitioner(int64 num_partitions,int64 num_replicas,SpmdPartitionerOptions options,SPMDCollectiveOpsCreator collective_ops_creator)186   SpmdPartitioner(int64 num_partitions, int64 num_replicas,
187                   SpmdPartitionerOptions options,
188                   SPMDCollectiveOpsCreator collective_ops_creator)
189       : num_partitions_(num_partitions),
190         num_replicas_(num_replicas),
191         options_(std::move(options)),
192         collective_ops_creator_(std::move(collective_ops_creator)) {}
name()193   absl::string_view name() const override { return "spmd-partitioning"; }
194   StatusOr<bool> Run(HloModule* module) override;
195 
196   // Transforms the given computation with SPMD instructions, replacing it with
197   // a new computation.
198   StatusOr<bool> PartitionComputation(HloComputation* computation,
199                                       const HloSharding& root_sharding,
200                                       int64* next_channel_id,
201                                       SpmdLogger* logger);
202 
203   // Creates all-gather(s) based on HloSharding. Can be overridden to customize.
204   // The default uses a single all-gather even if there are multiple sharded
205   // dimensions, and adds potential reshapes and transposes to achieve that.
206   // If it returns false, the partitioner will fall back to all-reduce.
207   // `selected_dims` specifies the dimensions along which the all-gather happens
208   // in the tiled sharding, which allows potentially creating a subgroup
209   // all-gather.
210   virtual HloInstruction* AllGatherShards(
211       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
212       int64* next_channel_id, absl::Span<const int64> selected_dims,
213       const SPMDCollectiveOpsCreator& collectives_creator);
214 
215   // Creates all-reduce(s) across devices along selected_dims in sharding. Can
216   // be overridden to customize.
217   virtual HloInstruction* AllReduceAlongShardingDims(
218       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
219       int64* next_channel_id, absl::Span<const int64> selected_dims,
220       const SPMDCollectiveOpsCreator& collectives_creator,
221       HloComputation* reduction);
222 
options()223   const SpmdPartitionerOptions& options() { return options_; }
224 
225  protected:
226   virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor(
227       HloComputation* computation, int64 num_partitions, int64 num_replicas,
228       const SPMDCollectiveOpsCreator& collective_ops_creator,
229       int64* next_channel_id, SpmdLogger* logger,
230       SpmdPartitionerOptions options);
231 
232   HloInstruction* AllGatherShardsInternal(
233       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
234       int64* next_channel_id, absl::Span<const int64> selected_dims,
235       const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag);
236   HloInstruction* AllReduceAlongShardingDimsInternal(
237       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
238       int64* next_channel_id, absl::Span<const int64> selected_dims,
239       const SPMDCollectiveOpsCreator& collectives_creator,
240       HloComputation* reduction, bool per_dim_ar);
241 
242   // Verify that the sharding of instructions in the module are valid, and also
243   // fill in missing sharding information.
244   Status PreprocessSharding(HloModule* module);
245 
246   const int64 num_partitions_;
247   const int64 num_replicas_;
248 
249   SpmdPartitionerOptions options_;
250   SPMDCollectiveOpsCreator collective_ops_creator_;
251 };
252 
253 // Class describes partition state of the data represented by an HLO created
254 // during SPMD partitioning pass.
255 //
256 // Data on some devices may include padding region, if the base (full) shape
257 // could not be evenly partitioned.
258 class PartitionedHlo {
259  public:
260   // Return value for ReshardAsWindowedInput which describes the resharded HLO,
261   // the window for the user on the shard, and if necessary, the dynamic slice
262   // offsets to be applied to the output of the op being sharded.
263   struct WindowedInputShardReturnValue {
264     HloInstruction* sharded_input;
265     Window shard_window;
266     absl::optional<std::vector<HloInstruction*>> dynamic_slice_index_on_output;
267   };
268   // A cache for resharding each partitioned HLO.
269   struct ReshardCache {
270     struct PerHloCache {
271       std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache;
272       std::vector<
273           std::tuple<HloSharding, Window, WindowedInputShardReturnValue>>
274           window_reshard_cache;
275     };
276     // Use std::unordered_map for pointer stability.
277     std::unordered_map<HloInstruction*, PerHloCache> per_hlo_cache;
278     // Caches for nested partitioning of grouped sharding. Each string key
279     // represents a unique way of grouping devices.
280     absl::flat_hash_map<std::string, std::unique_ptr<ReshardCache>>
281         groupd_caches;
282   };
283   struct PartitioningState {
284     SpmdBuilder* b;
285     HloModule* module;
286     int64 num_replicas;
287     HloInstruction* partition_id;
288     SPMDCollectiveOpsCreator collective_ops_creator;
289     int64* next_channel_id;
290     ReshardCache* reshard_cache;
291     SpmdPartitioner* partitioner;
292   };
PartitionedHlo(HloInstruction * hlo,Shape base_shape,PartitioningState state)293   PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state)
294       : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) {
295     CHECK(hlo->has_sharding())
296         << "PartitionedHlo is missing sharding:" << hlo->ToString();
297     // If the tuple shape instruction does not have a tuple sharding, reassign
298     // to use the tuple sharding. Reshard() implementation assumes this.
299     if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) {
300       hlo_->set_sharding(
301           hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie());
302     }
303   }
304 
305   // Reshards the current SPMD instruction to a new sharding. Could only modify
306   // the reshard cache.
307   PartitionedHlo Reshard(const HloSharding& target);
308 
309   // Pads the garbage area of the output with the provided value. Normally,
310   // unevenly partitioned dimensions are padded on the right, but this function
311   // allows specifying left-padded dimensions, which can be used during the
312   // handling of kReverse, etc.
313   PartitionedHlo PadWithValue(HloInstruction* pad_value,
314                               absl::Span<const int64> left_padded_dims = {},
315                               absl::Span<const int64> skipped_dims = {}) const;
316 
317   // Returns the SPMD instruction.
hlo()318   HloInstruction* hlo() const { return hlo_; }
319 
320   // Returns the sharding of the SPMD instruction.
sharding()321   const HloSharding& sharding() const { return hlo_->sharding(); }
322 
323   // Original full shape of the data.
base_shape()324   const Shape& base_shape() const { return base_shape_; }
325 
NewChannel()326   int64 NewChannel() const { return (*state_.next_channel_id)++; }
327 
328   // Reshards the HLO to a usable partitioned input for a windowed user. Could
329   // only modify the reshard cache.
330   absl::optional<WindowedInputShardReturnValue> ReshardAsWindowedInput(
331       const Window& window, const HloSharding& target,
332       HloInstruction* pad_value, bool mask_invalid_region = true);
333 
state()334   const PartitioningState& state() const { return state_; }
335 
336   // Helper function to replicate the data on all devices. Could only modify
337   // the reshard cache.
338   PartitionedHlo Replicate();
339 
340   // Helper function to replicate the data for partitions along the given dims.
341   HloInstruction* ReplicatePartial(absl::Span<const int64> dims);
342 
343  private:
344   // Same as Reshard except that it does not explicitly modify the reshard
345   // cache, although it would indirectly modify by calling Replicate().
346   PartitionedHlo ReshardNoCache(const HloSharding& target);
347 
348   // Helper function to broadcast data from a single device to all devices.
349   PartitionedHlo Broadcast() const;
350 
351   // Helper function to reshard the tensor using AllToAll (instead of the
352   // default of Replicate followed by Slice).
353   PartitionedHlo ReshardWithAllToAll(
354       const HloSharding& target,
355       absl::Span<const std::pair<int64, int64>> source_target_dims) const;
356 
357   // Helper function to reshard the tensor using CollectivePermute.
358   PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
359 
360   // Helper function to reshard to partial replicate using AllGather.
361   absl::optional<PartitionedHlo> ReshardToPartialReplicateWithAllGather(
362       const HloSharding& target);
363 
364   // Helper function to reshard from partial replicate using DynamicSlice.
365   absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice(
366       const HloSharding& target);
367 
368   // Helper function to reshard from partial replicate using AllToAll.
369   absl::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll(
370       const HloSharding& target);
371 
372   // SPMD instruction.
373   HloInstruction* hlo_;
374 
375   // The original shape of the data before SPMD transformation is applied.
376   Shape base_shape_;
377 
378   PartitioningState state_;
379 };
380 
381 struct DotConvDimsMapping {
382   // The dimension numbers for the operands and output corresponding to a
383   // logical dimension (e.g., batch, contracting, non-contracting). If an
384   // operand or the output doesn't have the logical dimension, it is set to
385   // -1.
386   struct DimsMapping {
387     int64 lhs;
388     int64 rhs;
389     int64 output;
390     // input mapped to index in input_spatial_dimensions().
391     int64 spatial;
392   };
393   std::vector<DimsMapping> batch_dims;
394   std::vector<DimsMapping> contracting_dims;
395   std::vector<DimsMapping> lhs_non_contracting_dims;
396   std::vector<DimsMapping> rhs_non_contracting_dims;
397   std::vector<DimsMapping> conv_spatial_dims;
398 };
399 
400 class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
401  public:
402   SpmdPartitioningVisitor(
403       HloComputation* computation, int64 num_partitions, int64 num_replicas,
404       const SPMDCollectiveOpsCreator& collective_ops_creator,
405       int64* next_channel_id, SpmdLogger* logger,
406       SpmdPartitionerOptions options, SpmdPartitioner* partitioner);
407 
408   Status DefaultAction(HloInstruction* hlo) override;
409   Status HandleAllReduce(HloInstruction* hlo) override;
410   Status HandleBroadcast(HloInstruction* hlo) override;
411   Status HandleConstant(HloInstruction* hlo) override;
412   Status HandleCustomCall(HloInstruction* hlo) override;
413   Status HandleDot(HloInstruction* hlo) override;
414   Status HandleDynamicSlice(HloInstruction* hlo) override;
415   Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
416   Status HandleFft(HloInstruction* hlo) override;
417   Status HandleGather(HloInstruction* hlo) override;
418   Status HandleGetTupleElement(HloInstruction* hlo) override;
419   Status HandleInfeed(HloInstruction* hlo) override;
420   Status HandleOutfeed(HloInstruction* hlo) override;
421   Status HandlePad(HloInstruction* hlo) override;
422   Status HandleParameter(HloInstruction* hlo) override;
423   Status HandleReduce(HloInstruction* hlo) override;
424   Status HandleReverse(HloInstruction* hlo) override;
425   Status HandleWhile(HloInstruction* hlo) override;
426   Status HandleConditional(HloInstruction* hlo) override;
427   Status HandleReduceWindow(HloInstruction* hlo) override;
428   Status HandleSelectAndScatter(HloInstruction* hlo) override;
429   Status HandleTuple(HloInstruction* hlo) override;
430   Status HandleRng(HloInstruction* hlo) override;
431   Status HandleConvolution(HloInstruction* hlo) override;
432   Status HandleConcatenate(HloInstruction* hlo) override;
433   Status HandleScatter(HloInstruction* hlo) override;
434   Status HandleSlice(HloInstruction* hlo) override;
435   Status HandleSort(HloInstruction* hlo) override;
436   Status HandleTranspose(HloInstruction* hlo) override;
437   Status HandleReshape(HloInstruction* hlo) override;
438   Status HandleIota(HloInstruction* hlo) override;
439   Status HandlePartitionId(HloInstruction* hlo) override;
440 
441   // Implementation of dot partitioning given DotGeneralDimsMapping.
442   Status HandleDotHelper(HloInstruction* hlo,
443                          const DotConvDimsMapping& dims_mapping,
444                          const std::function<StatusOr<HloInstruction*>(
445                              HloInstruction*, HloInstruction*, SpmdBuilder*,
446                              const Window& conv_window)>& create_sharded_dot);
447 
448   // Common handle for elementwise HLOs.
449   Status HandleElementwise(HloInstruction* hlo);
450 
451   // Common handle for HLOs that runs on a single device.
452   Status HandleSingleDevice(const HloInstruction* hlo);
453 
454   // Returns the PartitionedHlo that corresponds to the original hlo.
GetPartitionedHlo(const HloInstruction * hlo)455   PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) {
456     CHECK_EQ(partitioned_instructions_.count(hlo), 1);
457     return partitioned_instructions_.find(hlo)->second;
458   }
459 
460   // Sets the PartitionedHlo for the original hlo.
SetPartitionedHlo(const HloInstruction * hlo,const PartitionedHlo & partitioned_hlo)461   void SetPartitionedHlo(const HloInstruction* hlo,
462                          const PartitionedHlo& partitioned_hlo) {
463     CHECK_EQ(partitioned_instructions_.count(hlo), 0);
464     partitioned_instructions_.emplace(hlo, partitioned_hlo);
465     changed_ = true;
466   }
467 
468   // Convenient wrapper that creates PartitionedHlo from the result of the func
469   // and maps it to the given original hlo.
SetPartitionedHlo(const HloInstruction * hlo,const std::function<HloInstruction * ()> & func)470   void SetPartitionedHlo(const HloInstruction* hlo,
471                          const std::function<HloInstruction*()>& func) {
472     HloInstruction* new_hlo = func();
473     new_hlo->set_sharding(hlo->sharding());
474     new_hlo->set_metadata(hlo->metadata());
475     SetPartitionedHlo(
476         hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState()));
477     changed_ = true;
478   }
479 
NewChannel()480   int64 NewChannel() { return (*next_channel_id_)++; }
481 
MakePartitioningState()482   PartitionedHlo::PartitioningState MakePartitioningState() {
483     PartitionedHlo::PartitioningState state;
484     state.b = &b_;
485     state.module = module_;
486     state.num_replicas = num_replicas_;
487     state.partition_id = partition_id_;
488     state.collective_ops_creator = collective_ops_creator_;
489     state.next_channel_id = next_channel_id_;
490     state.reshard_cache = &reshard_cache_;
491     state.partitioner = partitioner_;
492     return state;
493   }
494 
builder()495   SpmdBuilder* builder() { return &b_; }
496 
497   StatusOr<bool> DoPartition(HloComputation* computation,
498                              const HloSharding& root_sharding,
499                              const SpmdPartitionerOptions& options);
500 
501   // Information about a loop created for windowed dot-general. Used when
502   // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
503   // finishes traversing the graph.
504   struct WindowedDotGeneralLoop {
505     HloInstruction* while_loop;
506     int64 windowed_operand;
507     bool windowed_in_contracting_dims;
508     bool windowed_in_batch_dims;
509     bool operands_sharded_at_contracting_dims;
510   };
511 
512  private:
513   Status Preprocess(HloInstruction* hlo) override;
514   Status Postprocess(HloInstruction* hlo) override;
515 
516   // Performs code motion for windowed dot-general loops in
517   // windowed_dot_general_loops_. Invoked after the visitor finishes traversing
518   // the graph.
519   Status DoCodeMotionForWindowedDotGeneralLoops(
520       HloComputation* computation, const SpmdPartitionerOptions& options);
521 
522   bool changed_;
523   HloModule* module_;
524   int64 num_partitions_;
525   int64 num_replicas_;
526 
527   SPMDCollectiveOpsCreator collective_ops_creator_;
528 
529   // Tracks the next channel id to use for cross-partition all-reduce.
530   int64* next_channel_id_;
531   SpmdBuilder b_;
532 
533   HloInstruction* partition_id_;
534 
535   PartitionedHlo::ReshardCache reshard_cache_;
536 
537   // Mapping from the instruction in the original computation to the new SPMD
538   // partitioned instruction.
539   ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_;
540 
541   std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_;
542 
543   HloInstruction* visiting_hlo_;
544   SpmdLogger* logger_;
545   const SpmdPartitionerOptions options_;
546   SpmdPartitioner* partitioner_;
547   std::vector<HloSharding> visiting_hlo_operand_shardings_;
548   absl::optional<HloSharding> visiting_hlo_sharding_;
549 };
550 
551 }  // namespace spmd
552 }  // namespace xla
553 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
554