1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3  Licensed under the Apache License, Version 2.0 (the "License");
4  you may not use this file except in compliance with the License.
5  You may obtain a copy of the License at
6 
7      http://www.apache.org/licenses/LICENSE-2.0
8 
9  Unless required by applicable law or agreed to in writing, software
10  distributed under the License is distributed on an "AS IS" BASIS,
11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  See the License for the specific language governing permissions and
13  limitations under the License.
14  ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
17 
18 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
19 #include "tensorflow/compiler/xla/service/call_graph.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_scheduling.h"
24 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
25 
26 namespace xla {
27 
28 class HloRematerialization {
29  public:
30   using ShapeSizeFunction = std::function<int64(const Shape&)>;
31 
32   // Helper struct that communicates the before / after sizes for the
33   // rematerialization process.
34   struct RematerializationSizes {
35     int64 before_bytes;
36     int64 after_bytes;
37   };
38 
39   // Rematerialize HLO instructions in the given module to reduce peak memory
40   // use below memory_limit_bytes where memory use is defined as the total size
41   // of all live HLO instruction values. Parameters and constants are included
42   // in memory use estimates. Method parameters:
43   //
44   //   size_function: Function which returns the size in bytes of the top-level
45   //     buffer of the given shape.
46   //
47   //   memory_limit_bytes: The threshold number of bytes to reduce memory use to
48   //     via rematerialization.
49   //
50   //   hlo_module: HLO module to rematerialize instructions in.
51   //
52   //   sequence: Should point to an empty HloModuleSequence. Upon return
53   //     contains the HLO instruction order which was used for
54   //     rematerialization. This is the order in which HLO instructions should
55   //     be emitted to minimize memory use.
56   //
57   //   sizes: Optional outparam that indicates the peak memory usage of the HLO
58   //     module before/after rematerialization.
59   //
60   // Returns whether any instructions were rematerialized. If memory use is
61   // already below the given limit then no instructions are rematerialized and
62   // false is returned.
63   //
64   // CSE will undo the effects of this optimization and should not be run after
65   // this pass. In general, this pass should be run very late immediately before
66   // code generation.
67   static StatusOr<bool> RematerializeAndSchedule(
68       const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
69       HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm,
70       SequentialHloOrdering::HloModuleSequence* sequence,
71       RematerializationSizes* sizes = nullptr);
72 
73  protected:
74   HloRematerialization(SchedulerAlgorithm scheduler_algorithm,
75                        const ShapeSizeFunction& size_function)
76       : scheduler_algorithm_(scheduler_algorithm),
77         size_function_(size_function) {}
78   ~HloRematerialization() {}
79 
80   // Runs rematerialization on the given module. Returns whether the module was
81   // changed. memory_limit is the target maximum peak memory usage by the
82   // module. sequence should be an empty HloModuleSequence. Upon return sequence
83   // contains the memory-minimizing order in which to emit the HLO instructions.
84   StatusOr<bool> Run(HloModule* module,
85                      SequentialHloOrdering::HloModuleSequence* sequence,
86                      int64 memory_limit, RematerializationSizes* sizes);
87 
88   // Rematerializes instructions within the given computation. 'order' is the
89   // order in which the computation's instructions will be emitted in the
90   // backend. Rematerialized instructions will be added to the HLO computation
91   // and inserted into 'order'.
92   StatusOr<bool> RematerializeComputation(
93       HloComputation* computation,
94       SequentialHloOrdering::HloModuleSequence* sequence,
95       int64 computation_memory_limit);
96 
97   // Computes and returns the peak memory used by the given computation. The
98   // peak memory is the maximum total size of all live HLO instruction values at
99   // any program point. 'order' is the order in which the HLO instructions will
100   // be emitted which is used to determine lifespans of HLO values.
101   StatusOr<int64> ComputePeakMemory(
102       const HloComputation* computation,
103       const std::vector<const HloInstruction*>& order) const;
104 
105   // Returns the peak memory usage of the called computations for the given
106   // instruction. Zero is returned if the instruction calls no computations.
107   StatusOr<int64> CalledComputationsMemoryUsage(
108       const HloInstruction* instruction) const;
109 
110   // Selects an algorithm to use for HLO scheduling.
111   SchedulerAlgorithm scheduler_algorithm_;
112 
113   // Function which computes the size of the top-level buffer of a shape.
114   const ShapeSizeFunction size_function_;
115 
116   // Call graph of the hlo_module.
117   std::unique_ptr<CallGraph> call_graph_;
118 
119   // The peak memory usage of each computation. The map contains only those
120   // computations called from sequential context
121   // (CallContext::kSequential). These values are updated as rematerialization
122   // occurs.
123   tensorflow::gtl::FlatMap<const HloComputation*, int64>
124       computation_peak_memory_;
125 
126   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
127 
128   // Set of computations which have had rematerialization
129   // applied. Rematerialization is only applied once per computation.
130   tensorflow::gtl::FlatSet<const HloComputation*> rematerialized_computations_;
131 
132   // Count of the total instructions rematerialized.
133   int64 instructions_rematerialized_ = 0;
134 
135   // Count of the net instructions added to the HLO module by
136   // rematerialization. This can be different than instructions_rematerialized_
137   // because some rematerializations are effectively moves in the HLO
138   // schedule. In these cases, the rematerialization instruction replaces all
139   // uses of the original instruction and the original instruction is
140   // dead. Hence, no net instructions were added.
141   int64 net_instructions_added_ = 0;
142 };
143 
144 }  // namespace xla
145 
146 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
147