1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/lib/random/random.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/public/session_options.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 
33 class GrapplerTest : public ::testing::Test {
34  public:
35   GrapplerTest();
36 
37  protected:
38   std::vector<Tensor> EvaluateNodes(
39       const GraphDef& graph, const std::vector<string>& node_names) const;
40 
41   std::vector<Tensor> EvaluateNodes(
42       const GraphDef& graph, const std::vector<string>& node_names,
43       const std::vector<std::pair<string, Tensor>>& inputs) const;
44 
45   std::vector<Tensor> EvaluateFetchNodes(const GrapplerItem& item) const;
46 
47   NodeDef* AddNode(const string& name, const string& op,
48                    const std::vector<string>& inputs,
49                    const std::vector<std::pair<string, AttrValue>>& attributes,
50                    GraphDef* graph) const;
51 
52   // Checks if two graphs are equal. Both graphs must have the same set of nodes
53   // with the same inputs and attributes. Nodes can be in different order.
54   //
55   // NOTE: This function uses EXPECT/ASSERT macros to check node properties
56   // equality, and adds all failuires to the current test.
57   void CompareGraphs(GraphDef want, GraphDef got) const;
58 
59   // Checks if two nodes have the same name, op, inputs and attributes.
60   //
61   // NOTE: This function uses EXPECT/ASSERT macros to check node properties
62   // equality, and adds all failuires to the current test.
63   void CompareNodes(const NodeDef& want, const NodeDef& got) const;
64 
65   // Checks if two functions are equal. Both functions must have the same set of
66   // nodes with the same inputs and attributes. Nodes can be in different order.
67   //
68   // NOTE: This function uses EXPECT/ASSERT macros to check node properties
69   // equality, and adds all failures to the current test.
70   void CompareFunctions(FunctionDef want, FunctionDef got) const;
71 
72   // Checks if node 'src' is directly connected to the input($position) of
73   // 'dst'.
74   bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src,
75                                 const string& dst, int position = 0);
76 
77   // Counts nodes of the given op-type in a graph.
78   int CountOpNodes(const GraphDef& graph, const string& op);
79 
80   // Get a random tensor with given shape.
81   template <DataType DTYPE>
GenerateRandomTensor(const TensorShape & shape)82   Tensor GenerateRandomTensor(const TensorShape& shape) const {
83     typedef typename EnumToDataType<DTYPE>::Type T;
84     Tensor tensor(DTYPE, shape);
85     for (auto i = 0; i < tensor.NumElements(); i++)
86       tensor.flat<T>()(i) = i + random::New64() % 10;
87     return tensor;
88   }
89 
90   // Get a constant tensor with given shape.
91   template <DataType DTYPE>
GenerateConstantTensor(const TensorShape & shape,typename EnumToDataType<DTYPE>::Type value)92   Tensor GenerateConstantTensor(
93       const TensorShape& shape,
94       typename EnumToDataType<DTYPE>::Type value) const {
95     typedef typename EnumToDataType<DTYPE>::Type T;
96     Tensor tensor(DTYPE, shape);
97     for (auto i = 0; i < tensor.NumElements(); i++) tensor.flat<T>()(i) = value;
98     return tensor;
99   }
100 
101  private:
102   SessionOptions options_;
103 };
104 
105 }  // end namespace grappler
106 }  // end namespace tensorflow
107 
108 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_
109