1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
17 
18 #include <array>
19 
20 #include "tensorflow/core/framework/shape_inference.h"
21 #include "tensorflow/core/util/padding.h"
22 #include "tensorflow/core/util/tensor_format.h"
23 
24 namespace tensorflow {
25 
26 namespace shape_inference {
27 
28 // Like GetWindowedOutputSize, but deals with DimensionHandles. Does not support
29 // EXPLICIT padding.
30 Status GetWindowedOutputSizeFromDims(InferenceContext* c,
31                                      DimensionHandle input_size,
32                                      DimensionOrConstant filter_size,
33                                      int64 stride, Padding padding_type,
34                                      DimensionHandle* output_size);
35 
36 // The V2 version computes the same outputs with arbitrary dilation_rate, and
37 // supports EXPLICIT padding. For detailed equations, refer to the comments
38 // for GetWindowedOutputSizeV2(). The 'padding_before' and 'padding_after'
39 // parameters are only used if padding_type == EXPLICIT.
40 Status GetWindowedOutputSizeFromDimsV2(
41     InferenceContext* c, DimensionHandle input_size,
42     DimensionOrConstant filter_size, int64 dilation_rate, int64 stride,
43     Padding padding_type, int64 padding_before, int64 padding_after,
44     DimensionHandle* output_size);
45 
46 // Transfers shape of input(0) to output(0).
47 Status UnchangedShape(shape_inference::InferenceContext* c);
48 
49 // Transfers shape of input(0) to output(0), after asserting its rank is <rank>.
UnchangedShapeWithRank(shape_inference::InferenceContext * c,int32 rank)50 inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
51                                      int32 rank) {
52   ShapeHandle out;
53   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out));
54   c->set_output(0, out);
55   return Status::OK();
56 }
57 
58 // Transfers shape of input(0) to output(0), after asserting its rank >= <rank>.
UnchangedShapeWithRankAtLeast(shape_inference::InferenceContext * c,int32 rank)59 inline Status UnchangedShapeWithRankAtLeast(
60     shape_inference::InferenceContext* c, int32 rank) {
61   ShapeHandle out;
62   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out));
63   c->set_output(0, out);
64   return Status::OK();
65 }
66 
67 // Transfers shape of input(0) to output(0), after asserting its rank <= <rank>.
UnchangedShapeWithRankAtMost(shape_inference::InferenceContext * c,int32 rank)68 inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c,
69                                            int32 rank) {
70   ShapeHandle out;
71   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out));
72   c->set_output(0, out);
73   return Status::OK();
74 }
75 
76 // Shape function for use with ops no outputs.
NoOutputs(shape_inference::InferenceContext * c)77 inline Status NoOutputs(shape_inference::InferenceContext* c) {
78   return Status::OK();
79 }
80 
81 // Shape function for ops that output a single scalar value.
ScalarShape(shape_inference::InferenceContext * c)82 inline Status ScalarShape(shape_inference::InferenceContext* c) {
83   c->set_output(0, c->Scalar());
84   return Status::OK();
85 }
86 
87 // Shape function for binary ops where both inputs and the output match.
MergeBothInputsShapeFn(InferenceContext * c)88 inline Status MergeBothInputsShapeFn(InferenceContext* c) {
89   ShapeHandle out;
90   TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
91   c->set_output(0, out);
92   return Status::OK();
93 }
94 
95 // Shape function for dataset iterators.
96 Status DatasetIteratorShape(shape_inference::InferenceContext* c);
97 
98 // Returns a new shape with the specified dims arranged in the specified
99 // format. The returned value is owned by this context.
100 // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth.
101 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
102                            const std::vector<DimensionOrConstant>& spatial,
103                            DimensionOrConstant C, ShapeHandle* out,
104                            shape_inference::InferenceContext* context);
105 
106 // Shape function for MatMul-like operations.
107 Status MatMulShape(shape_inference::InferenceContext* c);
108 
109 // Shape function for Batched MatMul-like operations with broadcasting across
110 // batch dimensions.
111 Status BatchMatMulV2Shape(shape_inference::InferenceContext* c);
112 
113 // Shape function for BatchMatMul-like operations
114 Status BatchMatMulShape(shape_inference::InferenceContext* c);
115 
116 // Shape function for Einsum.
117 Status EinsumShape(shape_inference::InferenceContext* c);
118 
119 // Shape function for BiasAdd-like operations.
120 Status BiasAddShape(shape_inference::InferenceContext* c);
121 
122 // Shape function for BiasAddGrad-like operations.
123 Status BiasAddGradShape(shape_inference::InferenceContext* c);
124 
125 // Shape function for Conv2D-like operations that support explicit padding.
126 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c);
127 
128 // Shape function for Conv2D-like operations that do not support explicit
129 // padding.
130 Status Conv2DShape(shape_inference::InferenceContext* c);
131 
132 // Shape function for Conv3D-like operations.
133 Status Conv3DShape(shape_inference::InferenceContext* c);
134 
135 // Shape function for DepthwiseConv2D-like operations that support explicit
136 // padding.
137 Status DepthwiseConv2DNativeShapeWithExplicitPadding(
138     shape_inference::InferenceContext* c);
139 
140 // Shape function for DepthwiseConv2D-like operations that do not support
141 // explicit padding.
142 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
143 
144 // Shape function for Conv2DBackpropInput.
145 Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c);
146 
147 // Shape function for AvgPool-like operations.
148 Status AvgPoolShape(shape_inference::InferenceContext* c);
149 
150 // Shape function for AvgPoolGrad-like operations.
151 Status AvgPoolGradShape(shape_inference::InferenceContext* c);
152 
153 // Shape function for FusedBatchNorm and FusedBatchNormV2 operations.
154 Status FusedBatchNormShape(shape_inference::InferenceContext* c);
155 
156 // Shape function for FusedBatchNormV3 operations.
157 Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c);
158 
159 // Shape function for _FusedBatchNormEx operations.
160 Status FusedBatchNormExShape(shape_inference::InferenceContext* c);
161 
162 // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
163 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
164 
165 // Shape function for MatrixDiagPartV2 and MatrixDiagPartV3 operations.
166 Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c);
167 
168 // Shape function for MatrixDiagV2 and MatrixDiagV3 operations.
169 Status MatrixDiagV2Shape(shape_inference::InferenceContext* c);
170 
171 // Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations.
172 Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c);
173 
174 // Shape function for MaxPool-like operations that support explicit padding.
175 Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c);
176 
177 // Shape function for MaxPool-like operations that do not support explicit
178 // padding.
179 Status MaxPoolShape(shape_inference::InferenceContext* c);
180 
181 // Shape function for MaxPoolV2-like operations.
182 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
183 
184 // Shape function for MaxPoolGrad-like operations.
185 Status MaxPoolGradShape(shape_inference::InferenceContext* c);
186 
187 // Shape function for 3D Pooling operations.
188 Status Pool3DShape(shape_inference::InferenceContext* c);
189 
190 // Shape function for MaxPool3DGrad-like operations.
191 Status MaxPool3DGradShape(shape_inference::InferenceContext* c);
192 
193 // Shape function for AvgPool3DGrad-like operations.
194 Status AvgPool3DGradShape(shape_inference::InferenceContext* c);
195 
196 // Shape function for use with ops whose output shapes are unknown.
197 Status UnknownShape(shape_inference::InferenceContext* c);
198 
199 // Shape function for reduction operations.
200 Status ReductionShape(shape_inference::InferenceContext* c);
201 
202 // Shape function for unsorted segment operations.
203 Status UnsortedSegmentReductionShapeFn(InferenceContext* c);
204 
205 // Shape function for concat operations.
206 // <num_inputs_to_concat> is the number of inputs to concatenate and are taken
207 // from inputs
208 // [1,num_inputs_to_concat] of the op.  Input 0 is the concat_dim input.
209 Status ConcatShape(shape_inference::InferenceContext* c,
210                    int num_inputs_to_concat);
211 
212 // Shape function for concat operations.
213 Status ConcatV2Shape(shape_inference::InferenceContext* c);
214 
215 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
216 
217 // Shape function for binary operators that broadcast their inputs
218 // and with output to output_index.
219 // Note: out cannot be NULL.
220 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
221                                             ShapeHandle shape_x,
222                                             ShapeHandle shape_y,
223                                             bool incompatible_shape_error,
224                                             ShapeHandle* out);
225 
226 // Shape function for binary operators that broadcast their inputs
227 // and with output to output_index.
BroadcastBinaryOpOutputShapeFn(InferenceContext * c,int output_index)228 inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
229                                              int output_index) {
230   ShapeHandle out;
231   TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
232       c, c->input(0), c->input(1), true, &out));
233   c->set_output(output_index, out);
234   return Status::OK();
235 }
236 
237 // Shape function for binary operators that broadcast their inputs.
238 // Tested by ops/math_ops_test.cc.
BroadcastBinaryOpShapeFn(InferenceContext * c)239 inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
240   return BroadcastBinaryOpOutputShapeFn(c, 0);
241 }
242 
243 // Shape function for random operations.
244 Status RandomShape(shape_inference::InferenceContext* c);
245 
246 // Shape function for Slice operations.
247 Status SliceShape(shape_inference::InferenceContext* c);
248 
249 // Validates the 3 component tensors of a sparse tensor have the proper
250 // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
251 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
252                             ShapeHandle values_shape, ShapeHandle shape_shape);
253 
254 Status ValidateVariableResourceHandle(
255     InferenceContext* c, std::vector<ShapeAndType>* shape_and_type);
256 
257 // Shape function for GatherNd operations.
258 Status GatherNdShape(InferenceContext* c);
259 
260 // Helper shape function for ScatterNd.../TensorScatter... operations.
261 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
262                             ShapeHandle updates_shape, ShapeHandle input_shape);
263 
264 // Shape function for ops with an explicit "shape" attribute.
265 Status ExplicitShape(InferenceContext* c);
266 
267 // Shape function for multiple-output ops with an explicit "shapes" attribute.
268 Status ExplicitShapes(InferenceContext* c);
269 
270 // Shape function for SparseReduceMax and SparseReduceSum.
271 Status SparseReduceShapeFn(InferenceContext* c);
272 
273 }  // namespace shape_inference
274 
275 }  // namespace tensorflow
276 
277 #endif  // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
278