1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes to maintain a static registry of whole-graph optimization
17 // passes to be applied by the Session when it initializes a graph.
18 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_
19 #define TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_
20 
21 #include <functional>
22 #include <map>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/device_set.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/graph/costmodel.h"
28 #include "tensorflow/core/graph/graph.h"
29 
30 namespace tensorflow {
31 struct SessionOptions;
32 
33 // All the parameters used by an optimization pass are packaged in
34 // this struct. They should be enough for the optimization pass to use
35 // as a key into a state dictionary if it wants to keep state across
36 // calls.
37 struct GraphOptimizationPassOptions {
38   // Filled in by DirectSession for PRE_PLACEMENT optimizations. Can be empty.
39   string session_handle;
40   const SessionOptions* session_options = nullptr;
41   const CostModel* cost_model = nullptr;
42 
43   FunctionLibraryDefinition* flib_def = nullptr;  // Not owned.
44   // The DeviceSet contains all the devices known to the system and is
45   // filled in for optimizations run by the session master, i.e.,
46   // PRE_PLACEMENT, POST_PLACEMENT, and POST_REWRITE_FOR_EXEC. It is
47   // nullptr for POST_PARTITIONING optimizations which are run at the
48   // workers.
49   const DeviceSet* device_set = nullptr;  // Not owned.
50 
51   // The graph to optimize, for optimization passes that run before
52   // partitioning. Null for post-partitioning passes.
53   // An optimization pass may replace *graph with a new graph object.
54   std::unique_ptr<Graph>* graph = nullptr;
55 
56   // Graphs for each partition, if running post-partitioning. Optimization
57   // passes may alter the graphs, but must not add or remove partitions.
58   // Null for pre-partitioning passes.
59   std::unordered_map<string, std::unique_ptr<Graph>>* partition_graphs =
60       nullptr;
61 
62   // Indicator of whether or not the graph was derived from a function.
63   bool is_function_graph = false;
64 };
65 
66 // Optimization passes are implemented by inheriting from
67 // GraphOptimizationPass.
68 class GraphOptimizationPass {
69  public:
~GraphOptimizationPass()70   virtual ~GraphOptimizationPass() {}
71   virtual Status Run(const GraphOptimizationPassOptions& options) = 0;
set_name(const string & name)72   void set_name(const string& name) { name_ = name; }
name()73   string name() const { return name_; }
74 
75  private:
76   // The name of the optimization pass, which is the same as the inherited
77   // class name.
78   string name_;
79 };
80 
81 // The key is a 'phase' number. Phases are executed in increasing
82 // order. Within each phase the order of passes is undefined.
83 typedef std::map<int, std::vector<std::unique_ptr<GraphOptimizationPass>>>
84     GraphOptimizationPasses;
85 
86 // A global OptimizationPassRegistry is used to hold all passes.
87 class OptimizationPassRegistry {
88  public:
89   // Groups of passes are run at different points in initialization.
90   enum Grouping {
91     PRE_PLACEMENT,          // after cost model assignment, before placement.
92     POST_PLACEMENT,         // after placement.
93     POST_REWRITE_FOR_EXEC,  // after re-write using feed/fetch endpoints.
94     POST_PARTITIONING,      // after partitioning
95   };
96 
97   // Add an optimization pass to the registry.
98   void Register(Grouping grouping, int phase,
99                 std::unique_ptr<GraphOptimizationPass> pass);
100 
groups()101   const std::map<Grouping, GraphOptimizationPasses>& groups() {
102     return groups_;
103   }
104 
105   // Run all passes in grouping, ordered by phase, with the same
106   // options.
107   Status RunGrouping(Grouping grouping,
108                      const GraphOptimizationPassOptions& options);
109 
110   // Returns the global registry of optimization passes.
111   static OptimizationPassRegistry* Global();
112 
113   // Prints registered optimization passes for debugging.
114   void LogGrouping(Grouping grouping, int vlog_level);
115   void LogAllGroupings(int vlog_level);
116 
117  private:
118   std::map<Grouping, GraphOptimizationPasses> groups_;
119 };
120 
121 namespace optimization_registration {
122 
123 class OptimizationPassRegistration {
124  public:
OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping,int phase,std::unique_ptr<GraphOptimizationPass> pass,string optimization_pass_name)125   OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping,
126                                int phase,
127                                std::unique_ptr<GraphOptimizationPass> pass,
128                                string optimization_pass_name) {
129     pass->set_name(optimization_pass_name);
130     OptimizationPassRegistry::Global()->Register(grouping, phase,
131                                                  std::move(pass));
132   }
133 };
134 
135 }  // namespace optimization_registration
136 
137 #define REGISTER_OPTIMIZATION(grouping, phase, optimization) \
138   REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization)
139 
140 #define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \
141   REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)
142 
143 #define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)         \
144   static ::tensorflow::optimization_registration::OptimizationPassRegistration \
145       register_optimization_##ctr(                                             \
146           grouping, phase,                                                     \
147           ::std::unique_ptr<::tensorflow::GraphOptimizationPass>(              \
148               new optimization()),                                             \
149           #optimization)
150 
151 }  // namespace tensorflow
152 
153 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_
154