1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_ 18 19 #include <vector> 20 21 #include "tensorflow/core/framework/attr_value.pb.h" 22 #include "tensorflow/core/framework/graph.pb.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/grappler/grappler_item.h" 25 #include "tensorflow/core/grappler/utils.h" 26 #include "tensorflow/core/lib/random/random.h" 27 #include "tensorflow/core/platform/test.h" 28 #include "tensorflow/core/public/session_options.h" 29 30 namespace tensorflow { 31 namespace grappler { 32 33 class GrapplerTest : public ::testing::Test { 34 public: 35 GrapplerTest(); 36 37 protected: 38 std::vector<Tensor> EvaluateNodes( 39 const GraphDef& graph, const std::vector<string>& node_names) const; 40 41 std::vector<Tensor> EvaluateNodes( 42 const GraphDef& graph, const std::vector<string>& node_names, 43 const std::vector<std::pair<string, Tensor>>& inputs) const; 44 45 std::vector<Tensor> EvaluateFetchNodes(const GrapplerItem& item) const; 46 47 NodeDef* AddNode(const string& name, const string& op, 48 const std::vector<string>& inputs, 49 const std::vector<std::pair<string, AttrValue>>& attributes, 50 GraphDef* graph) const; 51 52 // Checks if two graphs are equal. Both graphs must have the same set of nodes 53 // with the same inputs and attributes. Nodes can be in different order. 54 // 55 // NOTE: This function uses EXPECT/ASSERT macros to check node properties 56 // equality, and adds all failuires to the current test. 57 void CompareGraphs(GraphDef want, GraphDef got) const; 58 59 // Checks if two nodes have the same name, op, inputs and attributes. 60 // 61 // NOTE: This function uses EXPECT/ASSERT macros to check node properties 62 // equality, and adds all failuires to the current test. 63 void CompareNodes(const NodeDef& want, const NodeDef& got) const; 64 65 // Checks if two functions are equal. Both functions must have the same set of 66 // nodes with the same inputs and attributes. Nodes can be in different order. 67 // 68 // NOTE: This function uses EXPECT/ASSERT macros to check node properties 69 // equality, and adds all failures to the current test. 70 void CompareFunctions(FunctionDef want, FunctionDef got) const; 71 72 // Checks if node 'src' is directly connected to the input($position) of 73 // 'dst'. 74 bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src, 75 const string& dst, int position = 0); 76 77 // Counts nodes of the given op-type in a graph. 78 int CountOpNodes(const GraphDef& graph, const string& op); 79 80 // Get a random tensor with given shape. 81 template <DataType DTYPE> GenerateRandomTensor(const TensorShape & shape)82 Tensor GenerateRandomTensor(const TensorShape& shape) const { 83 typedef typename EnumToDataType<DTYPE>::Type T; 84 Tensor tensor(DTYPE, shape); 85 for (auto i = 0; i < tensor.NumElements(); i++) 86 tensor.flat<T>()(i) = i + random::New64() % 10; 87 return tensor; 88 } 89 90 // Get a constant tensor with given shape. 91 template <DataType DTYPE> GenerateConstantTensor(const TensorShape & shape,typename EnumToDataType<DTYPE>::Type value)92 Tensor GenerateConstantTensor( 93 const TensorShape& shape, 94 typename EnumToDataType<DTYPE>::Type value) const { 95 typedef typename EnumToDataType<DTYPE>::Type T; 96 Tensor tensor(DTYPE, shape); 97 for (auto i = 0; i < tensor.NumElements(); i++) tensor.flat<T>()(i) = value; 98 return tensor; 99 } 100 101 private: 102 SessionOptions options_; 103 }; 104 105 } // end namespace grappler 106 } // end namespace tensorflow 107 108 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_ 109