1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
18 
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
23 
24 namespace xla {
25 
26 class AlgebraicSimplifierOptions {
27  public:
AlgebraicSimplifierOptions()28   AlgebraicSimplifierOptions() {}
29   // Platform dependent callback to determine if a reshape `from_shape` to
30   // `to_shape` is a bitcast.
31   using ReshapeIsBitcastCallback =
32       std::function<bool(const Shape& from_shape, const Shape& to_shape)>;
AlgebraicSimplifierOptions(ReshapeIsBitcastCallback reshape_is_bitcast_callback)33   explicit AlgebraicSimplifierOptions(
34       ReshapeIsBitcastCallback reshape_is_bitcast_callback)
35       : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)) {}
36 
37   // Use the platform specific callback if set. It is not sensible to return
38   // true here if the options are not layout sensitive.
ReshapeIsBitcast(const Shape & from_shape,const Shape & to_shape)39   bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const {
40     if (!is_layout_sensitive_) {
41       return false;
42     }
43     if (!reshape_is_bitcast_callback_) {
44       return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape);
45     }
46     return reshape_is_bitcast_callback_(from_shape, to_shape);
47   }
48 
49   // If is_layout_sensitive is true, then the simplifier preserves layout during
50   // transformation. Otherwise, layout is ignored.
set_is_layout_sensitive(bool is_layout_sensitive)51   void set_is_layout_sensitive(bool is_layout_sensitive) {
52     is_layout_sensitive_ = is_layout_sensitive;
53   }
54 
is_layout_sensitive()55   bool is_layout_sensitive() const { return is_layout_sensitive_; }
56 
57   // Enable dot simplification on platforms where it is profitable.
set_enable_dot_strength_reduction(bool enable_dot_strength_reduction)58   void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) {
59     enable_dot_strength_reduction_ = enable_dot_strength_reduction;
60   }
61 
enable_dot_strength_reduction()62   bool enable_dot_strength_reduction() const {
63     return enable_dot_strength_reduction_;
64   }
65 
66   // Enable convolution simplification on platforms where it is profitable.
set_enable_conv_simplification(bool enable_conv_simplification)67   void set_enable_conv_simplification(bool enable_conv_simplification) {
68     enable_conv_simplification_ = enable_conv_simplification;
69   }
enable_conv_simplification()70   bool enable_conv_simplification() const {
71     return enable_conv_simplification_;
72   }
73 
74   // If enable_window_reduce_replacement is true, the kReduceWindow instruction
75   // can be optimized by replacement with simpler operations.
set_enable_window_reduce_to_reduce_replacement(bool enable_window_reduce_to_reduce_replacement)76   void set_enable_window_reduce_to_reduce_replacement(
77       bool enable_window_reduce_to_reduce_replacement) {
78     enable_window_reduce_to_reduce_replacement_ =
79         enable_window_reduce_to_reduce_replacement;
80   }
81 
enable_window_reduce_to_reduce_replacement()82   bool enable_window_reduce_to_reduce_replacement() const {
83     return enable_window_reduce_to_reduce_replacement_;
84   }
85 
86  private:
87   ReshapeIsBitcastCallback reshape_is_bitcast_callback_;
88   bool is_layout_sensitive_{false};
89   bool enable_dot_strength_reduction_{true};
90   bool enable_conv_simplification_{true};
91   bool enable_window_reduce_to_reduce_replacement_{true};
92 };
93 
94 // A pass which performs algebraic simplifications.
95 class AlgebraicSimplifier : public HloModulePass {
96  public:
97   // If is_layout_sensitive is true, then the simplifier preserves layout during
98   // transformation. Otherwise, layout is ignored.
AlgebraicSimplifier(const AlgebraicSimplifierOptions & options)99   explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options)
100       : options_(options) {}
101   ~AlgebraicSimplifier() override = default;
name()102   absl::string_view name() const override { return "algsimp"; }
103 
104   // Run algebraic simplification on the given computation. Returns whether the
105   // computation was changed.
106   StatusOr<bool> Run(HloModule* module) override;
107 
108  private:
109   AlgebraicSimplifierOptions options_;
110 };
111 
112 }  // namespace xla
113 
114 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
115