1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_ 18 19 #include "tensorflow/core/grappler/costs/cost_estimator.h" 20 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" 21 #include "tensorflow/core/grappler/costs/virtual_scheduler.h" 22 #include "tensorflow/core/grappler/grappler_item.h" 23 #include "tensorflow/core/lib/core/status.h" 24 25 namespace tensorflow { 26 class CostGraphDef; 27 class GraphDef; 28 } // namespace tensorflow 29 30 namespace tensorflow { 31 namespace grappler { 32 33 class Cluster; 34 struct GrapplerItem; 35 36 // Estimate the cost of running a Grappler item based on the theoretical 37 // performance of the hardware that will run the model. Note that this 38 // internally uses static shape inference. An option for aggressive shape 39 // inference is provided to minimize unknown shapes, and this is only applicable 40 // with static shape inference. 41 class AnalyticalCostEstimator : public CostEstimator { 42 public: 43 AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes, 44 bool use_aggressive_shape_inference); 45 AnalyticalCostEstimator(Cluster* cluster, 46 std::unique_ptr<OpLevelCostEstimator> node_estimator, 47 std::unique_ptr<ReadyNodeManager> node_manager, 48 bool use_static_shapes, 49 bool use_aggressive_shape_inference); 50 AnalyticalCostEstimator(Cluster* cluster, 51 std::unique_ptr<OpLevelCostEstimator> node_estimator, 52 std::unique_ptr<ReadyNodeManager> node_manager, 53 std::unique_ptr<VirtualPlacer> placer, 54 bool use_static_shapes, 55 bool use_aggressive_shape_inference); ~AnalyticalCostEstimator()56 ~AnalyticalCostEstimator() override {} 57 58 // This implementation always returns OK. 59 Status Initialize(const GrapplerItem& item) override; 60 61 // Predict the performance of each node of the optimized graph and annotate 62 // the RunMetadata with the corresponding estimates. Also returns the 63 // expected cost for the whole graph. 64 Status PredictCosts(const GraphDef& optimized_graph, 65 RunMetadata* run_metadata, Costs* cost) const override; 66 GetScheduler()67 const VirtualScheduler* GetScheduler() const { return scheduler_.get(); } 68 69 private: 70 const GrapplerItem* item_; 71 std::unique_ptr<OpLevelCostEstimator> node_estimator_; 72 std::unique_ptr<ReadyNodeManager> node_manager_; 73 std::unique_ptr<VirtualScheduler> scheduler_; 74 75 bool use_static_shapes_; 76 bool use_aggressive_shape_inference_; 77 }; 78 79 } // end namespace grappler 80 } // end namespace tensorflow 81 82 #endif // TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_ 83