1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
18 
19 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
22 
23 namespace xla {
24 
25 // A pass which folds F32 <-> BF16 conversions to their operands or users, when
26 // it is supported by the backend.
27 //
28 // This pass follows the passed-in backend-specific BF16 support rules, but can
29 // introduce mixed precision in individual HLOs which breaks the assumption of
30 // some other HLO passes. So it should be used at the end of the HLO
31 // optimization pipeline followed by a DCE pass. If other passes are needed
32 // after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
33 // changed made by this pass.
34 class BFloat16ConversionFolding : public HloModulePass {
35  public:
BFloat16ConversionFolding(const BFloat16Support * bfloat16_support)36   explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
37       : bfloat16_support_(bfloat16_support) {}
38 
39   ~BFloat16ConversionFolding() override = default;
name()40   absl::string_view name() const override { return "bfloat16-fold"; }
41 
42   // Run BF16 conversion folding on the given computation. Returns whether the
43   // computation was changed.
44   StatusOr<bool> Run(HloModule* module) override;
45 
46  private:
47   const BFloat16Support* bfloat16_support_;
48 };
49 
50 }  // namespace xla
51 
52 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
53