1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 17 18 #include "tensorflow/compiler/xla/service/buffer_liveness.h" 19 #include "tensorflow/compiler/xla/service/call_graph.h" 20 #include "tensorflow/compiler/xla/service/hlo_computation.h" 21 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 22 #include "tensorflow/compiler/xla/service/hlo_module.h" 23 #include "tensorflow/compiler/xla/service/hlo_scheduling.h" 24 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 25 26 namespace xla { 27 28 class HloRematerialization { 29 public: 30 using ShapeSizeFunction = std::function<int64(const Shape&)>; 31 32 // Helper struct that communicates the before / after sizes for the 33 // rematerialization process. 34 struct RematerializationSizes { 35 int64 before_bytes; 36 int64 after_bytes; 37 }; 38 39 // Rematerialize HLO instructions in the given module to reduce peak memory 40 // use below memory_limit_bytes where memory use is defined as the total size 41 // of all live HLO instruction values. Parameters and constants are included 42 // in memory use estimates. Method parameters: 43 // 44 // size_function: Function which returns the size in bytes of the top-level 45 // buffer of the given shape. 46 // 47 // memory_limit_bytes: The threshold number of bytes to reduce memory use to 48 // via rematerialization. 49 // 50 // hlo_module: HLO module to rematerialize instructions in. 51 // 52 // sequence: Should point to an empty HloModuleSequence. Upon return 53 // contains the HLO instruction order which was used for 54 // rematerialization. This is the order in which HLO instructions should 55 // be emitted to minimize memory use. 56 // 57 // sizes: Optional outparam that indicates the peak memory usage of the HLO 58 // module before/after rematerialization. 59 // 60 // Returns whether any instructions were rematerialized. If memory use is 61 // already below the given limit then no instructions are rematerialized and 62 // false is returned. 63 // 64 // CSE will undo the effects of this optimization and should not be run after 65 // this pass. In general, this pass should be run very late immediately before 66 // code generation. 67 static StatusOr<bool> RematerializeAndSchedule( 68 const ShapeSizeFunction& size_function, int64 memory_limit_bytes, 69 HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm, 70 SequentialHloOrdering::HloModuleSequence* sequence, 71 RematerializationSizes* sizes = nullptr); 72 73 protected: 74 HloRematerialization(SchedulerAlgorithm scheduler_algorithm, 75 const ShapeSizeFunction& size_function) 76 : scheduler_algorithm_(scheduler_algorithm), 77 size_function_(size_function) {} 78 ~HloRematerialization() {} 79 80 // Runs rematerialization on the given module. Returns whether the module was 81 // changed. memory_limit is the target maximum peak memory usage by the 82 // module. sequence should be an empty HloModuleSequence. Upon return sequence 83 // contains the memory-minimizing order in which to emit the HLO instructions. 84 StatusOr<bool> Run(HloModule* module, 85 SequentialHloOrdering::HloModuleSequence* sequence, 86 int64 memory_limit, RematerializationSizes* sizes); 87 88 // Rematerializes instructions within the given computation. 'order' is the 89 // order in which the computation's instructions will be emitted in the 90 // backend. Rematerialized instructions will be added to the HLO computation 91 // and inserted into 'order'. 92 StatusOr<bool> RematerializeComputation( 93 HloComputation* computation, 94 SequentialHloOrdering::HloModuleSequence* sequence, 95 int64 computation_memory_limit); 96 97 // Computes and returns the peak memory used by the given computation. The 98 // peak memory is the maximum total size of all live HLO instruction values at 99 // any program point. 'order' is the order in which the HLO instructions will 100 // be emitted which is used to determine lifespans of HLO values. 101 StatusOr<int64> ComputePeakMemory( 102 const HloComputation* computation, 103 const std::vector<const HloInstruction*>& order) const; 104 105 // Returns the peak memory usage of the called computations for the given 106 // instruction. Zero is returned if the instruction calls no computations. 107 StatusOr<int64> CalledComputationsMemoryUsage( 108 const HloInstruction* instruction) const; 109 110 // Selects an algorithm to use for HLO scheduling. 111 SchedulerAlgorithm scheduler_algorithm_; 112 113 // Function which computes the size of the top-level buffer of a shape. 114 const ShapeSizeFunction size_function_; 115 116 // Call graph of the hlo_module. 117 std::unique_ptr<CallGraph> call_graph_; 118 119 // The peak memory usage of each computation. The map contains only those 120 // computations called from sequential context 121 // (CallContext::kSequential). These values are updated as rematerialization 122 // occurs. 123 tensorflow::gtl::FlatMap<const HloComputation*, int64> 124 computation_peak_memory_; 125 126 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; 127 128 // Set of computations which have had rematerialization 129 // applied. Rematerialization is only applied once per computation. 130 tensorflow::gtl::FlatSet<const HloComputation*> rematerialized_computations_; 131 132 // Count of the total instructions rematerialized. 133 int64 instructions_rematerialized_ = 0; 134 135 // Count of the net instructions added to the HLO module by 136 // rematerialization. This can be different than instructions_rematerialized_ 137 // because some rematerializations are effectively moves in the HLO 138 // schedule. In these cases, the rematerialization instruction replaces all 139 // uses of the original instruction and the original instruction is 140 // dead. Hence, no net instructions were added. 141 int64 net_instructions_added_ = 0; 142 }; 143 144 } // namespace xla 145 146 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 147