1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_def_builder.h"
17
18 #include <limits>
19 #include <vector>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/op_def_util.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/array_slice.h"
26 #include "tensorflow/core/lib/strings/scanner.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29
30 using ::tensorflow::strings::Scanner;
31
32 namespace tensorflow {
33
34 namespace {
35
AttrError(StringPiece orig,const string & op_name)36 string AttrError(StringPiece orig, const string& op_name) {
37 return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
38 }
39
ConsumeAttrName(StringPiece * sp,StringPiece * out)40 bool ConsumeAttrName(StringPiece* sp, StringPiece* out) {
41 return Scanner(*sp)
42 .One(Scanner::LETTER)
43 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
44 .StopCapture()
45 .AnySpace()
46 .OneLiteral(":")
47 .AnySpace()
48 .GetResult(sp, out);
49 }
50
ConsumeListPrefix(StringPiece * sp)51 bool ConsumeListPrefix(StringPiece* sp) {
52 return Scanner(*sp)
53 .OneLiteral("list")
54 .AnySpace()
55 .OneLiteral("(")
56 .AnySpace()
57 .GetResult(sp);
58 }
59
ConsumeQuotedString(char quote_ch,StringPiece * sp,StringPiece * out)60 bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) {
61 const string quote_str(1, quote_ch);
62 return Scanner(*sp)
63 .OneLiteral(quote_str.c_str())
64 .RestartCapture()
65 .ScanEscapedUntil(quote_ch)
66 .StopCapture()
67 .OneLiteral(quote_str.c_str())
68 .AnySpace()
69 .GetResult(sp, out);
70 }
71
ConsumeAttrType(StringPiece * sp,StringPiece * out)72 bool ConsumeAttrType(StringPiece* sp, StringPiece* out) {
73 return Scanner(*sp)
74 .Many(Scanner::LOWERLETTER_DIGIT)
75 .StopCapture()
76 .AnySpace()
77 .GetResult(sp, out);
78 }
79
ConsumeAttrNumber(StringPiece * sp,int64 * out)80 bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
81 Scanner scan(*sp);
82 StringPiece match;
83 StringPiece remaining;
84
85 scan.AnySpace().RestartCapture();
86 if (scan.Peek() == '-') {
87 scan.OneLiteral("-");
88 }
89 if (!scan.Many(Scanner::DIGIT)
90 .StopCapture()
91 .AnySpace()
92 .GetResult(&remaining, &match)) {
93 return false;
94 }
95 int64 value = 0;
96 if (!strings::safe_strto64(match, &value)) {
97 return false;
98 }
99 *out = value;
100 *sp = remaining;
101 return true;
102 }
103
104 #define VERIFY(expr, ...) \
105 do { \
106 if (!(expr)) { \
107 errors->push_back( \
108 strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \
109 return; \
110 } \
111 } while (false)
112
ConsumeCompoundAttrType(StringPiece * sp,StringPiece * out)113 bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
114 auto capture_begin = sp->begin();
115 if (sp->Consume("numbertype") || sp->Consume("numerictype") ||
116 sp->Consume("quantizedtype") || sp->Consume("realnumbertype") ||
117 sp->Consume("realnumberictype")) {
118 *out = StringPiece(capture_begin, sp->begin() - capture_begin);
119 return true;
120 }
121 return false;
122 }
123
ProcessCompoundType(const StringPiece type_string,AttrValue * allowed)124 bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
125 if (type_string == "numbertype" || type_string == "numerictype") {
126 for (DataType dt : NumberTypes()) {
127 allowed->mutable_list()->add_type(dt);
128 }
129 } else if (type_string == "quantizedtype") {
130 for (DataType dt : QuantizedTypes()) {
131 allowed->mutable_list()->add_type(dt);
132 }
133 } else if (type_string == "realnumbertype" ||
134 type_string == "realnumerictype") {
135 for (DataType dt : RealNumberTypes()) {
136 allowed->mutable_list()->add_type(dt);
137 }
138 } else {
139 return false;
140 }
141 return true;
142 }
143
FinalizeAttr(StringPiece spec,OpDef * op_def,std::vector<string> * errors)144 void FinalizeAttr(StringPiece spec, OpDef* op_def,
145 std::vector<string>* errors) {
146 OpDef::AttrDef* attr = op_def->add_attr();
147 StringPiece orig(spec);
148
149 // Parse "<name>:" at the beginning.
150 StringPiece tmp_name;
151 VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'");
152 attr->set_name(tmp_name.data(), tmp_name.size());
153
154 // Read "<type>" or "list(<type>)".
155 bool is_list = ConsumeListPrefix(&spec);
156 string type;
157 StringPiece type_string; // Used if type == "type"
158 if (spec.Consume("string")) {
159 type = "string";
160 } else if (spec.Consume("int")) {
161 type = "int";
162 } else if (spec.Consume("float")) {
163 type = "float";
164 } else if (spec.Consume("bool")) {
165 type = "bool";
166 } else if (spec.Consume("type")) {
167 type = "type";
168 } else if (spec.Consume("shape")) {
169 type = "shape";
170 } else if (spec.Consume("tensor")) {
171 type = "tensor";
172 } else if (spec.Consume("func")) {
173 type = "func";
174 } else if (ConsumeCompoundAttrType(&spec, &type_string)) {
175 type = "type";
176 AttrValue* allowed = attr->mutable_allowed_values();
177 VERIFY(ProcessCompoundType(type_string, allowed),
178 "Expected to see a compound type, saw: ", type_string);
179 } else if (spec.Consume("{")) {
180 // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
181 AttrValue* allowed = attr->mutable_allowed_values();
182 str_util::RemoveLeadingWhitespace(&spec);
183 if (spec.starts_with("\"") || spec.starts_with("'")) {
184 type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
185 while (true) {
186 StringPiece escaped_string;
187 VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) ||
188 ConsumeQuotedString('\'', &spec, &escaped_string),
189 "Trouble parsing allowed string at '", spec, "'");
190 string unescaped;
191 string error;
192 VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error),
193 "Trouble unescaping \"", escaped_string,
194 "\", got error: ", error);
195 allowed->mutable_list()->add_s(unescaped);
196 if (spec.Consume(",")) {
197 str_util::RemoveLeadingWhitespace(&spec);
198 if (spec.Consume("}")) break; // Allow ending with ", }".
199 } else {
200 VERIFY(spec.Consume("}"),
201 "Expected , or } after strings in list, not: '", spec, "'");
202 break;
203 }
204 }
205 } else { // "{ bool, numbertype, string }"
206 type = "type";
207 while (true) {
208 VERIFY(ConsumeAttrType(&spec, &type_string),
209 "Trouble parsing type string at '", spec, "'");
210 if (ProcessCompoundType(type_string, allowed)) {
211 // Processed a compound type.
212 } else {
213 DataType dt;
214 VERIFY(DataTypeFromString(type_string, &dt),
215 "Unrecognized type string '", type_string, "'");
216 allowed->mutable_list()->add_type(dt);
217 }
218 if (spec.Consume(",")) {
219 str_util::RemoveLeadingWhitespace(&spec);
220 if (spec.Consume("}")) break; // Allow ending with ", }".
221 } else {
222 VERIFY(spec.Consume("}"),
223 "Expected , or } after types in list, not: '", spec, "'");
224 break;
225 }
226 }
227 }
228 } else { // if spec.Consume("{")
229 VERIFY(false, "Trouble parsing type string at '", spec, "'");
230 }
231 str_util::RemoveLeadingWhitespace(&spec);
232
233 // Write the type into *attr.
234 if (is_list) {
235 VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'");
236 str_util::RemoveLeadingWhitespace(&spec);
237 attr->set_type(strings::StrCat("list(", type, ")"));
238 } else {
239 attr->set_type(type);
240 }
241
242 // Read optional minimum constraint at the end.
243 if ((is_list || type == "int") && spec.Consume(">=")) {
244 int64 min_limit = -999;
245 VERIFY(ConsumeAttrNumber(&spec, &min_limit),
246 "Could not parse integer lower limit after '>=', found '", spec,
247 "' instead");
248 attr->set_has_minimum(true);
249 attr->set_minimum(min_limit);
250 }
251
252 // Parse default value, if present.
253 if (spec.Consume("=")) {
254 str_util::RemoveLeadingWhitespace(&spec);
255 VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
256 "Could not parse default value '", spec, "'");
257 } else {
258 VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
259 }
260 }
261
262 #undef VERIFY
263
InOutError(bool is_output,StringPiece orig,const string & op_name)264 string InOutError(bool is_output, StringPiece orig, const string& op_name) {
265 return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig,
266 "\") for Op ", op_name);
267 }
268
ConsumeInOutName(StringPiece * sp,StringPiece * out)269 bool ConsumeInOutName(StringPiece* sp, StringPiece* out) {
270 return Scanner(*sp)
271 .One(Scanner::LOWERLETTER)
272 .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)
273 .StopCapture()
274 .AnySpace()
275 .OneLiteral(":")
276 .AnySpace()
277 .GetResult(sp, out);
278 }
279
ConsumeInOutRefOpen(StringPiece * sp)280 bool ConsumeInOutRefOpen(StringPiece* sp) {
281 return Scanner(*sp)
282 .OneLiteral("Ref")
283 .AnySpace()
284 .OneLiteral("(")
285 .AnySpace()
286 .GetResult(sp);
287 }
288
ConsumeInOutRefClose(StringPiece * sp)289 bool ConsumeInOutRefClose(StringPiece* sp) {
290 return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp);
291 }
292
ConsumeInOutNameOrType(StringPiece * sp,StringPiece * out)293 bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) {
294 return Scanner(*sp)
295 .One(Scanner::LETTER)
296 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
297 .StopCapture()
298 .AnySpace()
299 .GetResult(sp, out);
300 }
301
ConsumeInOutTimesType(StringPiece * sp,StringPiece * out)302 bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) {
303 return Scanner(*sp)
304 .OneLiteral("*")
305 .AnySpace()
306 .RestartCapture()
307 .One(Scanner::LETTER)
308 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
309 .StopCapture()
310 .AnySpace()
311 .GetResult(sp, out);
312 }
313
314 #define VERIFY(expr, ...) \
315 do { \
316 if (!(expr)) { \
317 errors->push_back(strings::StrCat( \
318 __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \
319 return; \
320 } \
321 } while (false)
322
FinalizeInputOrOutput(StringPiece spec,bool is_output,OpDef * op_def,std::vector<string> * errors)323 void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
324 std::vector<string>* errors) {
325 OpDef::ArgDef* arg =
326 is_output ? op_def->add_output_arg() : op_def->add_input_arg();
327
328 StringPiece orig(spec);
329
330 // Parse "<name>:" at the beginning.
331 StringPiece tmp_name;
332 VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'");
333 arg->set_name(tmp_name.data(), tmp_name.size());
334
335 // Detect "Ref(...)".
336 if (ConsumeInOutRefOpen(&spec)) {
337 arg->set_is_ref(true);
338 }
339
340 { // Parse "<name|type>" or "<name>*<name|type>".
341 StringPiece first, second, type_or_attr;
342 VERIFY(ConsumeInOutNameOrType(&spec, &first),
343 "Trouble parsing either a type or an attr name at '", spec, "'");
344 if (ConsumeInOutTimesType(&spec, &second)) {
345 arg->set_number_attr(first.data(), first.size());
346 type_or_attr = second;
347 } else {
348 type_or_attr = first;
349 }
350 DataType dt;
351 if (DataTypeFromString(type_or_attr, &dt)) {
352 arg->set_type(dt);
353 } else {
354 const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def);
355 VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'");
356 if (attr->type() == "type") {
357 arg->set_type_attr(type_or_attr.data(), type_or_attr.size());
358 } else {
359 VERIFY(attr->type() == "list(type)", "Reference to attr '",
360 type_or_attr, "' with type ", attr->type(),
361 " that isn't type or list(type)");
362 arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size());
363 }
364 }
365 }
366
367 // Closing ) for Ref(.
368 if (arg->is_ref()) {
369 VERIFY(ConsumeInOutRefClose(&spec),
370 "Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
371 }
372
373 // Should not have anything else.
374 VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
375
376 // Int attrs that are the length of an input or output get a default
377 // minimum of 1.
378 if (!arg->number_attr().empty()) {
379 OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def);
380 if (attr != nullptr && !attr->has_minimum()) {
381 attr->set_has_minimum(true);
382 attr->set_minimum(1);
383 }
384 } else if (!arg->type_list_attr().empty()) {
385 // If an input or output has type specified by a list(type) attr,
386 // it gets a default minimum of 1 as well.
387 OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def);
388 if (attr != nullptr && attr->type() == "list(type)" &&
389 !attr->has_minimum()) {
390 attr->set_has_minimum(true);
391 attr->set_minimum(1);
392 }
393 }
394
395 // If the arg's dtype is resource we should mark the op as stateful as it
396 // likely touches a resource manager. This deliberately doesn't cover inputs /
397 // outputs which resolve to resource via Attrs as those mostly operate on
398 // resource handles as an opaque type (as opposed to ops which explicitly take
399 // / produce resources).
400 if (arg->type() == DT_RESOURCE) {
401 op_def->set_is_stateful(true);
402 }
403 }
404
405 #undef VERIFY
406
num_leading_spaces(StringPiece s)407 int num_leading_spaces(StringPiece s) {
408 size_t i = 0;
409 while (i < s.size() && s[i] == ' ') {
410 ++i;
411 }
412 return i;
413 }
414
ConsumeDocNameColon(StringPiece * sp,StringPiece * out)415 bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) {
416 return Scanner(*sp)
417 .One(Scanner::LETTER)
418 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
419 .StopCapture()
420 .AnySpace()
421 .OneLiteral(":")
422 .AnySpace()
423 .GetResult(sp, out);
424 }
425
IsDocNameColon(StringPiece s)426 bool IsDocNameColon(StringPiece s) {
427 return ConsumeDocNameColon(&s, nullptr /* out */);
428 }
429
FinalizeDoc(const string & text,OpDef * op_def,std::vector<string> * errors)430 void FinalizeDoc(const string& text, OpDef* op_def,
431 std::vector<string>* errors) {
432 std::vector<string> lines = str_util::Split(text, '\n');
433
434 // Remove trailing spaces.
435 for (string& line : lines) {
436 str_util::StripTrailingWhitespace(&line);
437 }
438
439 // First non-blank line -> summary.
440 int l = 0;
441 while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
442 if (static_cast<size_t>(l) < lines.size()) {
443 op_def->set_summary(lines[l]);
444 ++l;
445 }
446 while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
447
448 // Lines until we see name: -> description.
449 int start_l = l;
450 while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
451 ++l;
452 }
453 int end_l = l;
454 // Trim trailing blank lines from the description.
455 while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
456 string desc = str_util::Join(
457 gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
458 if (!desc.empty()) op_def->set_description(desc);
459
460 // name: description
461 // possibly continued on the next line
462 // if so, we remove the minimum indent
463 StringPiece name;
464 std::vector<StringPiece> description;
465 while (static_cast<size_t>(l) < lines.size()) {
466 description.clear();
467 description.push_back(lines[l]);
468 ConsumeDocNameColon(&description.back(), &name);
469 ++l;
470 while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
471 description.push_back(lines[l]);
472 ++l;
473 }
474 // Remove any trailing blank lines.
475 while (!description.empty() && description.back().empty()) {
476 description.pop_back();
477 }
478 // Compute the minimum indent of all lines after the first.
479 int min_indent = -1;
480 for (size_t i = 1; i < description.size(); ++i) {
481 if (!description[i].empty()) {
482 int indent = num_leading_spaces(description[i]);
483 if (min_indent < 0 || indent < min_indent) min_indent = indent;
484 }
485 }
486 // Remove min_indent spaces from all lines after the first.
487 for (size_t i = 1; i < description.size(); ++i) {
488 if (!description[i].empty()) description[i].remove_prefix(min_indent);
489 }
490 // Concatenate lines into a single string.
491 const string complete(str_util::Join(description, "\n"));
492
493 // Find name.
494 bool found = false;
495 for (int i = 0; !found && i < op_def->input_arg_size(); ++i) {
496 if (op_def->input_arg(i).name() == name) {
497 op_def->mutable_input_arg(i)->set_description(complete);
498 found = true;
499 }
500 }
501 for (int i = 0; !found && i < op_def->output_arg_size(); ++i) {
502 if (op_def->output_arg(i).name() == name) {
503 op_def->mutable_output_arg(i)->set_description(complete);
504 found = true;
505 }
506 }
507 for (int i = 0; !found && i < op_def->attr_size(); ++i) {
508 if (op_def->attr(i).name() == name) {
509 op_def->mutable_attr(i)->set_description(complete);
510 found = true;
511 }
512 }
513 if (!found) {
514 errors->push_back(
515 strings::StrCat("No matching input/output/attr for name '", name,
516 "' from Doc() for Op ", op_def->name()));
517 return;
518 }
519 }
520 }
521
522 } // namespace
523
OpDefBuilder(StringPiece op_name)524 OpDefBuilder::OpDefBuilder(StringPiece op_name) {
525 op_def()->set_name(op_name.ToString()); // NOLINT
526 }
527
Attr(StringPiece spec)528 OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
529 attrs_.emplace_back(spec.data(), spec.size());
530 return *this;
531 }
532
Input(StringPiece spec)533 OpDefBuilder& OpDefBuilder::Input(StringPiece spec) {
534 inputs_.emplace_back(spec.data(), spec.size());
535 return *this;
536 }
537
Output(StringPiece spec)538 OpDefBuilder& OpDefBuilder::Output(StringPiece spec) {
539 outputs_.emplace_back(spec.data(), spec.size());
540 return *this;
541 }
542
543 #ifndef TF_LEAN_BINARY
Doc(StringPiece text)544 OpDefBuilder& OpDefBuilder::Doc(StringPiece text) {
545 if (!doc_.empty()) {
546 errors_.push_back(
547 strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
548 } else {
549 doc_.assign(text.data(), text.size());
550 }
551 return *this;
552 }
553 #endif
554
SetIsCommutative()555 OpDefBuilder& OpDefBuilder::SetIsCommutative() {
556 op_def()->set_is_commutative(true);
557 return *this;
558 }
559
SetIsAggregate()560 OpDefBuilder& OpDefBuilder::SetIsAggregate() {
561 op_def()->set_is_aggregate(true);
562 return *this;
563 }
564
SetIsStateful()565 OpDefBuilder& OpDefBuilder::SetIsStateful() {
566 op_def()->set_is_stateful(true);
567 return *this;
568 }
569
SetAllowsUninitializedInput()570 OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
571 op_def()->set_allows_uninitialized_input(true);
572 return *this;
573 }
574
Deprecated(int version,StringPiece explanation)575 OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
576 if (op_def()->has_deprecation()) {
577 errors_.push_back(
578 strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
579 } else {
580 OpDeprecation* deprecation = op_def()->mutable_deprecation();
581 deprecation->set_version(version);
582 deprecation->set_explanation(explanation.ToString());
583 }
584 return *this;
585 }
586
SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))587 OpDefBuilder& OpDefBuilder::SetShapeFn(
588 Status (*fn)(shape_inference::InferenceContext*)) {
589 if (op_reg_data_.shape_inference_fn != nullptr) {
590 errors_.push_back(
591 strings::StrCat("SetShapeFn called twice for Op ", op_def()->name()));
592 } else {
593 op_reg_data_.shape_inference_fn = OpShapeInferenceFn(fn);
594 }
595 return *this;
596 }
597
Finalize(OpRegistrationData * op_reg_data) const598 Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
599 std::vector<string> errors = errors_;
600 *op_reg_data = op_reg_data_;
601
602 OpDef* op_def = &op_reg_data->op_def;
603 for (StringPiece attr : attrs_) {
604 FinalizeAttr(attr, op_def, &errors);
605 }
606 for (StringPiece input : inputs_) {
607 FinalizeInputOrOutput(input, false, op_def, &errors);
608 }
609 for (StringPiece output : outputs_) {
610 FinalizeInputOrOutput(output, true, op_def, &errors);
611 }
612 FinalizeDoc(doc_, op_def, &errors);
613
614 if (errors.empty()) return Status::OK();
615 return errors::InvalidArgument(str_util::Join(errors, "\n"));
616 }
617
618 } // namespace tensorflow
619