1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package tensorflow
18
19// #include <stdlib.h>
20// #include "tensorflow/c/c_api.h"
21import "C"
22
23import "unsafe"
24
25// Operation that has been added to the graph.
26type Operation struct {
27	c *C.TF_Operation
28	// A reference to the Graph to prevent it from
29	// being GCed while the Operation is still alive.
30	g *Graph
31}
32
33// Name returns the name of the operation.
34func (op *Operation) Name() string {
35	return C.GoString(C.TF_OperationName(op.c))
36}
37
38// Type returns the name of the operator used by this operation.
39func (op *Operation) Type() string {
40	return C.GoString(C.TF_OperationOpType(op.c))
41}
42
43// NumOutputs returns the number of outputs of op.
44func (op *Operation) NumOutputs() int {
45	return int(C.TF_OperationNumOutputs(op.c))
46}
47
48// Device returns a specification of the device on which this operation
49// will be executed, or the empty string if there is no such specification.
50func (op *Operation) Device() string {
51	return C.GoString(C.TF_OperationDevice(op.c))
52}
53
54// OutputListSize returns the size of the list of Outputs that is produced by a
55// named output of op.
56//
57// An Operation has multiple named outputs, each of which produces either
58// a single tensor or a list of tensors. This method returns the size of
59// the list of tensors for a specific output of the operation, identified
60// by its name.
61func (op *Operation) OutputListSize(output string) (int, error) {
62	cname := C.CString(output)
63	defer C.free(unsafe.Pointer(cname))
64	status := newStatus()
65	n := C.TF_OperationOutputListLength(op.c, cname, status.c)
66	return int(n), status.Err()
67}
68
69// Output returns the i-th output of op.
70func (op *Operation) Output(i int) Output {
71	return Output{op, i}
72}
73
74// NumInputs returns the number of inputs of op.
75func (op *Operation) NumInputs() int {
76	return int(C.TF_OperationNumInputs(op.c))
77}
78
79// Output represents one of the outputs of an operation in the graph. Has a
80// DataType (and eventually a Shape).  May be passed as an input argument to a
81// function for adding operations to a graph, or to a Session's Run() method to
82// fetch that output as a tensor.
83type Output struct {
84	// Op is the Operation that produces this Output.
85	Op *Operation
86
87	// Index specifies the index of the output within the Operation.
88	Index int
89}
90
91// DataType returns the type of elements in the tensor produced by p.
92func (p Output) DataType() DataType {
93	return DataType(C.TF_OperationOutputType(p.c()))
94}
95
96// Shape returns the (possibly incomplete) shape of the tensor produced p.
97func (p Output) Shape() Shape {
98	status := newStatus()
99	port := p.c()
100	ndims := C.TF_GraphGetTensorNumDims(p.Op.g.c, port, status.c)
101	if err := status.Err(); err != nil {
102		// This should not be possible since an error only occurs if
103		// the operation does not belong to the graph.  It should not
104		// be possible to construct such an Operation object.
105		return Shape{}
106	}
107	if ndims < 0 {
108		return Shape{}
109	}
110	if ndims == 0 {
111		return ScalarShape()
112	}
113	dims := make([]C.int64_t, ndims)
114	C.TF_GraphGetTensorShape(p.Op.g.c, port, &dims[0], ndims, status.c)
115	if err := status.Err(); err != nil {
116		// Same as above, should not be possible.
117		return Shape{}
118	}
119	ret := Shape{dims: make([]int64, ndims)}
120	for i := 0; i < int(ndims); i++ {
121		ret.dims[i] = int64(dims[i])
122	}
123	return ret
124}
125
126func (p Output) c() C.TF_Output {
127	if p.Op == nil {
128		// Attempt to provide a more useful panic message than "nil
129		// pointer dereference".
130		panic("nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.")
131	}
132	return C.TF_Output{oper: p.Op.c, index: C.int(p.Index)}
133}
134
135func (p Output) canBeAnInput() {}
136
137// Consumers returns the inputs that consume this output.
138func (p Output) Consumers() []Consumer {
139	max := int(C.TF_OperationOutputNumConsumers(p.c()))
140	if max == 0 {
141		return nil
142	}
143	inputs := make([]C.TF_Input, max)
144	n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max))
145	inputs = inputs[:int(n)]
146
147	var consumers []Consumer
148	for _, consumer := range inputs {
149		consumers = append(consumers, Consumer{
150			Index: int(consumer.index),
151			Op: &Operation{
152				c: consumer.oper,
153				g: p.Op.g,
154			},
155		})
156	}
157
158	return consumers
159}
160
161// Consumer identifies a specific input of an operation that consumes the output
162// of another operation.
163type Consumer struct {
164	// Op is the Operation that is consuming the output of another operation.
165	Op *Operation
166
167	// Index is the index of the input within Op that the output of another
168	// operation is connected to.
169	Index int
170}
171
172func (p Consumer) c() C.TF_Input {
173	if p.Op == nil {
174		// Attempt to provide a more useful panic message than "nil
175		// pointer dereference".
176		panic("nil-Operation. Consumer objects should only be created by a call to Output.Consumers")
177	}
178	return C.TF_Input{oper: p.Op.c, index: C.int(p.Index)}
179}
180
181// DataType returns the type of the input.
182func (p Consumer) DataType() DataType {
183	return DataType(C.TF_OperationInputType(p.c()))
184}
185
186// Producer returns the Output that is connected to this Consumer.
187func (p Consumer) Producer() Output {
188	output := C.TF_OperationInput(p.c())
189	return Output{
190		Op: &Operation{
191			c: output.oper,
192			g: p.Op.g,
193		},
194		Index: int(output.index),
195	}
196}
197
198// Input is the interface for specifying inputs to an operation being added to
199// a Graph.
200//
201// Operations can have multiple inputs, each of which could be either a tensor
202// produced by another operation (an Output object), or a list of tensors
203// produced by other operations (an OutputList). Thus, this interface is
204// implemented by both Output and OutputList.
205//
206// See OpSpec.Input for more information.
207type Input interface {
208	// Unexported to preclude implementations outside this package.
209	canBeAnInput()
210}
211
212// OutputList represents a list of Outputs that can be provided as input to
213// another operation.
214type OutputList []Output
215
216func (l OutputList) canBeAnInput() {}
217