1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
18 
19 #include "tensorflow/core/grappler/costs/cost_estimator.h"
20 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
21 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 namespace tensorflow {
26 class CostGraphDef;
27 class GraphDef;
28 }  // namespace tensorflow
29 
30 namespace tensorflow {
31 namespace grappler {
32 
33 class Cluster;
34 struct GrapplerItem;
35 
36 // Estimate the cost of running a Grappler item based on the theoretical
37 // performance of the hardware that will run the model. Note that this
38 // internally uses static shape inference. An option for aggressive shape
39 // inference is provided to minimize unknown shapes, and this is only applicable
40 // with static shape inference.
41 class AnalyticalCostEstimator : public CostEstimator {
42  public:
43   AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes,
44                           bool use_aggressive_shape_inference);
45   AnalyticalCostEstimator(Cluster* cluster,
46                           std::unique_ptr<OpLevelCostEstimator> node_estimator,
47                           std::unique_ptr<ReadyNodeManager> node_manager,
48                           bool use_static_shapes,
49                           bool use_aggressive_shape_inference);
50   AnalyticalCostEstimator(Cluster* cluster,
51                           std::unique_ptr<OpLevelCostEstimator> node_estimator,
52                           std::unique_ptr<ReadyNodeManager> node_manager,
53                           std::unique_ptr<VirtualPlacer> placer,
54                           bool use_static_shapes,
55                           bool use_aggressive_shape_inference);
~AnalyticalCostEstimator()56   ~AnalyticalCostEstimator() override {}
57 
58   // This implementation always returns OK.
59   Status Initialize(const GrapplerItem& item) override;
60 
61   // Predict the performance of each node of the optimized graph and annotate
62   // the RunMetadata with the corresponding estimates. Also returns the
63   // expected cost for the whole graph.
64   Status PredictCosts(const GraphDef& optimized_graph,
65                       RunMetadata* run_metadata, Costs* cost) const override;
66 
GetScheduler()67   const VirtualScheduler* GetScheduler() const { return scheduler_.get(); }
68 
69  private:
70   const GrapplerItem* item_;
71   std::unique_ptr<OpLevelCostEstimator> node_estimator_;
72   std::unique_ptr<ReadyNodeManager> node_manager_;
73   std::unique_ptr<VirtualScheduler> scheduler_;
74 
75   bool use_static_shapes_;
76   bool use_aggressive_shape_inference_;
77 };
78 
79 }  // end namespace grappler
80 }  // end namespace tensorflow
81 
82 #endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
83