1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 17 18 #include "absl/container/flat_hash_map.h" 19 #include "absl/container/flat_hash_set.h" 20 #include "tensorflow/compiler/xla/service/buffer_liveness.h" 21 #include "tensorflow/compiler/xla/service/call_graph.h" 22 #include "tensorflow/compiler/xla/service/hlo_computation.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 27 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 28 29 namespace xla { 30 31 // HLO pass which rematerializes instructions to reduce peak memory use, where 32 // memory use is defined as the total size of all live HLO instruction 33 // values. Parameters and constants are included in memory use estimates. 34 // 35 // CSE will undo the effects of this optimization and should not be run after 36 // this pass. In general, this pass should be run very late, immediately before 37 // code generation. 38 class HloRematerialization : public HloModulePass { 39 public: 40 using ShapeSizeFunction = std::function<int64(const Shape&)>; 41 42 // Helper struct that communicates the before / after sizes for the 43 // rematerialization process. 44 struct RematerializationSizes { 45 int64 before_bytes; 46 int64 after_bytes; 47 }; 48 49 // Constructor parameters: 50 // 51 // size_function: Function which returns the size in bytes of the top-level 52 // buffer of the given shape. 53 // 54 // memory_limit_bytes: The threshold number of bytes to reduce memory use to 55 // via rematerialization. 56 // 57 // sizes: Pointer to data structure which records the peak memory usage of 58 // the HLO module before/after rematerialization. Value are set during 59 // Run(). Can be nullptr. HloRematerialization(const ShapeSizeFunction & size_function,int64 memory_limit_bytes,RematerializationSizes * sizes)60 HloRematerialization(const ShapeSizeFunction& size_function, 61 int64 memory_limit_bytes, RematerializationSizes* sizes) 62 : size_function_(size_function), 63 memory_limit_bytes_(memory_limit_bytes), 64 sizes_(sizes) {} ~HloRematerialization()65 ~HloRematerialization() {} 66 name()67 absl::string_view name() const override { return "rematerialization"; } 68 69 // Runs rematerialization on the given module. Returns whether the module was 70 // changed. Requires that the module has a schedule set 71 // (HloModule::has_schedule() is true) before running. Returns whether any 72 // instructions were rematerialized. If memory use is already below the limit 73 // specified in the constructor then no instructions are rematerialized and 74 // false is returned. 75 StatusOr<bool> Run(HloModule* module) override; 76 77 protected: 78 // Rematerializes instructions within the given computation. 'order' is the 79 // order in which the computation's instructions will be emitted in the 80 // backend. Rematerialized instructions will be added to the HLO computation 81 // and inserted into 'order'. 82 StatusOr<bool> RematerializeComputation(HloComputation* computation, 83 HloSchedule* schedule, 84 int64 memory_limit_bytes); 85 86 // Computes and returns the peak memory used by the given computation. The 87 // peak memory is the maximum total size of all live HLO instruction values at 88 // any program point. 'order' is the order in which the HLO instructions will 89 // be emitted which is used to determine lifespans of HLO values. 90 StatusOr<int64> ComputePeakMemory(const HloComputation* computation, 91 const HloInstructionSequence& order) const; 92 93 // Returns the peak memory usage of the called computations for the given 94 // instruction. Zero is returned if the instruction calls no computations. 95 StatusOr<int64> CalledComputationsMemoryUsage( 96 const HloInstruction* instruction) const; 97 98 // Selects an algorithm to use for HLO scheduling. 99 MemorySchedulerAlgorithm scheduler_algorithm_; 100 101 // Function which computes the size of the top-level buffer of a shape. 102 const ShapeSizeFunction size_function_; 103 104 // The threshold number of bytes to reduce memory use to via 105 // rematerialization. 106 const int64 memory_limit_bytes_; 107 108 // Pointer to data structure which records the peak memory usage of the HLO 109 // module before/after rematerialization 110 RematerializationSizes* sizes_; 111 112 // Call graph of the hlo_module. 113 std::unique_ptr<CallGraph> call_graph_; 114 115 // The peak memory usage of each computation. The map contains only those 116 // computations called from sequential context 117 // (CallContext::kSequential). These values are updated as rematerialization 118 // occurs. 119 absl::flat_hash_map<const HloComputation*, int64> computation_peak_memory_; 120 121 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; 122 123 // Set of computations which have had rematerialization 124 // applied. Rematerialization is only applied once per computation. 125 absl::flat_hash_set<const HloComputation*> rematerialized_computations_; 126 127 // Count of the total instructions rematerialized. 128 int64 instructions_rematerialized_ = 0; 129 130 // Count of the net instructions added to the HLO module by 131 // rematerialization. This can be different than instructions_rematerialized_ 132 // because some rematerializations are effectively moves in the HLO 133 // schedule. In these cases, the rematerialization instruction replaces all 134 // uses of the original instruction and the original instruction is 135 // dead. Hence, no net instructions were added. 136 int64 net_instructions_added_ = 0; 137 }; 138 139 } // namespace xla 140 141 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 142