1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package tensorflow
18
19// #include <stdlib.h>
20// #include "tensorflow/c/c_api.h"
21import "C"
22
23import "unsafe"
24
25// Operation that has been added to the graph.
26type Operation struct {
27	c *C.TF_Operation
28	// A reference to the Graph to prevent it from
29	// being GCed while the Operation is still alive.
30	g *Graph
31}
32
33// Name returns the name of the operation.
34func (op *Operation) Name() string {
35	return C.GoString(C.TF_OperationName(op.c))
36}
37
38// Type returns the name of the operator used by this operation.
39func (op *Operation) Type() string {
40	return C.GoString(C.TF_OperationOpType(op.c))
41}
42
43// NumOutputs returns the number of outputs of op.
44func (op *Operation) NumOutputs() int {
45	return int(C.TF_OperationNumOutputs(op.c))
46}
47
48// OutputListSize returns the size of the list of Outputs that is produced by a
49// named output of op.
50//
51// An Operation has multiple named outputs, each of which produces either
52// a single tensor or a list of tensors. This method returns the size of
53// the list of tensors for a specific output of the operation, identified
54// by its name.
55func (op *Operation) OutputListSize(output string) (int, error) {
56	cname := C.CString(output)
57	defer C.free(unsafe.Pointer(cname))
58	status := newStatus()
59	n := C.TF_OperationOutputListLength(op.c, cname, status.c)
60	return int(n), status.Err()
61}
62
63// Output returns the i-th output of op.
64func (op *Operation) Output(i int) Output {
65	return Output{op, i}
66}
67
68// Output represents one of the outputs of an operation in the graph. Has a
69// DataType (and eventually a Shape).  May be passed as an input argument to a
70// function for adding operations to a graph, or to a Session's Run() method to
71// fetch that output as a tensor.
72type Output struct {
73	// Op is the Operation that produces this Output.
74	Op *Operation
75
76	// Index specifies the index of the output within the Operation.
77	Index int
78}
79
80// DataType returns the type of elements in the tensor produced by p.
81func (p Output) DataType() DataType {
82	return DataType(C.TF_OperationOutputType(p.c()))
83}
84
85// Shape returns the (possibly incomplete) shape of the tensor produced p.
86func (p Output) Shape() Shape {
87	status := newStatus()
88	port := p.c()
89	ndims := C.TF_GraphGetTensorNumDims(p.Op.g.c, port, status.c)
90	if err := status.Err(); err != nil {
91		// This should not be possible since an error only occurs if
92		// the operation does not belong to the graph.  It should not
93		// be possible to construct such an Operation object.
94		return Shape{}
95	}
96	if ndims < 0 {
97		return Shape{}
98	}
99	if ndims == 0 {
100		return ScalarShape()
101	}
102	dims := make([]C.int64_t, ndims)
103	C.TF_GraphGetTensorShape(p.Op.g.c, port, &dims[0], ndims, status.c)
104	if err := status.Err(); err != nil {
105		// Same as above, should not be possible.
106		return Shape{}
107	}
108	ret := Shape{dims: make([]int64, ndims)}
109	for i := 0; i < int(ndims); i++ {
110		ret.dims[i] = int64(dims[i])
111	}
112	return ret
113}
114
115func (p Output) c() C.TF_Output {
116	if p.Op == nil {
117		// Attempt to provide a more useful panic message than "nil
118		// pointer dereference".
119		panic("nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.")
120	}
121	return C.TF_Output{oper: p.Op.c, index: C.int(p.Index)}
122}
123
124func (p Output) canBeAnInput() {}
125
126// Input is the interface for specifying inputs to an operation being added to
127// a Graph.
128//
129// Operations can have multiple inputs, each of which could be either a tensor
130// produced by another operation (an Output object), or a list of tensors
131// produced by other operations (an OutputList). Thus, this interface is
132// implemented by both Output and OutputList.
133//
134// See OpSpec.Input for more information.
135type Input interface {
136	// Unexported to preclude implementations outside this package.
137	canBeAnInput()
138}
139
140// OutputList represents a list of Outputs that can be provided as input to
141// another operation.
142type OutputList []Output
143
144func (l OutputList) canBeAnInput() {}
145