1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference_testutil.h"
16
17 #include "tensorflow/core/framework/node_def_util.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/lib/gtl/map_util.h"
20 #include "tensorflow/core/lib/strings/numbers.h"
21 #include "tensorflow/core/lib/strings/scanner.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23
24 namespace tensorflow {
25 namespace shape_inference {
26
27 using errors::Unknown;
28
InferShapes(ShapeInferenceTestOp op,const string & ins,const string & expected_outs)29 Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
30 const string& ins,
31 const string& expected_outs) {
32 const OpRegistrationData* op_reg_data;
33 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
34
35 std::vector<string> ins_v = str_util::Split(ins, ';');
36 std::unique_ptr<const NodeDef> new_node_def;
37
38 InferenceContext::ShapeManager manager;
39 std::vector<ShapeHandle> in_shapes;
40 for (const string& spec : ins_v) {
41 ShapeHandle shape;
42 TF_RETURN_IF_ERROR(MakeShapeFromString(&manager, spec, &shape));
43 in_shapes.push_back(shape);
44 }
45
46 std::vector<std::unique_ptr<std::vector<shape_inference::ShapeAndType>>>
47 input_resource_handle_shapes_and_types;
48 for (const auto p : op.input_resource_handle_shapes_and_types) {
49 if (p == nullptr) {
50 input_resource_handle_shapes_and_types.push_back(nullptr);
51 } else {
52 std::unique_ptr<std::vector<ShapeAndType>> v(
53 new std::vector<ShapeAndType>());
54 for (const auto& shape_and_type : *p) {
55 ShapeHandle shape;
56 TF_RETURN_IF_ERROR(
57 MakeShapeFromString(&manager, shape_and_type.first, &shape));
58 v->emplace_back(shape, shape_and_type.second);
59 }
60 input_resource_handle_shapes_and_types.emplace_back(v.release());
61 }
62 }
63 shape_inference::InferenceContext c(
64 op.graph_def_version, &op.node_def, op_reg_data->op_def, in_shapes,
65 op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types));
66 TF_RETURN_IF_ERROR(c.construction_status());
67 if (op_reg_data->shape_inference_fn == nullptr) {
68 return errors::InvalidArgument(
69 "No shape inference function exists for op '", op.name,
70 "', did you forget to define it?");
71 }
72
73 TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
74
75 const int num_outputs = c.num_outputs();
76
77 if (expected_outs == "e") {
78 return Unknown("Shape inference should have returned error");
79 }
80
81 // Verify the output shape.
82 std::vector<string> expected_outs_v = str_util::Split(expected_outs, ';');
83 if (num_outputs != expected_outs_v.size()) {
84 return Unknown("The expected output string lists the wrong number of ",
85 "outputs. It lists ", expected_outs_v.size(),
86 " but should list ", num_outputs);
87 }
88 for (int i = 0; i < num_outputs; ++i) {
89 StringPiece expected(expected_outs_v[i]);
90 shape_inference::ShapeHandle out = c.output(i);
91
92 string err_prefix = strings::StrCat("Output ", i);
93 string err_suffix =
94 strings::StrCat(". Output shape was ", c.DebugString(out));
95
96 int in_index = -1;
97 for (int i = 0; i < c.num_inputs(); ++i) {
98 if (c.input(i).SameHandle(out)) {
99 in_index = i;
100 }
101 }
102
103 if (str_util::StartsWith(expected, "in")) {
104 if (in_index == -1) {
105 return Unknown(err_prefix,
106 " should have matched an input shape by "
107 "handle, but matched no input shape. This means the ",
108 "shape function was expected to pass an input "
109 "ShapeHandle through for this output, but did not",
110 err_suffix);
111 }
112 auto v = str_util::Split(expected, '|');
113 if (std::find(v.begin(), v.end(), strings::StrCat("in", in_index)) ==
114 v.end()) {
115 return Unknown(
116 err_prefix, " matched input ", in_index,
117 " by handle, but should have matched one of (", expected,
118 ") instead. This means the shape function passed the ShapeHandle ",
119 "for input ", in_index,
120 " to the output, but should have passed a different input ",
121 "ShapeHandle through", err_suffix);
122 }
123 continue;
124 }
125 if (in_index != -1) {
126 return Unknown(err_prefix, " matched input ", in_index,
127 " by ShapeHandle, but was expected to not match an input ",
128 "shape by handle", err_suffix);
129 }
130 if (expected == "?") {
131 if (c.RankKnown(out)) {
132 return Unknown(err_prefix, " expected to be unknown", err_suffix);
133 }
134 continue;
135 }
136
137 // Verify the dimensions.
138 CHECK(str_util::StartsWith(expected, "[") &&
139 str_util::EndsWith(expected, "]"))
140 << expected;
141 expected.remove_prefix(1);
142 expected.remove_suffix(1);
143
144 // Split expected as a dimension.
145 auto expected_dims = str_util::Split(expected, ',');
146 if (!c.RankKnown(out)) {
147 return Unknown(err_prefix, " expected rank ", expected_dims.size(),
148 " but was ?", err_suffix);
149 }
150 if (c.Rank(out) != expected_dims.size()) {
151 return Unknown(err_prefix, " expected rank ", expected_dims.size(),
152 " but was ", c.Rank(out), err_suffix);
153 }
154 for (int j = 0; j < expected_dims.size(); ++j) {
155 err_prefix = strings::StrCat("Output dim ", i, ",", j);
156 StringPiece expected_dim(expected_dims[j]);
157 DimensionHandle out_dim = c.Dim(out, j);
158
159 std::pair<int, int> in_dim_idx(-1, -1);
160 for (int i = 0; i < c.num_inputs(); ++i) {
161 auto in = c.input(i);
162 for (int j = 0; j < c.Rank(in); ++j) {
163 if (c.Dim(in, j).SameHandle(out_dim)) {
164 in_dim_idx = std::make_pair(i, j);
165 }
166 }
167 }
168
169 if (expected_dim == "?") {
170 if (in_dim_idx.first != -1) {
171 return Unknown(err_prefix,
172 " expected to be an unknown but matched input d",
173 in_dim_idx.first, "_", in_dim_idx.second,
174 ". The shape function passed through ",
175 "a DimensionHandle from an input instead of making ",
176 "a new unknown dimension", err_suffix);
177 } else if (c.ValueKnown(out_dim)) {
178 return Unknown(err_prefix, " expected to be unknown but was ",
179 c.Value(out_dim), err_suffix);
180 }
181 } else if (str_util::StartsWith(expected_dim, "d")) {
182 // Compare the dimension values.
183 auto v = str_util::Split(expected_dim, '|');
184 if (in_dim_idx.first == -1) {
185 return Unknown(
186 err_prefix, " was expected to match the dimension of an input, ",
187 "but did not match any input dimension. The shape ",
188 "function was expected to pass through a ",
189 "DimensionHandle for an input, but did not", err_suffix);
190 }
191 if (std::find(v.begin(), v.end(),
192 strings::StrCat("d", in_dim_idx.first, "_",
193 in_dim_idx.second)) == v.end()) {
194 return Unknown(err_prefix, " matched input d", in_dim_idx.first, "_",
195 in_dim_idx.second,
196 ", but should have matched one of (", expected_dim,
197 "). The shape function passed through "
198 "the DimensionHandle for an input, but ",
199 "was expected to pass a different one", err_suffix);
200 }
201 } else {
202 // Parse it as a value.
203 int64 value = -1;
204 if (!strings::safe_strto64(expected_dim, &value)) {
205 return Unknown(err_prefix, ": the expected dimension value '",
206 expected_dim, "' failed to parse as int64",
207 err_suffix);
208 }
209 if (in_dim_idx.first != -1) {
210 return Unknown( //
211 err_prefix, " expected to be ", value, " but matched input d",
212 in_dim_idx.first, "_", in_dim_idx.second,
213 ". The shape function was not expected to pass a DimensionHandle "
214 "from the input to the output, but did. Note that even if the "
215 "passed through output has the same dimension value as the "
216 "expected value, this is considered a failure for the test; "
217 "switch to using d#_# syntax if passing through the "
218 "DimensionHandle should be the expected behavior",
219 err_suffix);
220 } else if (value != c.Value(out_dim)) {
221 return Unknown(err_prefix, " expected to be ", value, " but was ",
222 c.DebugString(out_dim), err_suffix);
223 }
224 }
225 }
226 }
227 return Status::OK();
228 }
229
230 // static
MakeShapeFromString(InferenceContext::ShapeManager * manager,const string & spec,ShapeHandle * output)231 Status ShapeInferenceTestutil::MakeShapeFromString(
232 InferenceContext::ShapeManager* manager, const string& spec,
233 ShapeHandle* output) {
234 if (spec == "?") {
235 *output = manager->UnknownShape();
236 return Status::OK();
237 }
238
239 std::vector<DimensionHandle> dims;
240 strings::Scanner scanner(spec);
241 scanner.OneLiteral("[");
242 while (scanner.Peek() != ']') {
243 if (scanner.Peek() == '?') {
244 scanner.OneLiteral("?");
245 dims.push_back(manager->MakeDim(InferenceContext::kUnknownDim));
246 } else {
247 scanner.RestartCapture().Many(strings::Scanner::DIGIT);
248 StringPiece match;
249 int64 dim_size = 0;
250
251 if (!scanner.GetResult(nullptr, &match) ||
252 !strings::safe_strto64(match, &dim_size)) {
253 return errors::InvalidArgument("Could not parse number in ", spec);
254 }
255
256 dims.push_back(manager->MakeDim(dim_size));
257 }
258
259 if (scanner.Peek() == ',') {
260 scanner.OneLiteral(",");
261 } else if (scanner.Peek() != ']') {
262 return errors::InvalidArgument(
263 "Invalid input spec (] not found in dim shape): ", spec);
264 }
265 }
266 if (!scanner.OneLiteral("]").Eos().GetResult()) {
267 return errors::InvalidArgument("Malformed shape spec: did not end in ']'.");
268 }
269 *output = manager->MakeShape(dims);
270
271 return Status::OK();
272 }
273
274 } // namespace shape_inference
275 } // namespace tensorflow
276