1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // DEPRECATED: Use the C++ API defined in tensorflow/cc instead. 17 18 #ifndef TENSORFLOW_CORE_GRAPH_TESTLIB_H_ 19 #define TENSORFLOW_CORE_GRAPH_TESTLIB_H_ 20 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/graph/graph.h" 27 #include "tensorflow/core/graph/types.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace tensorflow { 31 namespace test { 32 namespace graph { 33 34 // Converts "g" into its corresponding GraphDef "def". 35 ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.") 36 void ToGraphDef(Graph* g, GraphDef* def); 37 38 // A few helpers to construct a graph. 39 40 // Adds a node in "g" producing a constant "tensor". 41 Node* Constant(Graph* g, const Tensor& tensor); 42 Node* Constant(Graph* g, const Tensor& tensor, const string& name); 43 44 // Adds a node in "g" producing a constant "tensor" on the host. 45 // The given node which, unlike the regular Constant above, always 46 // stores its output on the host. This is necessary for use 47 // in GPU tests where the test Op in question runs on the device 48 // but requires some arguments to be pinned to the host. 49 Node* HostConstant(Graph* g, const Tensor& tensor); 50 Node* HostConstant(Graph* g, const Tensor& tensor, const string& name); 51 52 // Adds a variable in "g" of the given "shape" and "dtype". 53 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape); 54 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape, 55 const string& name); 56 57 // Adds an assign node in "g" which assigns "val" into "var". 58 Node* Assign(Graph* g, Node* var, Node* val); 59 60 // Adds a send node "g" sending "input" as a named "tensor" from 61 // "sender" to "receiver". 62 Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, 63 const uint64 sender_incarnation, const string& receiver); 64 65 // Adds a recv node in "g" receiving a named "tensor" from "sender" 66 // to "receiver". 67 Node* Recv(Graph* g, const string& tensor, const string& type, 68 const string& sender, const uint64 sender_incarnation, 69 const string& receiver); 70 71 // Adds a cumsum "node" in "g" doing cumsum(data, axes). 72 Node* Cumsum(Graph* g, Node* data, Node* axes, bool exclusive = false, 73 bool reverse = false); 74 75 // Adds a reduction "node" in "g" doing sum(data, axes). "reduce" is 76 // a reduction, e.g., Sum, Max, Min, Mean, etc. 77 Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, 78 bool keep_dims = false); 79 80 // Adds a Matmul node in g doing in0.contract(in1). 81 Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, 82 bool transpose_b); 83 84 // Adds a Matmul node in g doing in0.contract(in1). 85 Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y); 86 87 // Adds a Quantize node into g that quantize floats into QUINT8. The range of 88 // the input float tensor is assumed to be [-1, 1]. 89 Node* QuantizeToUINT8(Graph* g, Node* data); 90 91 // Adds a unary function "func" "node" in "g" taking "input". 92 Node* Unary(Graph* g, const string& func, Node* input, int index = 0); 93 94 // Adds an identity node in "g" taking "input" and producing an 95 // identity copy. 96 Node* Identity(Graph* g, Node* input, int index = 0); 97 98 // Adds a binary function "func" node in "g" taking "in0" and "in1". 99 Node* Binary(Graph* g, const string& func, Node* in0, Node* in1); 100 101 // Adds a function "func" node in "g" taking inputs "ins". 102 Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins); 103 104 // Adds a binary add node in "g" doing in0 + in1. 105 Node* Add(Graph* g, Node* in0, Node* in1); 106 107 // Reverses <axis> dimensions of <tensor>> 108 Node* Reverse(Graph* g, Node* tensor, Node* axis); 109 110 // Generates random unit uniform distribution of the input shape. 111 Node* RandomUniform(Graph* g, Node* input, DataType dtype); 112 113 // Generates random unit normal distribution of the input shape. 114 Node* RandomGaussian(Graph* g, Node* input, DataType dtype); 115 116 // Generates random gamma distribution with the given shape and alpha[s]. 117 // Output dtype determined by alpha. 118 Node* RandomGamma(Graph* g, Node* shape, Node* alpha); 119 120 // Generates random poisson distribution with the given shape and lam[s]. 121 // Output dtype determined by lam. 122 Node* RandomPoisson(Graph* g, Node* shape, Node* lam); 123 124 // Rolls tensor by an offset of <shift> along the corresponding 125 // <axis> dimensions. 126 Node* Roll(Graph* g, Node* input, Node* shift, Node* axis); 127 128 // Generates random parameters from the truncated standard normal distribution 129 // of the nput shape 130 Node* TruncatedNormal(Graph* g, Node* input, DataType dtype); 131 132 // Adds an error node in "g". The node's computation always 133 // generates an error with the given error message "errmsg". 134 Node* Error(Graph* g, Node* input, const string& errmsg); 135 136 // Adds a node that generates a invalid ref output. 137 Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type); 138 139 // Adds a node in "g". Its Compute() sleeps a while and outputs the 140 // input (i.e., same as identity). 141 Node* Delay(Graph* g, Node* input, Microseconds delay_micros); 142 143 // Adds a no-op "node" in "g", with control inputs from all nodes in 144 // control_inputs vector. 145 Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs); 146 147 // Adds a Switch node in "g". If "in1" is true, it forwards "in0" to 148 // output 1. Otherwise, it forwards "in0" to output 0. 149 Node* Switch(Graph* g, Node* in0, Node* in1); 150 151 // Adds an Enter node in "g", which enters a new frame. 152 Node* Enter(Graph* g, Node* input, const string& frame_name); 153 154 // Adds an Exit node in "g", which exits a frame. 155 Node* Exit(Graph* g, Node* input); 156 157 // Adds a Merge node in "g" with two inputs "in0" and "in1". 158 Node* Merge(Graph* g, Node* in0, Node* in1); 159 160 // Adds a Merge node in "g". The first input is "in0", the remaining 161 // inputs are only given by their names in remaining_in. 162 Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in); 163 164 // Adds a NextIteration node in "g", which makes its input available 165 // to the next iteration. 166 Node* Next(Graph* g, const string& name, Node* input); 167 168 // Adds a LoopCond node in "g", representing the "pivot" termination 169 // condition of a loop. 170 Node* LoopCond(Graph* g, Node* input); 171 172 // Adds a less node in "g", which returns true iff "in0" < "in1". 173 Node* Less(Graph* g, Node* in0, Node* in1); 174 175 // Adds a select node in "g", which outputs either "inx" or "iny" 176 // depending on the boolean value of "c". 177 Node* Select(Graph* g, Node* c, Node* inx, Node* iny); 178 179 // Casts "in" into data type "dst". 180 Node* Cast(Graph* g, Node* in, DataType dst); 181 182 // Perform gather op on params "in0" with indices "in1" and axis "axis". 183 Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis); 184 185 // Gets a tensor stored in the session state. 186 Node* GetSessionTensor(Graph* g, Node* in); 187 188 // Adds a Concat node in "g". The first input is "concat_dim", the 189 // dimension to concatenate on, and the tensors to concatenate are 190 // given in "tensors". 191 Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors); 192 193 // Adds a ConcatV2 node in "g". The last input is "concat_dim", the 194 // dimension to concatenate on, and the tensors to concatenate are 195 // given in "tensors". 196 Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim); 197 198 // Add a Relu node in "g". 199 Node* Relu(Graph* g, Node* in); 200 201 // Add a Relu6 node in "g". 202 Node* Relu6(Graph* g, Node* in); 203 204 // Add a BiasAdd node in "g". 205 Node* BiasAdd(Graph* g, Node* value, Node* bias); 206 207 // Add a Conv2D node in "g". 208 Node* Conv2D(Graph* g, Node* in0, Node* in1); 209 210 // Add a Diag node in "g". 211 Node* Diag(Graph* g, Node* in, DataType type); 212 213 // Add a DiagPart node in "g". 214 Node* DiagPart(Graph* g, Node* in, DataType type); 215 216 // Add a CheckNumerics node in "g". 217 Node* CheckNumerics(Graph* g, Node* in, const string& message); 218 219 // Add an _Arg node in "g". 220 Node* Arg(Graph* g, int64 index, DataType type); 221 222 // Add a _Retval node in "g". 223 Node* Retval(Graph* g, int64 index, Node* in); 224 225 } // end namespace graph 226 } // end namespace test 227 } // end namespace tensorflow 228 229 #endif // TENSORFLOW_CORE_GRAPH_TESTLIB_H_ 230