1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17// Package internal generates Go source code with functions for TensorFlow operations.
18//
19// The basic outline of the generated API is as follows:
20//
21// - One function for each TensorFlow operation
22// - The arguments to the function are the inputs and required attributes of the operation
23// - The function returns the outputs
24// - A function is also generated for each optional attribute of the operation.
25//
26// There is a possibility that there are name collisions between the functions
27// generated for ops and the functions generated for optional attributes. For
28// now, we ignore those, but will need to revisit if a collision is actually
29// encountered.
30package internal
31
32/*
33#include <stdlib.h>
34
35#include "tensorflow/c/c_api.h"
36*/
37import "C"
38
39import (
40	"fmt"
41	"io"
42	"io/ioutil"
43	"path"
44	"reflect"
45	"strings"
46	"text/template"
47	"unsafe"
48
49	"github.com/golang/protobuf/proto"
50	pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework"
51)
52
53// GenerateFunctionsForRegisteredOps writes a Go source code file to w
54// containing functions for each TensorFlow operation registered in the address
55// space of the calling process.
56// apidefDirs should be a contain of directories containing api_def_*.pbtxt
57// files to load.
58func GenerateFunctionsForRegisteredOps(
59	w io.Writer, apidefDirs []string) error {
60	ops, apimap, err := registeredOps()
61	if err != nil {
62		return err
63	}
64	for _, dir := range apidefDirs {
65		if err = updateAPIDefs(apimap, dir); err != nil {
66			return err
67		}
68	}
69	return generateFunctionsForOps(w, ops, apimap)
70}
71
72func registeredOps() (*pb.OpList, *apiDefMap, error) {
73	buf := C.TF_GetAllOpList()
74	defer C.TF_DeleteBuffer(buf)
75	var (
76		list = new(pb.OpList)
77		size = int(buf.length)
78		// A []byte backed by C memory.
79		// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
80		data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size]
81		err  = proto.Unmarshal(data, list)
82	)
83	if err != nil {
84		return nil, nil, err
85	}
86	apimap, err := newAPIDefMap(list)
87	return list, apimap, err
88}
89
90func updateAPIDefs(m *apiDefMap, dir string) error {
91	files, err := ioutil.ReadDir(dir)
92	if err != nil {
93		return err
94	}
95	for _, file := range files {
96		data, err := ioutil.ReadFile(path.Join(dir, file.Name()))
97		if err != nil {
98			return fmt.Errorf("failed to read %q: %v", file.Name(), err)
99		}
100		if err = m.Put(string(data)); err != nil {
101			return fmt.Errorf("failed to process %q: %v", file.Name(), err)
102		}
103	}
104	return nil
105}
106
107func generateFunctionsForOps(w io.Writer, ops *pb.OpList, apimap *apiDefMap) error {
108	thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath()
109	if err := tmplHeader.Execute(w, thisPackage); err != nil {
110		return err
111	}
112	blacklist := map[string]bool{
113		"Const":           true,
114		"PyFunc":          true,
115		"PyFuncStateless": true,
116	}
117	for _, op := range ops.Op {
118		if blacklist[op.Name] {
119			continue
120		}
121		apidef, err := apimap.Get(op.Name)
122		if err != nil {
123			return err
124		}
125		if err := generateFunctionForOp(w, op, apidef); err != nil {
126			return err
127		}
128	}
129	return nil
130}
131
132func generateFunctionForOp(w io.Writer, op *pb.OpDef, apidef *pb.ApiDef) error {
133	if strings.HasPrefix(op.Name, "_") { // Internal operation
134		return nil
135	}
136	// Ignore operations where the Go types corresponding to the TensorFlow
137	// type haven't been worked out (such as "func"s).
138	for _, a := range op.Attr {
139		if _, err := goType(a.Type); err != nil {
140			return nil
141		}
142	}
143	// Also, haven't figured out reference types yet, so ignore those too.
144	for _, a := range op.InputArg {
145		if a.IsRef {
146			return nil
147		}
148	}
149	for _, a := range op.OutputArg {
150		if a.IsRef {
151			return nil
152		}
153	}
154	if apidef.Summary == "" {
155		// Undocumented operation, perhaps a sign of not being ready to
156		// export.
157		return nil
158	}
159	tmplArgs, err := newTmplArgs(op, apidef)
160	if err != nil {
161		return err
162	}
163	return tmplOp.Execute(w, tmplArgs)
164}
165
166var (
167	// Go keywords that cannot be used as identifiers.
168	// From https://golang.org/ref/spec#Keywords
169	keywords = []string{
170		"break", "default", "func", "interface", "select", "case",
171		"defer", "go", "map", "struct", "chan", "else", "goto",
172		"package", "switch", "const", "fallthrough", "if", "range",
173		"type", "continue", "for", "import", "return", "var",
174	}
175
176	tmplHeader = template.Must(template.New("header").Parse(`// DO NOT EDIT
177// This file was machine generated by {{.}}
178//
179// WARNING: This generation of wrapper function for TensorFlow ops is in an
180// experimental state. The generated API can change without notice.
181
182package op
183
184import tf "github.com/tensorflow/tensorflow/tensorflow/go"
185
186// optionalAttr is an intentionally un-exported type to hide
187// details of how optional attributes to operations are implemented.
188type optionalAttr map[string]interface{}
189
190func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, int, error) {
191	size, err := op.OutputListSize(output)
192	if err != nil {
193		return nil, start, err
194	}
195	list := make([]tf.Output, size)
196	for i := 0; i < size; i++ {
197		list[i] = op.Output(start + i)
198	}
199	return list, start + size, nil
200}
201`))
202
203	tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{
204		"MakeComment":       makeComment,
205		"GoType":            goType,
206		"CamelCase":         camelCase,
207		"Identifier":        identifier,
208		"IsListArg":         isListArg,
209		"IsListAttr":        isListAttr,
210		"StripLeadingColon": stripLeadingColon,
211	}).Parse(`
212{{if .OptionalAttrs -}}
213{{/* Type for specifying all optional attributes. */ -}}
214// {{.Op.Name}}Attr is an optional argument to {{.Op.Name}}.
215type {{.Op.Name}}Attr func(optionalAttr)
216
217{{range .OptionalAttrs}}
218// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value.
219{{- if .Description}}
220//
221// value: {{MakeComment .Description}}
222{{- end}}
223// If not specified, defaults to {{StripLeadingColon .DefaultValue}}
224{{- if .HasMinimum}}
225//
226// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
227{{- end}}
228func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
229	return func(m optionalAttr) {
230		m[{{printf "%q" .Name}}] = value
231	}
232}
233{{end}}
234{{end}}
235
236{{- /* Create a godoc friendly comment. */ -}}
237
238// {{MakeComment .APIDef.Summary}}
239
240{{- with .Op.Deprecation}}
241//
242// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}}
243{{- end -}}
244
245{{- with .APIDef.Description}}
246//
247// {{MakeComment .}}
248{{- end -}}
249
250{{- if .DescribeArguments}}
251//
252// Arguments:
253{{- range .InArgsReordered}}
254//	{{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
255{{- end -}}
256{{- range .RequiredAttrs}}
257//	{{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
258{{- end -}}
259{{- end -}}
260
261{{- if (not .Op.OutputArg) }}
262//
263// Returns the created operation.
264{{- else }}
265{{- if .DescribeOutputs}}
266//
267{{- if ((len .OutArgs) eq 1) }}
268// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}}
269{{- else }}
270// Returns:
271{{- range .OutArgs}}
272//	{{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}}
273{{- end -}}
274{{- end -}}
275{{- end -}}
276{{- end -}}
277{{- /*
278
279  The function signature.
280  Since OpDef.Name is in CamelCase, it cannot conflict with a reserved keyword in Golang
281*/}}
282func {{.Op.Name}}
283
284{{- /*
285  Fill in input arguments:
286  (1) The Scope
287  (2) All input arguments (which may be either []tf.Output or tf.Output)
288  (3) All required attributes
289  (4) Variadic list of optional attributes
290*/ -}}
291
292(scope *Scope
293{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}
294{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}}
295{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}}
296)
297
298{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}}
299
300{{if .OutArgs -}}
301({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}})
302{{- else -}}
303(o *tf.Operation)
304{{- end }} {
305	if scope.Err() != nil {
306		return
307	}
308	{{if .HasAttrs -}}
309	attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}}
310	{{if .OptionalAttrs -}}
311	for _, a := range optional {
312		a(attrs)
313	}
314	{{end -}}
315	{{end -}}
316	opspec := tf.OpSpec{
317		Type: {{printf "%q" .Op.Name}},
318		{{if .InArgs -}}
319		Input: []tf.Input{
320			{{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}}
321		},
322		{{- end}}
323		{{- if .HasAttrs}}
324		Attrs: attrs,
325		{{- end}}
326	}
327	{{- if .OutArgs}}
328	{{- if .HasListOutput}}
329	op := scope.AddOperation(opspec)
330	if scope.Err() != nil {
331		return
332	}
333	var idx int
334	var err error
335	{{- range $i, $a := .OutArgs}}
336	{{- if $a.IsListArg}}
337	if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
338		scope.UpdateErr({{printf "%q" $.Op.Name}}, err)
339		return
340	}
341	{{- else }}
342	{{Identifier .RenameTo}} = op.Output(idx)
343	{{- end }}{{- /* if IsListArg */}}
344	{{- end }}{{- /* range .OutArgs */}}
345	return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}}
346	{{- else }}
347	op := scope.AddOperation(opspec)
348	return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
349	{{- end }}{{- /* if .HasListOutput */}}
350	{{- else }}
351	return scope.AddOperation(opspec)
352	{{- end }}{{- /* if .OutArgs */}}
353}
354`))
355)
356
357type attrWrapper struct {
358	op  *pb.OpDef_AttrDef
359	api *pb.ApiDef_Attr
360}
361
362func (a *attrWrapper) Name() string              { return a.api.Name }
363func (a *attrWrapper) RenameTo() string          { return a.api.RenameTo }
364func (a *attrWrapper) Description() string       { return a.api.Description }
365func (a *attrWrapper) Type() string              { return a.op.Type }
366func (a *attrWrapper) IsListAttr() bool          { return isListAttr(a.op) }
367func (a *attrWrapper) HasMinimum() bool          { return a.op.HasMinimum }
368func (a *attrWrapper) Minimum() int64            { return a.op.Minimum }
369func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue }
370
371type argWrapper struct {
372	op  *pb.OpDef_ArgDef
373	api *pb.ApiDef_Arg
374}
375
376func (a *argWrapper) Name() string        { return a.api.Name }
377func (a *argWrapper) RenameTo() string    { return a.api.RenameTo }
378func (a *argWrapper) Description() string { return a.api.Description }
379func (a *argWrapper) IsListArg() bool     { return isListArg(a.op) }
380
381type tmplArgs struct {
382	Op     *pb.OpDef
383	APIDef *pb.ApiDef
384	// Op.Attr is split into two categories
385	// (1) Required: These must be specified by the client and are thus
386	//     included in the function signature.
387	// (2) Optional: These need not be specified (as they have default
388	//     values) and thus do not appear in the function signature.
389	RequiredAttrs []*attrWrapper
390	OptionalAttrs []*attrWrapper
391	InArgs        []*argWrapper
392	// Input arguments ordered based on arg_order field of ApiDef.
393	InArgsReordered []*argWrapper
394	OutArgs         []*argWrapper
395}
396
397func newTmplArgs(op *pb.OpDef, apidef *pb.ApiDef) (*tmplArgs, error) {
398	ret := tmplArgs{Op: op, APIDef: apidef}
399
400	// Setup InArgs field
401	for i, in := range op.InputArg {
402		argCombined := argWrapper{op: in, api: apidef.InArg[i]}
403		ret.InArgs = append(ret.InArgs, &argCombined)
404	}
405
406	// Setup OutArgs field
407	for i, out := range op.OutputArg {
408		argCombined := argWrapper{op: out, api: apidef.OutArg[i]}
409		ret.OutArgs = append(ret.OutArgs, &argCombined)
410	}
411
412	// Setup InArgsReordered field
413	for _, argName := range apidef.ArgOrder {
414		// Find the argument in op.InputArg
415		argIndex := -1
416		for i, in := range op.InputArg {
417			if in.Name == argName {
418				argIndex = i
419				break
420			}
421		}
422		if argIndex == -1 {
423			return nil, fmt.Errorf(
424				"couldn't find argument %s in ApiDef for op %s",
425				argName, op.Name)
426		}
427		argCombined := argWrapper{
428			op: op.InputArg[argIndex], api: apidef.InArg[argIndex]}
429		ret.InArgsReordered = append(ret.InArgsReordered, &argCombined)
430	}
431
432	if len(op.Attr) == 0 {
433		return &ret, nil
434	}
435	// Attributes related to the InputArg's type are inferred automatically
436	// and are not exposed to the client.
437	inferred := make(map[string]bool)
438	for _, in := range op.InputArg {
439		switch {
440		case in.TypeAttr != "":
441			inferred[in.TypeAttr] = true
442		case in.TypeListAttr != "":
443			inferred[in.TypeListAttr] = true
444		}
445		if in.NumberAttr != "" {
446			inferred[in.NumberAttr] = true
447		}
448	}
449	for i, attr := range op.Attr {
450		if inferred[attr.Name] {
451			continue
452		}
453		attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]}
454		if attr.DefaultValue == nil {
455			ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined)
456		} else {
457			ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined)
458		}
459	}
460	return &ret, nil
461}
462
463func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 }
464func (a *tmplArgs) DescribeArguments() bool {
465	for _, arg := range a.InArgs {
466		if arg.Description() != "" {
467			return true
468		}
469	}
470	for _, attr := range a.RequiredAttrs {
471		if attr.Description() != "" {
472			return true
473		}
474	}
475	return false
476
477}
478func (a *tmplArgs) DescribeOutputs() bool {
479	for _, arg := range a.OutArgs {
480		if arg.Description() != "" {
481			return true
482		}
483	}
484	return false
485}
486func (a *tmplArgs) HasListOutput() bool {
487	for _, arg := range a.OutArgs {
488		if arg.IsListArg() {
489			return true
490		}
491	}
492	return false
493}
494
495func makeComment(lines string) string {
496	return strings.Join(strings.SplitAfter(lines, "\n"), "// ")
497}
498
499// goType converts a TensorFlow "type" ('string', 'int', 'list(string)' etc.)
500// to the corresponding type in Go.
501func goType(tfType string) (string, error) {
502	list, tfType := parseTFType(tfType)
503	var gotype string
504	switch tfType {
505	case "int":
506		gotype = "int64"
507	case "float":
508		gotype = "float32"
509	case "bool":
510		gotype = "bool"
511	case "type":
512		gotype = "tf.DataType"
513	case "shape":
514		gotype = "tf.Shape"
515	case "tensor":
516		gotype = "tf.Tensor"
517	case "string":
518		gotype = "string"
519	default:
520		return "", fmt.Errorf("%q is not a recognized DataType", tfType)
521	}
522	if list {
523		gotype = "[]" + gotype
524	}
525	return gotype, nil
526}
527
528func camelCase(snakeCase string) string {
529	words := strings.Split(snakeCase, "_")
530	for i, w := range words {
531		words[i] = strings.ToUpper(string(w[0])) + w[1:]
532	}
533	return strings.Join(words, "")
534}
535
536// identifier creates an identifier for s usable in the generated Go source
537// code.
538//
539// Avoids collisions with keywords and other identifiers used in the generated
540// code.
541func identifier(s string) string {
542	// Identifiers used in the generated code.
543	if s == "tf" || s == "scope" || s == "err" || s == "op" {
544		return s + "_"
545	}
546	for _, k := range keywords {
547		if s == k {
548			// Alternatively, make the first letter upper case.
549			return s + "_"
550		}
551	}
552	return s
553}
554
555func isListArg(argdef *pb.OpDef_ArgDef) bool {
556	return argdef.TypeListAttr != "" || argdef.NumberAttr != ""
557}
558
559func isListAttr(attrdef *pb.OpDef_AttrDef) bool {
560	list, _ := parseTFType(attrdef.Type)
561	return list
562}
563
564// stripLeadingColon removes the prefix of the string up to the first colon.
565//
566// This is useful when 's' corresponds to a "oneof" protocol buffer message.
567// For example, consider the protocol buffer message:
568//   oneof value { bool b = 1;  int64 i = 2; }
569// String() on a Go corresponding object (using proto.CompactTextString) will
570// print "b:true", or "i:7" etc. This function strips out the leading "b:" or
571// "i:".
572func stripLeadingColon(s fmt.Stringer) string {
573	x := s.String()
574	y := strings.SplitN(x, ":", 2)
575	if len(y) < 2 {
576		return x
577	}
578	return y[1]
579}
580
581func parseTFType(tfType string) (list bool, typ string) {
582	const (
583		listPrefix = "list("
584		listSuffix = ")"
585	)
586	if strings.HasPrefix(tfType, listPrefix) && strings.HasSuffix(tfType, listSuffix) {
587		return true, strings.TrimSuffix(strings.TrimPrefix(tfType, listPrefix), listSuffix)
588	}
589	return false, tfType
590}
591