1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/core/common_runtime/device_set.h" 24 #include "tensorflow/core/framework/function.h" 25 #include "tensorflow/core/graph/graph.h" 26 #include "tensorflow/core/protobuf/config.pb.h" 27 28 // Classes to maintain a static registry of Graph based passes to be applied to 29 // a function graph. 30 31 namespace tensorflow { 32 33 // A pass to be registered with the FunctionOptimizationPassRegistry. This pass 34 // takes in a DeviceSet (available devices for executing the Graph), ConfigProto 35 // (session configuration parameters), Graph (computation), 36 // FunctionLibraryDefinition (mapping between function names and function 37 // definitions of the Graph), control ret/target node names (names of nodes that 38 // must execute but their data outputs, if they have any, are irrelevant), and 39 // whether control ret nodes (via thier name) were updated. Mutations to the 40 // Graph and other associated arguments are performed inplace by the pass. 41 class FunctionOptimizationPass { 42 public: ~FunctionOptimizationPass()43 virtual ~FunctionOptimizationPass() {} 44 virtual Status Run(const DeviceSet& device_set, 45 const ConfigProto& config_proto, 46 std::unique_ptr<Graph>* graph, 47 FunctionLibraryDefinition* flib_def, 48 std::vector<std::string>* control_ret_node_names, 49 bool* control_rets_updated) = 0; 50 }; 51 52 // A global function optimization pass registry that is used to hold one 53 // FunctionOptimizationPass. Passes registered to this registry will run before 54 // passes registered in OptimizationPassRegistry. 55 class FunctionOptimizationPassRegistry { 56 public: 57 // Initializes registry with a pass. Only one pass should be set. An assertion 58 // will be triggered if the registry already has a pass set and is being 59 // initialized with another pass. 60 void Init(std::unique_ptr<FunctionOptimizationPass> pass); 61 62 // Runs a pass if the registry contains one. 63 Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, 64 std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def, 65 std::vector<std::string>* control_ret_node_names, 66 bool* control_rets_updated); 67 68 // Returns the global registry of function graph passes. 69 static FunctionOptimizationPassRegistry& Global(); 70 71 private: 72 std::unique_ptr<FunctionOptimizationPass> pass_; 73 }; 74 75 namespace function_optimization_registration { 76 77 class FunctionOptimizationPassRegistration { 78 public: FunctionOptimizationPassRegistration(std::unique_ptr<FunctionOptimizationPass> pass)79 explicit FunctionOptimizationPassRegistration( 80 std::unique_ptr<FunctionOptimizationPass> pass) { 81 FunctionOptimizationPassRegistry::Global().Init(std::move(pass)); 82 } 83 }; 84 85 } // namespace function_optimization_registration 86 87 } // namespace tensorflow 88 89 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ 90