1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 
28 // Classes to maintain a static registry of Graph based passes to be applied to
29 // a function graph.
30 
31 namespace tensorflow {
32 
33 // A pass to be registered with the FunctionOptimizationPassRegistry. This pass
34 // takes in a DeviceSet (available devices for executing the Graph), ConfigProto
35 // (session configuration parameters), Graph (computation),
36 // FunctionLibraryDefinition (mapping between function names and function
37 // definitions of the Graph), control ret/target node names (names of nodes that
38 // must execute but their data outputs, if they have any, are irrelevant), and
39 // whether control ret nodes (via thier name) were updated. Mutations to the
40 // Graph and other associated arguments are performed inplace by the pass.
41 class FunctionOptimizationPass {
42  public:
~FunctionOptimizationPass()43   virtual ~FunctionOptimizationPass() {}
44   virtual Status Run(const DeviceSet& device_set,
45                      const ConfigProto& config_proto,
46                      std::unique_ptr<Graph>* graph,
47                      FunctionLibraryDefinition* flib_def,
48                      std::vector<std::string>* control_ret_node_names,
49                      bool* control_rets_updated) = 0;
50 };
51 
52 // A global function optimization pass registry that is used to hold one
53 // FunctionOptimizationPass. Passes registered to this registry will run before
54 // passes registered in OptimizationPassRegistry.
55 class FunctionOptimizationPassRegistry {
56  public:
57   // Initializes registry with a pass. Only one pass should be set. An assertion
58   // will be triggered if the registry already has a pass set and is being
59   // initialized with another pass.
60   void Init(std::unique_ptr<FunctionOptimizationPass> pass);
61 
62   // Runs a pass if the registry contains one.
63   Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
64              std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
65              std::vector<std::string>* control_ret_node_names,
66              bool* control_rets_updated);
67 
68   // Returns the global registry of function graph passes.
69   static FunctionOptimizationPassRegistry& Global();
70 
71  private:
72   std::unique_ptr<FunctionOptimizationPass> pass_;
73 };
74 
75 namespace function_optimization_registration {
76 
77 class FunctionOptimizationPassRegistration {
78  public:
FunctionOptimizationPassRegistration(std::unique_ptr<FunctionOptimizationPass> pass)79   explicit FunctionOptimizationPassRegistration(
80       std::unique_ptr<FunctionOptimizationPass> pass) {
81     FunctionOptimizationPassRegistry::Global().Init(std::move(pass));
82   }
83 };
84 
85 }  // namespace function_optimization_registration
86 
87 }  // namespace tensorflow
88 
89 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_
90