1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
18 
19 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
20 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
21 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
22 #include "tensorflow/core/grappler/utils/grappler_test.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 
25 namespace tensorflow {
26 namespace grappler {
27 
28 class ArithmeticOptimizerTest : public GrapplerTest {
29  protected:
30   // Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
31   // longer have any output consumers.
OptimizeAndPrune(ArithmeticOptimizer * optimizer,GrapplerItem * item,GraphDef * output)32   void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
33                         GraphDef* output) {
34     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
35     item->graph.Swap(output);
36     output->Clear();
37     TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
38   }
39 
40   // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
OptimizeTwice(ArithmeticOptimizer * optimizer,GrapplerItem * item,GraphDef * output)41   void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item,
42                      GraphDef* output) {
43     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
44     item->graph.Swap(output);
45     output->Clear();
46     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
47   }
48 
49   // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
50   // Optionally run a constant folding pass before pruning.
51   void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
52                              GraphDef* output, bool const_folding = false) {
53     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
54 
55     item->graph.Swap(output);
56     output->Clear();
57     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
58 
59     if (const_folding) {
60       item->graph.Swap(output);
61       output->Clear();
62       TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr)
63                        .Optimize(nullptr, *item, output));
64     }
65 
66     item->graph.Swap(output);
67     output->Clear();
68     TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
69   }
70 
71   // TODO(ezhulenev): Make private. After migration to stages each test
72   // should explicitly enable required optimization for tests isolation
DisableAllStages(ArithmeticOptimizer * optimizer)73   void DisableAllStages(ArithmeticOptimizer* optimizer) {
74     ArithmeticOptimizer::ArithmeticOptimizerOptions options;
75     options.dedup_computations = false;
76     options.combine_add_to_addn = false;
77     options.convert_sqrt_div_to_rsqrt_mul = false;
78     options.convert_pow = false;
79     options.convert_log1p = false;
80     options.optimize_max_or_min_of_monotonic = false;
81     options.fold_conjugate_into_transpose = false;
82     options.fold_multiply_into_conv = false;
83     options.fold_transpose_into_matmul = false;
84     options.hoist_common_factor_out_of_aggregation = false;
85     options.hoist_cwise_unary_chains = false;
86     options.minimize_broadcasts = false;
87     options.remove_identity_transpose = false;
88     options.remove_involution = false;
89     options.remove_idempotent = false;
90     options.remove_redundant_bitcast = false;
91     options.remove_redundant_cast = false;
92     options.remove_redundant_reshape = false;
93     options.remove_negation = false;
94     options.remove_logical_not = false;
95     options.reorder_cast_like_and_value_preserving = false;
96     options.replace_mul_with_square = false;
97     options.simplify_aggregation = false;
98     options.unary_ops_composition = false;
99     optimizer->options_ = options;
100   }
101 
DisableAddToAddNCombining(ArithmeticOptimizer * optimizer)102   void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
103     optimizer->options_.combine_add_to_addn = false;
104   }
105 
EnableOnlyAddToAddNCombining(ArithmeticOptimizer * optimizer)106   void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
107     DisableAllStages(optimizer);
108     optimizer->options_.combine_add_to_addn = true;
109   }
110 
EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer * optimizer)111   void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) {
112     DisableAllStages(optimizer);
113     optimizer->options_.fold_conjugate_into_transpose = true;
114   }
115 
EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer * optimizer)116   void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) {
117     DisableAllStages(optimizer);
118     optimizer->options_.fold_multiply_into_conv = true;
119   }
120 
EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer * optimizer)121   void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) {
122     DisableAllStages(optimizer);
123     optimizer->options_.fold_transpose_into_matmul = true;
124   }
125 
EnableOnlyHoistCommonFactor(ArithmeticOptimizer * optimizer)126   void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
127     DisableAllStages(optimizer);
128     optimizer->options_.hoist_common_factor_out_of_aggregation = true;
129   }
130 
EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer * optimizer)131   void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) {
132     DisableAllStages(optimizer);
133     optimizer->options_.minimize_broadcasts = true;
134   }
135 
EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer * optimizer)136   void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) {
137     DisableAllStages(optimizer);
138     optimizer->options_.remove_identity_transpose = true;
139   }
140 
EnableOnlyRemoveInvolution(ArithmeticOptimizer * optimizer)141   void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) {
142     DisableAllStages(optimizer);
143     optimizer->options_.remove_involution = true;
144   }
145 
EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer * optimizer)146   void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
147     DisableAllStages(optimizer);
148     optimizer->options_.remove_redundant_bitcast = true;
149   }
150 
EnableOnlyRemoveRedundantCast(ArithmeticOptimizer * optimizer)151   void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
152     DisableAllStages(optimizer);
153     optimizer->options_.remove_redundant_cast = true;
154   }
155 
EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer * optimizer)156   void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
157     DisableAllStages(optimizer);
158     optimizer->options_.remove_redundant_reshape = true;
159   }
160 
EnableOnlyRemoveNegation(ArithmeticOptimizer * optimizer)161   void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
162     DisableAllStages(optimizer);
163     optimizer->options_.remove_negation = true;
164   }
165 
EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer * optimizer)166   void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) {
167     DisableAllStages(optimizer);
168     optimizer->options_.reorder_cast_like_and_value_preserving = true;
169   }
170 
EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer * optimizer)171   void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
172     DisableAllStages(optimizer);
173     optimizer->options_.replace_mul_with_square = true;
174   }
175 
EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer * optimizer)176   void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
177     DisableAllStages(optimizer);
178     optimizer->options_.hoist_cwise_unary_chains = true;
179   }
180 
EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer * optimizer)181   void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
182     DisableAllStages(optimizer);
183     optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
184   }
185 
EnableOnlyLogSoftmax(ArithmeticOptimizer * optimizer)186   void EnableOnlyLogSoftmax(ArithmeticOptimizer* optimizer) {
187     DisableAllStages(optimizer);
188     optimizer->options_.convert_log_softmax = true;
189   }
190 
EnableOnlyConvertPow(ArithmeticOptimizer * optimizer)191   void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
192     DisableAllStages(optimizer);
193     optimizer->options_.convert_pow = true;
194   }
195 
EnableOnlyFuseSquaredDiff(ArithmeticOptimizer * optimizer)196   void EnableOnlyFuseSquaredDiff(ArithmeticOptimizer* optimizer) {
197     DisableAllStages(optimizer);
198     optimizer->options_.fuse_squared_diff = true;
199   }
200 
EnableOnlyRemoveIdempotent(ArithmeticOptimizer * optimizer)201   void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
202     DisableAllStages(optimizer);
203     optimizer->options_.remove_idempotent = true;
204   }
205 
EnableOnlyRemoveLogicalNot(ArithmeticOptimizer * optimizer)206   void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) {
207     DisableAllStages(optimizer);
208     optimizer->options_.remove_logical_not = true;
209   }
210 
EnableOnlySimplifyAggregation(ArithmeticOptimizer * optimizer)211   void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) {
212     DisableAllStages(optimizer);
213     optimizer->options_.simplify_aggregation = true;
214   }
215 
EnableOnlyLog1p(ArithmeticOptimizer * optimizer)216   void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
217     DisableAllStages(optimizer);
218     optimizer->options_.convert_log1p = true;
219   }
220 
EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer * optimizer)221   void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
222     DisableAllStages(optimizer);
223     optimizer->options_.optimize_max_or_min_of_monotonic = true;
224   }
225 
EnableOnlyExpm1(ArithmeticOptimizer * optimizer)226   void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
227     DisableAllStages(optimizer);
228     optimizer->options_.convert_expm1 = true;
229   }
230 
EnableOnlyUnaryOpsComposition(ArithmeticOptimizer * optimizer)231   void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
232     DisableAllStages(optimizer);
233     optimizer->options_.unary_ops_composition = true;
234   }
235 
EnableOnlyRemoveStackStridedSliceSameAxis(ArithmeticOptimizer * optimizer)236   void EnableOnlyRemoveStackStridedSliceSameAxis(
237       ArithmeticOptimizer* optimizer) {
238     DisableAllStages(optimizer);
239     optimizer->options_.remove_stack_strided_slice_same_axis = true;
240   }
241 };
242 
243 }  // end namespace grappler
244 }  // end namespace tensorflow
245 
246 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
247