1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_
17 #define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/core/graph/graph.h"
21 #include "tensorflow/core/protobuf/config.pb.h"
22 
23 namespace tensorflow {
24 
25 enum class MlirBridgeRolloutPolicy {
26   // The MLIR bridge is explicitly disabled by the user and must not be run.
27   kDisabledByUser = 0,
28   // The MLIR bridge is explicitly enabled by the user and must be run. If the
29   // MLIR bridge errors, the fallback path should NOT be used.
30   kEnabledByUser,
31   // The bridge was not explicitly enabled or disabled by the user. Based on the
32   // features in the model, the MLIR bridge should not be run.
33   kDisabledAfterGraphAnalysis,
34   // The bridge was not explicitly enabled or disabled by the user. Based on the
35   // features in the model, the MLIR bridge should be run. If the MLIR Bridge
36   // errors, the fallback path should be used whenever possible.
37   kEnabledAfterGraphAnalysis,
38 };
39 
40 // Analyzes the user requested policy as well as the contents of the graph and
41 // returns true when the MLIR Bridge should be run.
42 //
43 // If the user explicitly requests the bridge be enabled or disabled, this
44 // function will respect the request. If the user does not explicitly request
45 // enabled or disabled, it will decide whether or not to run the bridge.
46 //
47 // The config_proto param is a required input for all TF1 graphs but it is
48 // redundant for TF2 graphs.
49 // If getting rollout policy involves graph analysis, `record_stats` is used
50 // to decide whether to emit metrics on unsupported features of the graph.
51 MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
52     const tensorflow::Graph& graph,
53     absl::optional<tensorflow::ConfigProto> config_proto,
54     bool uses_uninitialized_resource_args, bool record_stats = false);
55 
56 }  // namespace tensorflow
57 
58 #endif  // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_
59