1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ 18 19 #include "tensorflow/compiler/xla/service/bfloat16_support.h" 20 #include "tensorflow/compiler/xla/service/hlo_module.h" 21 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 22 23 namespace xla { 24 25 // A pass which folds F32 <-> BF16 conversions to their operands or users, when 26 // it is supported by the backend. 27 // 28 // This pass follows the passed-in backend-specific BF16 support rules, but can 29 // introduce mixed precision in individual HLOs which breaks the assumption of 30 // some other HLO passes. So it should be used at the end of the HLO 31 // optimization pipeline followed by a DCE pass. If other passes are needed 32 // after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the 33 // changed made by this pass. 34 class BFloat16ConversionFolding : public HloModulePass { 35 public: BFloat16ConversionFolding(const BFloat16Support * bfloat16_support)36 explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support) 37 : bfloat16_support_(bfloat16_support) {} 38 39 ~BFloat16ConversionFolding() override = default; name()40 absl::string_view name() const override { return "bfloat16-fold"; } 41 42 // Run BF16 conversion folding on the given computation. Returns whether the 43 // computation was changed. 44 StatusOr<bool> Run(HloModule* module) override; 45 46 private: 47 const BFloat16Support* bfloat16_support_; 48 }; 49 50 } // namespace xla 51 52 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ 53