1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3  Licensed under the Apache License, Version 2.0 (the "License");
4  you may not use this file except in compliance with the License.
5  You may obtain a copy of the License at
6 
7      http://www.apache.org/licenses/LICENSE-2.0
8 
9  Unless required by applicable law or agreed to in writing, software
10  distributed under the License is distributed on an "AS IS" BASIS,
11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  See the License for the specific language governing permissions and
13  limitations under the License.
14  ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
21 #include "tensorflow/compiler/xla/service/call_graph.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
28 
29 namespace xla {
30 
31 // HLO pass which rematerializes instructions to reduce peak memory use, where
32 // memory use is defined as the total size of all live HLO instruction
33 // values. Parameters and constants are included in memory use estimates.
34 //
35 // CSE will undo the effects of this optimization and should not be run after
36 // this pass. In general, this pass should be run very late, immediately before
37 // code generation.
38 class HloRematerialization : public HloModulePass {
39  public:
40   using ShapeSizeFunction = std::function<int64(const Shape&)>;
41 
42   // Helper struct that communicates the before / after sizes for the
43   // rematerialization process.
44   struct RematerializationSizes {
45     int64 before_bytes;
46     int64 after_bytes;
47   };
48 
49   // Constructor parameters:
50   //
51   //   size_function: Function which returns the size in bytes of the top-level
52   //     buffer of the given shape.
53   //
54   //   memory_limit_bytes: The threshold number of bytes to reduce memory use to
55   //     via rematerialization.
56   //
57   //   sizes: Pointer to data structure which records the peak memory usage of
58   //     the HLO module before/after rematerialization. Value are set during
59   //     Run(). Can be nullptr.
HloRematerialization(const ShapeSizeFunction & size_function,int64 memory_limit_bytes,RematerializationSizes * sizes)60   HloRematerialization(const ShapeSizeFunction& size_function,
61                        int64 memory_limit_bytes, RematerializationSizes* sizes)
62       : size_function_(size_function),
63         memory_limit_bytes_(memory_limit_bytes),
64         sizes_(sizes) {}
~HloRematerialization()65   ~HloRematerialization() {}
66 
name()67   absl::string_view name() const override { return "rematerialization"; }
68 
69   // Runs rematerialization on the given module. Returns whether the module was
70   // changed. Requires that the module has a schedule set
71   // (HloModule::has_schedule() is true) before running. Returns whether any
72   // instructions were rematerialized. If memory use is already below the limit
73   // specified in the constructor then no instructions are rematerialized and
74   // false is returned.
75   StatusOr<bool> Run(HloModule* module) override;
76 
77  protected:
78   // Rematerializes instructions within the given computation. 'order' is the
79   // order in which the computation's instructions will be emitted in the
80   // backend. Rematerialized instructions will be added to the HLO computation
81   // and inserted into 'order'.
82   StatusOr<bool> RematerializeComputation(HloComputation* computation,
83                                           HloSchedule* schedule,
84                                           int64 memory_limit_bytes);
85 
86   // Computes and returns the peak memory used by the given computation. The
87   // peak memory is the maximum total size of all live HLO instruction values at
88   // any program point. 'order' is the order in which the HLO instructions will
89   // be emitted which is used to determine lifespans of HLO values.
90   StatusOr<int64> ComputePeakMemory(const HloComputation* computation,
91                                     const HloInstructionSequence& order) const;
92 
93   // Returns the peak memory usage of the called computations for the given
94   // instruction. Zero is returned if the instruction calls no computations.
95   StatusOr<int64> CalledComputationsMemoryUsage(
96       const HloInstruction* instruction) const;
97 
98   // Selects an algorithm to use for HLO scheduling.
99   MemorySchedulerAlgorithm scheduler_algorithm_;
100 
101   // Function which computes the size of the top-level buffer of a shape.
102   const ShapeSizeFunction size_function_;
103 
104   // The threshold number of bytes to reduce memory use to via
105   // rematerialization.
106   const int64 memory_limit_bytes_;
107 
108   // Pointer to data structure which records the peak memory usage of the HLO
109   // module before/after rematerialization
110   RematerializationSizes* sizes_;
111 
112   // Call graph of the hlo_module.
113   std::unique_ptr<CallGraph> call_graph_;
114 
115   // The peak memory usage of each computation. The map contains only those
116   // computations called from sequential context
117   // (CallContext::kSequential). These values are updated as rematerialization
118   // occurs.
119   absl::flat_hash_map<const HloComputation*, int64> computation_peak_memory_;
120 
121   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
122 
123   // Set of computations which have had rematerialization
124   // applied. Rematerialization is only applied once per computation.
125   absl::flat_hash_set<const HloComputation*> rematerialized_computations_;
126 
127   // Count of the total instructions rematerialized.
128   int64 instructions_rematerialized_ = 0;
129 
130   // Count of the net instructions added to the HLO module by
131   // rematerialization. This can be different than instructions_rematerialized_
132   // because some rematerializations are effectively moves in the HLO
133   // schedule. In these cases, the rematerialization instruction replaces all
134   // uses of the original instruction and the original instruction is
135   // dead. Hence, no net instructions were added.
136   int64 net_instructions_added_ = 0;
137 };
138 
139 }  // namespace xla
140 
141 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
142