1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
17 #define TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
18 
19 #include <array>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/lib/gtl/inlined_vector.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 
28 // Tensor format for input/output activations used in convolution operations.
29 // The mnemonics specify the meaning of each tensor dimension sorted from
30 // largest to smallest memory stride.
31 // N = Batch, H = Image Height, W = Image Width, C = Number of Channels.
32 // TODO(pauldonnelly): It would probably be better to switch to a registration
33 // process for tensor formats, so specialized formats could be defined more
34 // locally to where they are used.
35 enum TensorFormat {
36   // FORMAT_NHWC is the default format in TensorFlow.
37   FORMAT_NHWC = 0,
38 
39   // FORMAT_NCHW often improves performance on GPUs.
40   FORMAT_NCHW = 1,
41 
42   // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized
43   // int8 convolution and fused convolution. It is laid out in the same order
44   // as NCHW, except that the size of the Channels dimension is divided by 4,
45   // and a new dimension of size 4 is appended, which packs 4 adjacent channel
46   // activations for the same pixel into an int32. Thus an NCHW format tensor
47   // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in
48   // NCHW_VECT_C format.
49   // A pre-condition of this format is that C must be a multiple of 4.
50   FORMAT_NCHW_VECT_C = 2,
51 
52   // Similar to NHWC, but the size of the W dimension is divided by 4, and a
53   // new dimension of size 4 is appended, which packs 4 adjacent activations
54   // in the width dimension.
55   FORMAT_NHWC_VECT_W = 3,
56 
57   // Note: although the current code in this file assumes VECT_C and VECT_W
58   // enums imply int8x4 vectors, this should not be relied upon.
59   // In the future we may change the meaning of these enums to include vectors
60   // of other types such as int16x2, with op implementations automatically
61   // determining which format is implied based on the datatype.
62 
63   // FORMAT_HWNC is for TPUs.
64   FORMAT_HWNC = 4,
65 
66   // FORMAT_HWCN is for TPUs.
67   FORMAT_HWCN = 5,
68 };
69 
70 // Tensor format for convolutional filters.
71 // The mnemonics specify the meaning of each tensor dimension sorted
72 // from largest to smallest memory stride.
73 // H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels.
74 // Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'.
75 enum FilterTensorFormat {
76   // FORMAT_HWIO is the default filter format in TensorFlow.
77   // Ops that do not have a 'filter_format' attribute will assume this format.
78   FORMAT_HWIO = 0,
79 
80   // FORMAT_OIHW often improves performance on GPUs.
81   FORMAT_OIHW = 1,
82 
83   // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
84   // int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C
85   // data format. It is laid out in the same order as OIHW, except that the size
86   // of the Input Channels dimension is divided by 4, and a new dimension of
87   // size 4 is appended, which packs 4 adjacent input channel weights into an
88   // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
89   // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
90   // A pre-condition of this format is that I must be a multiple of 4.
91   FORMAT_OIHW_VECT_I = 2,
92 };
93 
94 // Parse tensor format from the given string.
95 // Return true if the parsing succeeds, and false if it fails.
96 bool FormatFromString(const string& format_str, TensorFormat* format);
97 
98 // Parse tensor format from the given string.
99 // Return true if the parsing succeeds, and false if it fails.
100 bool FilterFormatFromString(const string& format_str,
101                             FilterTensorFormat* format);
102 
103 // Convert a tensor format into string.
104 string ToString(TensorFormat format);
105 
106 // Convert a filter tensor format into string.
107 string ToString(FilterTensorFormat format);
108 
109 // Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor
110 // format 'format'.
GetTensorSpatialDims(int num_dims,TensorFormat format)111 inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
112   switch (format) {
113     case FORMAT_NHWC:
114     case FORMAT_NCHW:
115     case FORMAT_HWNC:
116     case FORMAT_HWCN:
117       return num_dims - 2;  // Exclude N,C.
118     case FORMAT_NCHW_VECT_C:
119     case FORMAT_NHWC_VECT_W:
120       // Note: the VECT_W is not counted as an independent spatial dim here,
121       // since it just a component of the width dimension.
122       return num_dims - 3;  // Exclude N,C,VectDim.
123   }
124 }
125 
GetFilterTensorSpatialDims(int num_dims,FilterTensorFormat format)126 inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) {
127   if (format == FORMAT_OIHW_VECT_I) {
128     return num_dims - 3;  // Exclude O,I,InnerI.
129   } else {
130     return num_dims - 2;  // Exclude O,I.
131   }
132 }
133 
134 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
135 // tensor format 'format'. This is the inverse of GetTensorSpatialDims.
GetTensorDimsFromSpatialDims(int num_spatial_dims,TensorFormat format)136 inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
137                                         TensorFormat format) {
138   switch (format) {
139     case FORMAT_NHWC:
140     case FORMAT_NCHW:
141     case FORMAT_HWNC:
142     case FORMAT_HWCN:
143       return num_spatial_dims + 2;  // Include N,C.
144     case FORMAT_NCHW_VECT_C:
145     case FORMAT_NHWC_VECT_W:
146       return num_spatial_dims + 3;  // Include N,C,VectDim.
147   }
148 }
149 
150 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
151 // filter tensor format 'format'.
GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,FilterTensorFormat format)152 inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,
153                                               FilterTensorFormat format) {
154   if (format == FORMAT_OIHW_VECT_I) {
155     return num_spatial_dims + 3;  // Include O,I,InnerI.
156   } else {
157     return num_spatial_dims + 2;  // Include O,I.
158   }
159 }
160 
161 // Returns the index of the batch dimension.
GetTensorBatchDimIndex(int num_dims,TensorFormat format)162 inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
163   switch (format) {
164     case FORMAT_NHWC:
165     case FORMAT_NCHW:
166     case FORMAT_NCHW_VECT_C:
167     case FORMAT_NHWC_VECT_W:
168       return 0;
169     case FORMAT_HWNC:
170       return num_dims - 2;
171     case FORMAT_HWCN:
172       return num_dims - 1;
173     default:
174       LOG(FATAL) << "Unknown format " << format;
175       return -1;  // Avoid compiler warning about missing return value
176   }
177 }
178 
179 // Returns the index of the feature dimension. If format is NCHW_VECT_C, returns
180 // the index of the outer feature dimension (i.e. dimension 1, whose size would
181 // be num_features / 4 in this case).
GetTensorFeatureDimIndex(int num_dims,TensorFormat format)182 inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
183   switch (format) {
184     case FORMAT_NHWC:
185     case FORMAT_HWNC:
186       return num_dims - 1;
187     case FORMAT_NHWC_VECT_W:
188     case FORMAT_HWCN:
189       return num_dims - 2;
190     case FORMAT_NCHW:
191     case FORMAT_NCHW_VECT_C:
192       return 1;
193     default:
194       LOG(FATAL) << "Unknown format " << format;
195       return -1;  // Avoid compiler warning about missing return value
196   }
197 }
198 
199 // Returns the index of the inner feature dimension.
GetTensorInnerFeatureDimIndex(int num_dims,TensorFormat format)200 inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) {
201   DCHECK_EQ(format, FORMAT_NCHW_VECT_C);
202   return num_dims - 1;
203 }
204 
205 // Returns the index of the inner width dimension.
GetTensorInnerWidthDimIndex(int num_dims,TensorFormat format)206 inline int GetTensorInnerWidthDimIndex(int num_dims, TensorFormat format) {
207   DCHECK_EQ(format, FORMAT_NHWC_VECT_W);
208   return num_dims - 1;
209 }
210 
211 // Returns the dimension index of the specified 'spatial_dim' within an
212 // activation tensor. If format is NHWC_VECT_W and spatial_dim is 1, returns
213 // the index of the outer width dimension (i.e. dimension 2, whose size would
214 // be width / 4 in this case).
GetTensorSpatialDimIndex(int num_dims,TensorFormat format,int spatial_dim)215 inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
216                                     int spatial_dim) {
217   CHECK(spatial_dim >= 0 &&
218         spatial_dim < GetTensorSpatialDims(num_dims, format))
219       << spatial_dim << " " << num_dims << " " << ToString(format);
220   switch (format) {
221     case FORMAT_NHWC:
222     case FORMAT_NHWC_VECT_W:
223       return spatial_dim + 1;
224     case FORMAT_NCHW:
225     case FORMAT_NCHW_VECT_C:
226       return spatial_dim + 2;
227     case FORMAT_HWNC:
228     case FORMAT_HWCN:
229       return spatial_dim;
230     default:
231       LOG(FATAL) << "Unknown format " << format;
232       return -1;  // Avoid compiler warning about missing return value
233   }
234 }
235 
GetFilterTensorSpatialDimIndex(int num_dims,FilterTensorFormat format,int dim)236 inline int GetFilterTensorSpatialDimIndex(int num_dims,
237                                           FilterTensorFormat format, int dim) {
238   CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format))
239       << dim << " " << num_dims << " " << ToString(format);
240   switch (format) {
241     case FORMAT_HWIO:
242       return dim;
243     case FORMAT_OIHW:
244     case FORMAT_OIHW_VECT_I:
245       return dim + 2;
246     default:
247       LOG(FATAL) << "Unknown format " << format;
248       return -1;  // Avoid compiler warning about missing return value
249   }
250 }
251 
252 // Returns the index of the inner input channels dimension.
GetFilterTensorInnerInputChannelsDimIndex(int num_dims,FilterTensorFormat format)253 inline int GetFilterTensorInnerInputChannelsDimIndex(
254     int num_dims, FilterTensorFormat format) {
255   DCHECK_EQ(format, FORMAT_OIHW_VECT_I);
256   return num_dims - 1;
257 }
258 
259 // Returns the index of the input channels dimension.
260 // If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the
261 // outer input channel (i.e. 1), which holds num_input_channels / 4.
GetFilterTensorInputChannelsDimIndex(int num_dims,FilterTensorFormat format)262 inline int GetFilterTensorInputChannelsDimIndex(int num_dims,
263                                                 FilterTensorFormat format) {
264   switch (format) {
265     case FORMAT_HWIO:
266       return num_dims - 2;
267     case FORMAT_OIHW:
268     case FORMAT_OIHW_VECT_I:
269       return 1;
270     default:
271       LOG(FATAL) << "Unknown format " << format;
272       return -1;  // Avoid compiler warning about missing return value
273   }
274 }
275 
276 // Returns the index of the output channels dimension.
GetFilterTensorOutputChannelsDimIndex(int num_dims,FilterTensorFormat format)277 inline int GetFilterTensorOutputChannelsDimIndex(int num_dims,
278                                                  FilterTensorFormat format) {
279   switch (format) {
280     case FORMAT_HWIO:
281       return num_dims - 1;
282     case FORMAT_OIHW:
283     case FORMAT_OIHW_VECT_I:
284       return 0;
285     default:
286       LOG(FATAL) << "Unknown format " << format;
287       return -1;  // Avoid compiler warning about missing return value
288   }
289 }
290 
291 // TODO(pauldonnelly): Replace these tensor dimension index functions with
292 // constant structs to improve performance and reduce code size in Compute()
293 // functions.
294 
295 // Return the dimension index for the specified 'dimension' of the specified
296 // data 'tensor_format'.  'dimension' is a char that can be 'N' (batch size),
297 // 'C' (channels), 'H' (height), 'W' (width),  or a numbered spatial dimension:
298 // '0',  .. (NUM_SPATIAL_DIMS-1)..
299 // If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of
300 // the outer channel dimension (i.e. 1).
301 template <int NUM_SPATIAL_DIMS>
GetTensorDimIndex(TensorFormat format,char dimension)302 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
303   if (format == FORMAT_NHWC || format == FORMAT_NHWC_VECT_W) {
304     // clang-format off
305     switch (dimension) {
306       case 'N': return 0;
307       case '0': return 1;
308       case '1': return 2;
309       case '2': return 3;
310       case 'H': return NUM_SPATIAL_DIMS - 1;
311       case 'W': return NUM_SPATIAL_DIMS;
312       case 'C': return NUM_SPATIAL_DIMS + 1;
313       default:
314         LOG(FATAL) << "Invalid dimension: " << dimension;
315         return -1;  // Avoid compiler warning about missing return value
316     }
317   } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) {
318     switch (dimension) {
319       case 'N': return 0;
320       case 'C': return 1;
321       case '0': return 2;
322       case '1': return 3;
323       case '2': return 4;
324       case 'H': return NUM_SPATIAL_DIMS;
325       case 'W': return NUM_SPATIAL_DIMS + 1;
326       default:
327         LOG(FATAL) << "Invalid dimension: " << dimension;
328         return -1;  // Avoid compiler warning about missing return value
329     }
330   } else if (format == FORMAT_HWNC) {
331     switch (dimension) {
332       case '0': return 0;
333       case '1': return 1;
334       case '2': return 2;
335       case 'H': return NUM_SPATIAL_DIMS - 2;
336       case 'W': return NUM_SPATIAL_DIMS - 1;
337       case 'N': return NUM_SPATIAL_DIMS;
338       case 'C': return NUM_SPATIAL_DIMS + 1;
339       default:
340         LOG(FATAL) << "Invalid dimension: " << dimension;
341         return -1;  // Avoid compiler warning about missing return value
342     }
343   } else if (format == FORMAT_HWCN) {
344     switch (dimension) {
345       case '0': return 0;
346       case '1': return 1;
347       case '2': return 2;
348       case 'H': return NUM_SPATIAL_DIMS - 2;
349       case 'W': return NUM_SPATIAL_DIMS - 1;
350       case 'C': return NUM_SPATIAL_DIMS;
351       case 'N': return NUM_SPATIAL_DIMS + 1;
352       default:
353         LOG(FATAL) << "Invalid dimension: " << dimension;
354         return -1;  // Avoid compiler warning about missing return value
355     }
356   } else {
357     LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
358     return -1;  // Avoid compiler warning about missing return value
359   }
360   // clang-format on
361 }
362 
363 // Return the dimension index for the specified 'dimension' of the specified
364 // 'filter_tensor_format'.  'dimension' is a char that can be 'O' (num output
365 // channels), 'I' (num input channels), 'H' (height), 'W' (width), or a
366 // numbered spatial dimension: '0',  .. (NUM_SPATIAL_DIMS-1).
367 // If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the
368 // outer input channels dimension (i.e. 1).
369 template <int NUM_SPATIAL_DIMS>
GetFilterDimIndex(FilterTensorFormat filter_tensor_format,char dimension)370 inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format,
371                              char dimension) {
372   // clang-format off
373   if (filter_tensor_format == FORMAT_HWIO) {
374     switch (dimension) {
375       case '0': return 0;
376       case '1': return 1;
377       case '2': return 2;
378       case 'H': return NUM_SPATIAL_DIMS - 2;
379       case 'W': return NUM_SPATIAL_DIMS - 1;
380       case 'I': return NUM_SPATIAL_DIMS;
381       case 'O': return NUM_SPATIAL_DIMS + 1;
382       default:
383         LOG(FATAL) << "Invalid dimension: " << dimension;
384         return -1;  // Avoid compiler warning about missing return value
385     }
386   } else if (filter_tensor_format == FORMAT_OIHW ||
387              filter_tensor_format == FORMAT_OIHW_VECT_I) {
388     switch (dimension) {
389       case 'O': return 0;
390       case 'I': return 1;
391       case '0': return 2;
392       case '1': return 3;
393       case '2': return 4;
394       case 'H': return NUM_SPATIAL_DIMS;
395       case 'W': return NUM_SPATIAL_DIMS + 1;
396       default:
397         LOG(FATAL) << "Invalid dimension: " << dimension;
398         return -1;  // Avoid compiler warning about missing return value
399     }
400   } else {
401     LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format);
402     return -1;  // Avoid compiler warning about missing return value
403   }
404   // clang-format on
405 }
406 
GetTensorDimIndex(TensorFormat format,char dimension)407 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
408   return GetTensorDimIndex<2>(format, dimension);
409 }
410 
GetTensorDimIndex(TensorFormat format,char dimension,int num_total_dims)411 inline int32 GetTensorDimIndex(TensorFormat format, char dimension,
412                                int num_total_dims) {
413   int32 index = (GetTensorSpatialDims(num_total_dims, format) == 3)
414                     ? GetTensorDimIndex<3>(format, dimension)
415                     : GetTensorDimIndex<2>(format, dimension);
416   CHECK(index >= 0 && index < num_total_dims)  // Crash OK.
417       << "Invalid index from the dimension: " << index << ", " << format << ", "
418       << dimension;
419   return index;
420 }
421 
422 // Return the element from 'dimension_attributes' that corresponds to the
423 // specified 'dimension' according to 'tensor_format'.
424 template <typename T>
GetTensorDim(gtl::ArraySlice<T> dimension_attributes,TensorFormat tensor_format,char dimension)425 T GetTensorDim(gtl::ArraySlice<T> dimension_attributes,
426                TensorFormat tensor_format, char dimension) {
427   int index =
428       GetTensorDimIndex(tensor_format, dimension, dimension_attributes.size());
429   return dimension_attributes[index];
430 }
431 
432 // Return the element from 'dimension_attribute' that corresponds to the
433 // specified 'dimension' according to 'filter_tensor_format'.
434 template <typename T>
GetFilterDim(gtl::ArraySlice<T> dimension_attribute,FilterTensorFormat filter_tensor_format,char dimension)435 T GetFilterDim(gtl::ArraySlice<T> dimension_attribute,
436                FilterTensorFormat filter_tensor_format, char dimension) {
437   int index = (GetFilterTensorSpatialDims(dimension_attribute.size(),
438                                           filter_tensor_format) == 3)
439                   ? GetFilterDimIndex<3>(filter_tensor_format, dimension)
440                   : GetFilterDimIndex<2>(filter_tensor_format, dimension);
441   CHECK(index >= 0 && index < dimension_attribute.size())
442       << "Invalid index from the dimension: " << index << ", "
443       << filter_tensor_format << ", " << dimension;
444   return dimension_attribute[index];
445 }
446 
447 template <typename T>
GetTensorDim(const std::vector<T> & attributes,TensorFormat format,char dimension)448 T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
449                char dimension) {
450   return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension);
451 }
452 
453 // Return the size of the specified 'dimension' within 'tensor_shape'
454 // according to 'tensor_format'.
GetTensorDim(const TensorShape & tensor_shape,TensorFormat tensor_format,char dimension)455 inline int64 GetTensorDim(const TensorShape& tensor_shape,
456                           TensorFormat tensor_format, char dimension) {
457   return GetTensorDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
458                       tensor_format, dimension);
459 }
460 
461 // Return the size of the specified 'dimension' within 'tensor_shape'
462 // according to 'tensor_filter_format'.
GetFilterDim(const TensorShape & tensor_shape,FilterTensorFormat tensor_filter_format,char dimension)463 inline int64 GetFilterDim(const TensorShape& tensor_shape,
464                           FilterTensorFormat tensor_filter_format,
465                           char dimension) {
466   return GetFilterDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
467                       tensor_filter_format, dimension);
468 }
469 
470 // Return the size of the specified 'dimension' of 'tensor' according to
471 // 'tensor_format'.
GetTensorDim(const Tensor & tensor,TensorFormat tensor_format,char dimension)472 inline int64 GetTensorDim(const Tensor& tensor, TensorFormat tensor_format,
473                           char dimension) {
474   return GetTensorDim(tensor.shape(), tensor_format, dimension);
475 }
476 
477 // Return the size of the specified 'dimension' of 'tensor' according to
478 // 'filter_tensor_format'.
GetFilterDim(const Tensor & tensor,FilterTensorFormat filter_tensor_format,char dimension)479 inline int64 GetFilterDim(const Tensor& tensor,
480                           FilterTensorFormat filter_tensor_format,
481                           char dimension) {
482   return GetFilterDim(tensor.shape(), filter_tensor_format, dimension);
483 }
484 
GetExplicitPaddingForDim(const std::vector<int64> & explicit_paddings,TensorFormat tensor_format,char dimension,int64 * padding_before,int64 * padding_after)485 inline void GetExplicitPaddingForDim(
486     const std::vector<int64>& explicit_paddings, TensorFormat tensor_format,
487     char dimension, int64* padding_before, int64* padding_after) {
488   int index =
489       GetTensorDimIndex(tensor_format, dimension, explicit_paddings.size() / 2);
490   *padding_before = explicit_paddings[2 * index];
491   *padding_after = explicit_paddings[2 * index + 1];
492 }
493 
494 // Return the string that specifies the data format for convnet operations.
495 string GetConvnetDataFormatAttrString();
496 string GetConvnet3dDataFormatAttrString();
497 
498 // Return the string that specifies the filter format for convnet operations.
499 string GetConvnetFilterFormatAttrString();
500 string GetConvnet3dFilterFormatAttrString();
501 string GetConvnetDataFormat2D3DAttrString();
502 
503 // Returns a tensor shape for the specified format and dimension sizes.
504 // Works for both 2D and 3D operations. The output shapes are as follows:
505 // FORMAT_NHWC:        (N, spatial, C); rank = spatial.size() + 2
506 // FORMAT_NCHW:        (N, C, spatial); rank = spatial.size() + 2
507 // FORMAT_NCHW_VECT_C: (N, C, spatial, InnerC); rank = spatial.size() + 3
508 // FORMAT_NHWC_VECT_W: (N, spatial, C, InnerW); rank = spatial.size() + 3
ShapeFromFormat(TensorFormat format,int64 N,gtl::ArraySlice<int64> spatial,int64 C)509 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N,
510                                    gtl::ArraySlice<int64> spatial, int64 C) {
511   const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
512   gtl::InlinedVector<int64, 6> dim_sizes(dims);
513   dim_sizes[GetTensorBatchDimIndex(dims, format)] = N;
514   for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
515     auto dim_size = spatial[dim];
516     if (format == FORMAT_NHWC_VECT_W &&
517         static_cast<size_t>(dim) == spatial.size() - 1) {
518       CHECK_EQ(0, dim_size % 4)
519           << "FORMAT_NHWC_VECT_W requires W to be a multiple of 4, but W="
520           << dim_size;
521       dim_sizes[GetTensorInnerWidthDimIndex(dims, format)] = 4;
522       dim_size /= 4;
523     }
524     dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = dim_size;
525   }
526 
527   int feature_index = GetTensorFeatureDimIndex(dims, format);
528   if (format == FORMAT_NCHW_VECT_C) {
529     CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C="
530                        << C;
531     C /= 4;
532     dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4;
533   }
534   dim_sizes[feature_index] = C;
535   return TensorShape(dim_sizes);
536 }
537 
538 // Return a tensor shape of the specified 'format', and dimensions.
539 // Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I,
540 // the output TensorShape has spatial.size() + 3 dimensions, otherwise
541 // it has spatial.size() + 2 dimensions.
ShapeFromFilterTensorFormat(FilterTensorFormat format,gtl::ArraySlice<int64> spatial,int64 I,int64 O)542 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
543                                                gtl::ArraySlice<int64> spatial,
544                                                int64 I, int64 O) {
545   const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format);
546   gtl::InlinedVector<int64, 6> dim_sizes(dims);
547   dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O;
548   for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
549     dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
550   }
551 
552   if (format == FORMAT_OIHW_VECT_I) {
553     CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I="
554                        << I;
555     I /= 4;
556     dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4;
557   }
558   dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I;
559   return TensorShape(dim_sizes);
560 }
561 
562 // Return a tensor shape of the specified 'format', and dimensions.
ShapeFromFormat(TensorFormat format,int64 N,int64 H,int64 W,int64 C)563 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
564                                    int64 W, int64 C) {
565   return ShapeFromFormat(format, N, {H, W}, C);
566 }
567 
568 // Return a filter tensor shape of the specified 'format', and dimensions.
ShapeFromFilterTensorFormat(FilterTensorFormat format,int64 H,int64 W,int64 I,int64 O)569 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
570                                                int64 H, int64 W, int64 I,
571                                                int64 O) {
572   return ShapeFromFilterTensorFormat(format, {H, W}, I, O);
573 }
574 
575 // Returns a copy of the specified tensor 'src_shape' converted from
576 // 'src_format' to 'dst_format'.
ShapeFromFormat(TensorFormat dst_format,const TensorShape & src_shape,TensorFormat src_format)577 inline TensorShape ShapeFromFormat(TensorFormat dst_format,
578                                    const TensorShape& src_shape,
579                                    TensorFormat src_format) {
580   if (src_format == dst_format) {
581     return src_shape;
582   }
583 
584   const int64 batch = GetTensorDim(src_shape, src_format, 'N');
585   const int64 channels = GetTensorDim(src_shape, src_format, 'C') *
586                          (src_format == FORMAT_NCHW_VECT_C ? 4 : 1);
587   const int num_src_spatial_dims =
588       GetTensorSpatialDims(src_shape.dims(), src_format);
589   std::vector<int64> spatial_dims(num_src_spatial_dims);
590   for (int spatial_dim = 0; spatial_dim < num_src_spatial_dims; ++spatial_dim) {
591     spatial_dims[spatial_dim] =
592         gtl::ArraySlice<int64>(src_shape.dim_sizes())[GetTensorSpatialDimIndex(
593             src_shape.dims(), src_format, spatial_dim)];
594   }
595   if (src_format == FORMAT_NHWC_VECT_W) {
596     spatial_dims[num_src_spatial_dims - 1] *= 4;
597   }
598   return ShapeFromFormat(dst_format, batch, {spatial_dims}, channels);
599 }
600 
601 // Returns a copy of the specified filter tensor 'src_shape' converted from
602 // 'src_filter_format' to 'dst_filter_format'.
ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,const TensorShape & src_shape,FilterTensorFormat src_filter_format)603 inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,
604                                          const TensorShape& src_shape,
605                                          FilterTensorFormat src_filter_format) {
606   if (src_filter_format == dst_filter_format) {
607     return src_shape;
608   }
609 
610   const int64 output_channels = GetFilterDim(src_shape, src_filter_format, 'O');
611   const int64 input_channels =
612       GetFilterDim(src_shape, src_filter_format, 'I') *
613       (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1);
614 
615   if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) {
616     return ShapeFromFilterTensorFormat(
617         dst_filter_format,
618         {{GetFilterDim(src_shape, src_filter_format, '0'),
619           GetFilterDim(src_shape, src_filter_format, '1'),
620           GetFilterDim(src_shape, src_filter_format, '2')}},
621         input_channels, output_channels);
622   }
623 
624   return ShapeFromFilterTensorFormat(
625       dst_filter_format,
626       {{GetFilterDim(src_shape, src_filter_format, 'H'),
627         GetFilterDim(src_shape, src_filter_format, 'W')}},
628       input_channels, output_channels);
629 }
630 
631 }  // namespace tensorflow
632 
633 #endif  // TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
634