1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
18 
19 #include "tensorflow/core/kernels/eigen_convolution_helpers.h"
20 
21 // Note this header is used in both TF and TFLite.
22 namespace Eigen {
23 
24 namespace internal {
25 
26 // WARNING: Most of the code here implicitly assumes that the matrix is in
27 // ColMajor layout. This is guaranteed by the tensor contraction (see
28 // TensorContraction.h).
29 //
30 // Inside Eigen a tensor contraction is represented by a matrix multiplication.
31 // We don't want to actually extract image patches and reshape the result into
32 // a matrix (this involves allocating huge extra memory), so the patch
33 // extraction and reshape operations are implicit.
34 //
35 // TensorContractionInputMapper takes a matrix index and returns the coefficient
36 // (or the packet) of the "virtual tensor", that would be at that index if we
37 // were to actually reshape the result of patch extraction.
38 //
39 // TensorContractionSubMapper provides a similar view into the "virtual matrix"
40 // at the given vertical and horizontal offsets.
41 //
42 // "Virtual matrix" dimensions:
43 //   *0: kernelChannels * kernelRows * kernelCols;
44 //    1: out_height * out_width; * OTHERS (e.g batches, etc...)
45 //
46 // *) extracted patches are continuous in memory (innermost dimension assuming
47 //    col major layout)
48 //
49 // With this dimensions:
50 //   row - offset within a single patch (in code: patchId)
51 //   col - index of the extracted patch (in code: patchIndex)
52 //         patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
53 //
54 // TODO(ezhulenev): Consolidate this part of the code with the image patch
55 // extraction code since they are both very similar.
56 
57 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
58           typename Device, typename Scalar_, typename Index,
59           typename nocontract_t, typename contract_t, int Side, int packet_size,
60           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
61 class TensorContractionInputMapper<
62     Scalar_, Index, Side,
63     TensorEvaluator<
64         const TensorReshapingOp<NewDimension,
65                                 const TensorImagePatchOp<Rows, Cols, ArgType> >,
66         Device>,
67     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
68     inner_dim_reordered, Alignment> {
69  public:
70   typedef Scalar_ Scalar;
71 
72   typedef TensorContractionInputMapper<
73       Scalar, Index, Side,
74       TensorEvaluator<
75           const TensorReshapingOp<
76               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
77           Device>,
78       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
79       inner_dim_reordered, Alignment>
80       Self;
81 
82   typedef TensorContractionSubMapper<
83       Scalar, Index, Side,
84       TensorEvaluator<
85           const TensorReshapingOp<
86               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
87           Device>,
88       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
89       inner_dim_reordered, Alignment>
90       SubMapper;
91 
92   typedef SubMapper VectorMapper;
93   typedef SubMapper LinearMapper;
94   typedef typename packet_traits<Scalar>::type Packet;
95 
96   typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT;
97 
98   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorImagePatchOp<Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)99   TensorContractionInputMapper(
100       const TensorEvaluator<
101           const TensorReshapingOp<
102               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
103           Device>& tensor,
104       const nocontract_t&, const nocontract_t&, const contract_t&,
105       const contract_t&)
106       : m_impl(tensor.impl().impl()) {
107     Index patch_rows;
108     Index patch_depth;
109     if (internal::traits<ArgType>::Layout == ColMajor) {
110       patch_depth = tensor.impl().dimensions()[0];
111       patch_rows = tensor.impl().dimensions()[1];
112       m_patch_cols = tensor.impl().dimensions()[2];
113       m_num_patches = tensor.impl().dimensions()[3];
114     } else {
115       const size_t NumDims = tensor.impl().dimensions().size();
116       patch_depth = tensor.impl().dimensions()[NumDims - 1];
117       patch_rows = tensor.impl().dimensions()[NumDims - 2];
118       m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
119       m_num_patches = tensor.impl().dimensions()[NumDims - 4];
120     }
121 
122     // Strides for navigating through the single patch.
123     m_patch_row_stride = patch_depth;
124     m_patch_col_stride = patch_rows * m_patch_row_stride;
125 
126     m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
127     m_patch_col_inflate_strides = tensor.impl().colInflateStride();
128 
129     m_colStride = patch_rows;
130 
131     m_outputRows = tensor.impl().outputRows();
132     m_outputCols = tensor.impl().outputCols();
133     m_row_strides = tensor.impl().userRowStride();
134     m_col_strides = tensor.impl().userColStride();
135 
136     m_in_row_strides = tensor.impl().userInRowStride();
137     m_in_col_strides = tensor.impl().userInColStride();
138 
139     if (internal::traits<ArgType>::Layout == ColMajor) {
140       m_inputRows = tensor.impl().impl().dimensions()[1];
141       m_inputCols = tensor.impl().impl().dimensions()[2];
142     } else {
143       const int NumDims = tensor.impl().impl().dimensions().size();
144       m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
145       m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
146     }
147 
148     m_rowInputStride = patch_depth;
149     m_colInputStride = patch_depth * m_inputRows;
150     m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
151 
152     m_rowPaddingTop = tensor.impl().rowPaddingTop();
153     m_colPaddingLeft = tensor.impl().colPaddingLeft();
154 
155     m_fastPatchRowStride =
156         internal::TensorIntDivisor<Index>(m_patch_row_stride);
157     m_fastPatchColStride =
158         internal::TensorIntDivisor<Index>(m_patch_col_stride);
159     m_fastInputRowStride =
160         internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
161     m_fastInputColStride =
162         internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
163     m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
164     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
165     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
166     m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
167   }
168 
169   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)170   TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
171       : m_impl(base_mapper.m_impl) {
172     m_patch_cols = base_mapper.m_patch_cols;
173     m_num_patches = base_mapper.m_num_patches;
174 
175     m_patch_row_stride = base_mapper.m_patch_row_stride;
176     m_patch_col_stride = base_mapper.m_patch_col_stride;
177 
178     m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
179     m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
180 
181     m_colStride = base_mapper.m_colStride;
182 
183     m_rowInputStride = base_mapper.m_rowInputStride;
184     m_colInputStride = base_mapper.m_colInputStride;
185     m_patchInputStride = base_mapper.m_patchInputStride;
186 
187     m_inputRows = base_mapper.m_inputRows;
188     m_inputCols = base_mapper.m_inputCols;
189 
190     m_outputRows = base_mapper.m_outputRows;
191     m_outputCols = base_mapper.m_outputCols;
192     m_row_strides = base_mapper.m_row_strides;
193     m_col_strides = base_mapper.m_col_strides;
194 
195     m_in_row_strides = base_mapper.m_in_row_strides;
196     m_in_col_strides = base_mapper.m_in_col_strides;
197 
198     m_rowPaddingTop = base_mapper.m_rowPaddingTop;
199     m_colPaddingLeft = base_mapper.m_colPaddingLeft;
200 
201     m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
202     m_fastPatchColStride = base_mapper.m_fastPatchColStride;
203     m_fastInputRowStride = base_mapper.m_fastInputRowStride;
204     m_fastInputColStride = base_mapper.m_fastInputColStride;
205     m_fastNumPatches = base_mapper.m_fastNumPatches;
206     m_fastColStride = base_mapper.m_fastColStride;
207     m_fastOutputRows = base_mapper.m_fastOutputRows;
208     m_fastDimZero = base_mapper.m_fastDimZero;
209   }
210 
211   // If true, turns off some optimizations for loading packets since the image
212   // patches are "non-standard" such as there are non-trivial strides or
213   // inflations in the input.
214   EIGEN_DEVICE_FUNC
nonStandardPatches()215   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
216     return m_in_row_strides != 1 || m_in_col_strides != 1 ||
217            m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
218   }
219 
220   EIGEN_DEVICE_FUNC
getSubMapper(Index i,Index j)221   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
222     return SubMapper(*this, i, j);
223   }
224 
225   EIGEN_DEVICE_FUNC
getLinearMapper(Index i,Index j)226   EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
227     return LinearMapper(*this, i, j);
228   }
229 
230   EIGEN_DEVICE_FUNC
operator()231   EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
232     Index rowIndex, colIndex, otherIndex;
233     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
234     return loadCoeff(row, rowIndex, colIndex, otherIndex);
235   }
236 
237   // Load the coefficient at the patchIndex location instead of the usual
238   // m_rowIndex,
239   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
240   // EIGEN_DEVICE_FUNC
241   EIGEN_DEVICE_FUNC
operator()242   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
243     Index rowIndex, colIndex, otherIndex;
244     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
245     return loadCoeff(row, rowIndex, colIndex, otherIndex);
246   }
247 
248   EIGEN_DEVICE_FUNC
loadPacket(Index row)249   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
250     Index rowIndex, colIndex, otherIndex;
251     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
252     return loadPacket(row, rowIndex, colIndex, otherIndex);
253   }
254 
255   // Load the packet at the patchIndex location instead of the usual m_rowIndex,
256   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
257   EIGEN_DEVICE_FUNC
loadPacket(Index row,Index patchIndex)258   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
259     Index rowIndex, colIndex, otherIndex;
260     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
261     return loadPacket(row, rowIndex, colIndex, otherIndex);
262   }
263 
264   EIGEN_DEVICE_FUNC
impl()265   EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
266     return m_impl;
267   }
268 
269   EIGEN_DEVICE_FUNC
patchDepth()270   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
271   EIGEN_DEVICE_FUNC
patchRows()272   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
273   EIGEN_DEVICE_FUNC
patchCols()274   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
275 
276  private:
277   friend class TensorContractionSubMapper<
278       Scalar, Index, Side,
279       TensorEvaluator<
280           const TensorReshapingOp<
281               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
282           Device>,
283       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
284       inner_dim_reordered, Alignment>;
285 
286   // Load coefficient from a patch specified by the "within patch offset"
287   // (patchId) and the precomputed indices of the first element of the patch.
288   EIGEN_DEVICE_FUNC
loadCoeff(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)289   EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex,
290                                        Index colIndex, Index otherIndex) const {
291     // Find the offset of the element wrt the location of the first element.
292     const Index patchOffset = patchId / m_fastDimZero;
293 
294     const Index colOffset = patchOffset / m_fastColStride;
295     const Index inputCol = colIndex + colOffset * m_in_col_strides;
296     const Index origInputCol =
297         (m_patch_col_inflate_strides == 1)
298             ? inputCol
299             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
300 
301     const Index rowOffset = patchOffset - colOffset * m_colStride;
302     const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
303     const Index origInputRow =
304         (m_patch_row_inflate_strides == 1)
305             ? inputRow
306             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
307     if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
308         origInputRow >= m_inputRows ||
309         (inputCol != origInputCol * m_patch_col_inflate_strides) ||
310         (inputRow != origInputRow * m_patch_row_inflate_strides)) {
311       return Scalar(0);
312     }
313     const Index depth = patchId - patchOffset * patchDepth();
314     const Index inputIndex = depth + origInputRow * m_rowInputStride +
315                              origInputCol * m_colInputStride + otherIndex;
316     return m_impl.coeff(inputIndex);
317   }
318 
319   // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
320   // and `in_strides` equal to 1 (template specialization without templates).
321   EIGEN_DEVICE_FUNC
loadCoeffStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)322   EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex,
323                                                Index colIndex,
324                                                Index otherIndex) const {
325     eigen_assert(!nonStandardPatches());
326 
327     // Find the offset of the element wrt the location of the first element.
328     const Index patchOffset = patchId / m_fastDimZero;
329     const Index colOffset = patchOffset / m_fastColStride;
330     const Index rowOffset = patchOffset - colOffset * m_colStride;
331     const Index inputCol = colIndex + colOffset;
332     const Index inputRow = rowIndex + rowOffset;
333     if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
334         inputRow >= m_inputRows) {
335       return Scalar(0);
336     }
337     const Index depth = patchId - patchOffset * patchDepth();
338     const Index inputIndex = depth + inputRow * m_rowInputStride +
339                              inputCol * m_colInputStride + otherIndex;
340     return m_impl.coeff(inputIndex);
341   }
342 
343   // Load packet from a patch specified by the "within patch offset"
344   // (patchId) and the precomputed indices of the first element of the patch.
345   EIGEN_DEVICE_FUNC
loadPacket(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)346   EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex,
347                                         Index colIndex,
348                                         Index otherIndex) const {
349     const Index packetSize = internal::unpacket_traits<Packet>::size;
350     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
351     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
352 
353     if (nonStandardPatches()) {
354       return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
355     }
356     typedef decltype(m_impl) TensorEvaluatorT;
357     return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex,
358                                                         colIndex, otherIndex);
359   }
360 
361   // Helper function to load a 'partial' packet - this is the single column
362   // part of a packet that is split across two columns. In the 'partial' packet,
363   // the elements corresponding to the column (specified through colOffset) are
364   // loaded and the rest of the elements are zero-filled into the 'partial'
365   // packet. This function is called from loadPacketStandardFromTwoColumns().
366   // This code path is exercised only when the packet type supports masked load
367   // and when the partial packet load is available in the TensorEvaluator.
368   EIGEN_DEVICE_FUNC
loadPartialPacketStandard(Index rowIndex,Index colIndex,Index otherIndex,Index patchId,const Index span[],const Index patchOffsets[],Index colOffset)369   EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(
370       Index rowIndex, Index colIndex, Index otherIndex, Index patchId,
371       const Index span[], const Index patchOffsets[], Index colOffset) const {
372     const Index inputCol = colIndex + colOffset;
373     const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride,
374                                  patchOffsets[1] - colOffset * m_colStride};
375     const Index inputRows[2] = {rowIndex + rowOffsets[0],
376                                 rowIndex + rowOffsets[1]};
377 
378     if (inputRows[0] >= m_inputRows || inputRows[1] < 0 ||
379         inputCol >= m_inputCols || inputCol < 0) {
380       // Partial packet is all zeros
381       return internal::pset1<Packet>(Scalar(0));
382     } else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
383       // From inputIndex-span[0], we need to load elements starting from index
384       // span[0] all the way upto (and including) span[1].
385       const Index depth = patchId - patchOffsets[0] * patchDepth();
386       const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
387                                inputCol * m_colInputStride + otherIndex;
388       return m_impl.template partialPacket<Packet>(
389           inputIndex - span[0], mask<Packet>(span[0], span[1] + 1));
390     } else {
391       // Using slow path for this partial packet.
392       // We need to load elements starting from index span[0] all the way upto
393       // (and including) span[1]. We split this load into 3 parts:
394       // 0 : span[0]-1 - Zeros will be loaded for these indices
395       // span[0] : span[1] - Elements will be loaded here for these indices
396       // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
397       const Index packetSize = internal::unpacket_traits<Packet>::size;
398       EIGEN_ALIGN_MAX
399       typename internal::remove_const<Scalar>::type values[packetSize];
400       for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0);
401       for (int i = span[0]; i < span[1] + 1; ++i)
402         values[i] =
403             loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex);
404       for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0);
405       return internal::pload<Packet>(values);
406     }
407   }
408 
409   // Helper function to load a packet that is split across two columns.
410   // If required, this function is called from loadPacketStandard() when the
411   // packet type supports masked load and when the partial packet load is
412   // available in the TensorEvaluator.
413   EIGEN_DEVICE_FUNC
loadPacketStandardFromTwoColumns(Index patchId,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[])414   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns(
415       Index patchId, Index rowIndex, Index colIndex, Index otherIndex,
416       const Index patchOffsets[], const Index colOffsets[]) const {
417     eigen_assert(colOffsets[1] == colOffsets[0] + 1);
418     const Index packetSize = internal::unpacket_traits<Packet>::size;
419 
420     // Packet to load will be split into 2 parts where each part spans a single
421     // column. First determine where to split.
422     const Index patchIdSplit =
423         ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1;
424     const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
425 
426     // patchIds[i]:          patchId corresponding to partial packet i
427     // spans[i]:             Start and end indices corresponding to the elements
428     //                       to be loaded for partial packet i
429     // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
430     const Index patchIds[2] = {patchId, patchIdSplit + 1};
431     const Index spans[2][2] = {{0, patchIdSplit - patchId},
432                                {patchIdSplit - patchId + 1, packetSize - 1}};
433     const Index patchOffsets2Cols[2][2] = {
434         {patchOffsets[0], patchOffsetSplit},
435         {patchOffsetSplit + 1, patchOffsets[1]}};
436 
437     // Load partial packets and do bit-wise OR to generate required packet
438     return internal::por<Packet>(
439         loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0],
440                                   spans[0], patchOffsets2Cols[0],
441                                   colOffsets[0]),
442         loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1],
443                                   spans[1], patchOffsets2Cols[1],
444                                   colOffsets[1]));
445   }
446 
447   // Helper function to load a packet that is present in a single columns.
448   // If required, this function is called from loadPacketStandard().
449   EIGEN_DEVICE_FUNC
loadPacketStandardFromSingleColumn(Index patchId,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index inputCols[])450   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn(
451       Index patchId, Index rowIndex, Index colIndex, Index otherIndex,
452       const Index patchOffsets[], const Index colOffsets[],
453       const Index inputCols[]) const {
454     eigen_assert(colOffsets[0] == colOffsets[1]);
455     const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride,
456                                  patchOffsets[1] - colOffsets[1] * m_colStride};
457     eigen_assert(rowOffsets[0] <= rowOffsets[1]);
458     const Index inputRows[2] = {rowIndex + rowOffsets[0],
459                                 rowIndex + rowOffsets[1]};
460 
461     if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
462       // all zeros
463       return internal::pset1<Packet>(Scalar(0));  // all zeros
464     }
465 
466     if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
467       // no padding
468       const Index depth = patchId - patchOffsets[0] * patchDepth();
469       const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
470                                inputCols[0] * m_colInputStride + otherIndex;
471       return m_impl.template packet<Unaligned>(inputIndex);
472     }
473     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
474   }
475 
476   // Load standard packet from a patch specified by the "within patch offset"
477   // (patchId) and the precomputed indices of the first element of the patch.
478   // This function will be called if partial packet loading is not available
479   // for the TensorEvaluator or if the packet type does not support masked
480   // load.
481   template <typename PacketT, typename TensorEvaluatorT>
482   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
483       !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
484       PacketT>::type
loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)485   loadPacketStandard(Index patchId, Index rowIndex, Index colIndex,
486                      Index otherIndex) const {
487     const Index packetSize = internal::unpacket_traits<Packet>::size;
488     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
489     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
490 
491     eigen_assert(!nonStandardPatches());
492 
493     if ((patchDepth() % packetSize) == 0) {
494       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
495     }
496 
497     // Offsets and input calculation here are identical to
498     // loadCoeffStandard(...), but repeated twice.
499     const Index patchOffsets[2] = {patchId / m_fastDimZero,
500                                    (patchId + packetSize - 1) / m_fastDimZero};
501     const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
502                                  patchOffsets[1] / m_fastColStride};
503     const Index inputCols[2] = {colIndex + colOffsets[0],
504                                 colIndex + colOffsets[1]};
505 
506     if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
507       // all zeros
508       return internal::pset1<Packet>(Scalar(0));
509     }
510     if (inputCols[0] == inputCols[1]) {
511       return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex,
512                                                 otherIndex, patchOffsets,
513                                                 colOffsets, inputCols);
514     }
515     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
516   }
517 
518   // Load standard packet from a patch specified by the "within patch offset"
519   // (patchId) and the precomputed indices of the first element of the patch.
520   // This function will be called if partial packet loading is available for
521   // the TensorEvaluator and if the packet type supports masked load.
522   // The only difference between this and the other case is that if the packet
523   // to load is split across two columns, then in this case instead of going to
524   // the slow (element-by-element) load, we load two packets - each containing
525   // elements from one of the columns (rest of the elements of the packets are
526   // zeroes), and then combine these two packets to generate the required
527   // packet. The idea is to enable fast load (if possible) of these 'partial'
528   // packets.
529   template <typename PacketT, typename TensorEvaluatorT>
530   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
531       TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
532       PacketT>::type
loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)533   loadPacketStandard(Index patchId, Index rowIndex, Index colIndex,
534                      Index otherIndex) const {
535     const Index packetSize = internal::unpacket_traits<PacketT>::size;
536     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
537     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
538 
539     eigen_assert(!nonStandardPatches());
540 
541     if ((patchDepth() % packetSize) == 0) {
542       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
543     }
544 
545     // Offsets and input calculation here are identical to
546     // loadCoeffStandard(...), but repeated twice.
547     const Index patchOffsets[2] = {patchId / m_fastDimZero,
548                                    (patchId + packetSize - 1) / m_fastDimZero};
549     const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
550                                  patchOffsets[1] / m_fastColStride};
551     const Index inputCols[2] = {colIndex + colOffsets[0],
552                                 colIndex + colOffsets[1]};
553 
554     if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
555       // all zeros
556       return internal::pset1<PacketT>(Scalar(0));
557     }
558     if (inputCols[0] == inputCols[1]) {
559       return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex,
560                                                 otherIndex, patchOffsets,
561                                                 colOffsets, inputCols);
562     }
563     if (inputCols[1] == inputCols[0] + 1) {
564       return loadPacketStandardFromTwoColumns(
565           patchId, rowIndex, colIndex, otherIndex, patchOffsets, colOffsets);
566     }
567     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
568   }
569 
570   EIGEN_DEVICE_FUNC
loadPacketFast(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)571   EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex,
572                                             Index colIndex,
573                                             Index otherIndex) const {
574     const Index packetSize = internal::unpacket_traits<Packet>::size;
575     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
576     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
577 
578     eigen_assert(!nonStandardPatches());
579     eigen_assert((patchDepth() % packetSize) == 0);
580     // Find the offset of the element wrt the location of the first element.
581     const Index patchOffset = patchId / m_fastDimZero;
582     eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
583 
584     const Index colOffset = patchOffset / m_fastColStride;
585     const Index rowOffset = patchOffset - colOffset * m_colStride;
586     const Index inputCol = colIndex + colOffset;
587     const Index inputRow = rowIndex + rowOffset;
588     if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
589         inputRow >= m_inputRows) {
590       // all zeros
591       return internal::pset1<Packet>(Scalar(0));
592     }
593     // no padding
594     const Index depth = patchId - patchOffset * patchDepth();
595     const Index inputIndex = depth + inputRow * m_rowInputStride +
596                              inputCol * m_colInputStride + otherIndex;
597     return m_impl.template packet<Unaligned>(inputIndex);
598   }
599 
packetWithPossibleZero(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)600   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(
601       Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const {
602     const int packetSize = internal::unpacket_traits<Packet>::size;
603     EIGEN_ALIGN_MAX
604     typename internal::remove_const<Scalar>::type values[packetSize];
605     for (int i = 0; i < packetSize; ++i) {
606       values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
607     }
608     Packet rslt = internal::pload<Packet>(values);
609     return rslt;
610   }
611 
computeBaseIndices(Index patchIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)612   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
613       Index patchIndex, Index& rowIndex, Index& colIndex,
614       Index& otherIndex) const {
615     const size_t NumInputDims = array_size<
616         typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
617     otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
618     const Index patch2DIndex = (NumInputDims == 3)
619                                    ? patchIndex
620                                    : (patchIndex - otherIndex * m_num_patches);
621     otherIndex *= m_patchInputStride;
622     colIndex = patch2DIndex / m_fastOutputRows;
623     rowIndex = patch2DIndex - colIndex * m_outputRows;
624     colIndex = colIndex * m_col_strides - m_colPaddingLeft;
625     rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
626   }
627 
628   Index m_patch_cols;   // number of columns in the patch
629   Index m_num_patches;  // number of patches to extract.
630 
631   // Strides for navigating through the single patch.
632   Index m_patch_row_stride;
633   Index m_patch_col_stride;
634   internal::TensorIntDivisor<Index> m_fastPatchRowStride;
635   internal::TensorIntDivisor<Index> m_fastPatchColStride;
636 
637   Index m_patch_row_inflate_strides;  // the strides for row inflation in the
638                                       // image patch
639   Index m_patch_col_inflate_strides;  // the strides for col inflation in the
640                                       // image patch
641   // Fast representation of inflation strides.
642   internal::TensorIntDivisor<Index> m_fastInputRowStride;
643   internal::TensorIntDivisor<Index> m_fastInputColStride;
644 
645   Index m_otherStride;
646   Index m_colStride;
647   internal::TensorIntDivisor<Index> m_fastNumPatches;
648   internal::TensorIntDivisor<Index> m_fastColStride;
649 
650   Index m_rowInputStride;    // row stride in the input tensor
651   Index m_colInputStride;    // col stride in the input tensor
652   Index m_patchInputStride;  // patch stride in the input tensor
653 
654   Index m_inputRows;  // Number of rows in the input tensor
655   Index m_inputCols;  // Number of cols in the input tensor
656 
657   Index m_outputRows;  // Number of convolution output rows
658   Index m_outputCols;  // Number of convolution output column
659 
660   Index m_row_strides;  // User specified row stride
661   Index m_col_strides;  // User specified col stride
662 
663   Index m_in_row_strides;  // User specified input row stride
664   Index m_in_col_strides;  // User specified input col stride
665 
666   Index m_rowPaddingTop;   // Row padding
667   Index m_colPaddingLeft;  // Column padding
668 
669   internal::TensorIntDivisor<Index> m_fastOutputRows;
670   internal::TensorIntDivisor<Index> m_fastDimZero;
671 
672   const TensorEvaluator<ArgType, Device> m_impl;
673 };
674 
675 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
676           typename Device, typename Scalar, typename Index,
677           typename nocontract_t, typename contract_t, int Side, int packet_size,
678           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
679 class TensorContractionSubMapper<
680     Scalar, Index, Side,
681     TensorEvaluator<
682         const TensorReshapingOp<NewDimension,
683                                 const TensorImagePatchOp<Rows, Cols, ArgType> >,
684         Device>,
685     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
686     inner_dim_reordered, Alignment> {
687  public:
688   typedef typename packet_traits<Scalar>::type Packet;
689   typedef typename packet_traits<Scalar>::half HalfPacket;
690 
691   typedef TensorContractionInputMapper<
692       Scalar, Index, Side,
693       TensorEvaluator<
694           const TensorReshapingOp<
695               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
696           Device>,
697       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
698       inner_dim_reordered, Alignment>
699       ParentMapper;
700 
701   typedef TensorContractionSubMapper<
702       Scalar, Index, Side,
703       TensorEvaluator<
704           const TensorReshapingOp<
705               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
706           Device>,
707       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
708       inner_dim_reordered, Alignment>
709       Self;
710 
711   typedef Self LinearMapper;
712 
713   typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT;
714 
TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)715   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
716       const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
717       : m_depth_offset(vert_offset),
718         m_col_offset(horiz_offset),
719         m_base_mapper(base_mapper) {
720     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
721                                      m_otherIndex);
722   }
TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)723   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
724       const Self& base_mapper, Index vert_offset, Index horiz_offset)
725       : m_depth_offset(vert_offset + base_mapper.m_depth_offset),
726         m_col_offset(horiz_offset + base_mapper.m_col_offset),
727         m_base_mapper(base_mapper.m_base_mapper) {
728     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
729                                      m_otherIndex);
730   }
operator()731   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
732     return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex,
733                                    m_otherIndex);
734   }
operator()735   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
736                                                           Index j) const {
737     return m_base_mapper(i + m_depth_offset, j + m_col_offset);
738   }
739 
loadPacket(Index i)740   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
741     return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex,
742                                     m_otherIndex);
743   }
loadPacket(Index i,Index j)744   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
745                                                           Index j) const {
746     return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
747                                                         j + m_col_offset);
748   }
749   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
loadCoeffStandard(Index i)750   loadCoeffStandard(Index i) const {
751     return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex,
752                                            m_colIndex, m_otherIndex);
753   }
754 
loadPacketFast(Index i)755   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
756     return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex,
757                                         m_colIndex, m_otherIndex);
758   }
759   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
loadPacketStandard(Index i)760   loadPacketStandard(Index i) const {
761     typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
762     return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
763         i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
764   }
765   template <typename Packet>
aligned(Index)766   EIGEN_DEVICE_FUNC bool aligned(Index) const {
767     return false;
768   }
769 
770   EIGEN_DEVICE_FUNC
nonStandardPatches()771   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
772     return m_base_mapper.nonStandardPatches();
773   }
774 
775   // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
776   // index respectively that fits into the peeled_k elements starting at
777   // m_depth_offset.
778 
779   EIGEN_DEVICE_FUNC
maxCol(const Index peeled_k)780   EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
781     const Index max_col =
782         (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) /
783         fastPatchColStride();
784     return std::min<Index>(1 + max_col, patchCols());
785   }
786 
787   EIGEN_DEVICE_FUNC
maxRow(const Index peeled_k,const Index col)788   EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
789                                    const Index col) const {
790     const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) -
791                            col * patchColStride()) /
792                           fastPatchRowStride();
793     return std::min<Index>(1 + max_row, patchRows());
794   }
795 
796   EIGEN_DEVICE_FUNC
maxDepth(const Index peeled_k,const Index col,Index row)797   EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col,
798                                      Index row) const {
799     const Index max_depth = m_depth_offset + peeled_k -  //
800                             col * patchColStride() -     //
801                             row * patchRowStride();
802     return std::min<Index>(max_depth, patchDepth());
803   }
804 
805   // MaxDepth uses only the remaining number of elements in the peeled_k.
806   EIGEN_DEVICE_FUNC
maxDepth(const Index num_elements,const Index start_depth)807   EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
808                                      const Index start_depth) const {
809     return std::min<Index>(start_depth + num_elements, patchDepth());
810   }
811 
812   // Every register matters in this code, so sometimes to prevent register
813   // spilling, instead of the variable that you would expect to see, we use
814   // another one, that is guaranteed to have the same value. E.g. patch depth is
815   // always the same as input depth, and it's also the same as input row stride.
816   // Bunch of other parameters have similar relations.
817 
818   typedef internal::TensorIntDivisor<Index> IndexDivisor;
819 
820   EIGEN_DEVICE_FUNC
patchDepth()821   EIGEN_ALWAYS_INLINE Index patchDepth() const {
822     return m_base_mapper.m_rowInputStride;
823   }
824   EIGEN_DEVICE_FUNC
patchRows()825   EIGEN_ALWAYS_INLINE Index patchRows() const {
826     return m_base_mapper.m_colStride;
827   }
828   EIGEN_DEVICE_FUNC
patchCols()829   EIGEN_ALWAYS_INLINE Index patchCols() const {
830     return m_base_mapper.m_patch_cols;
831   }
832 
833   EIGEN_DEVICE_FUNC
patchRowStride()834   EIGEN_ALWAYS_INLINE Index patchRowStride() const {
835     eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
836                  "Patch depth must be equal to patch row stride.");
837     return patchDepth();
838   }
839   EIGEN_DEVICE_FUNC
patchColStride()840   EIGEN_ALWAYS_INLINE Index patchColStride() const {
841     return m_base_mapper.m_patch_col_stride;
842   }
843 
844   EIGEN_DEVICE_FUNC
fastPatchRowStride()845   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
846     eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
847                  "Patch depth must be equal to patch row stride.");
848     return m_base_mapper.m_fastDimZero;  // patch_depth
849   }
850   EIGEN_DEVICE_FUNC
fastPatchColStride()851   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
852     return m_base_mapper.m_fastPatchColStride;
853   }
854 
855   EIGEN_DEVICE_FUNC
packetNoPadding(const Index depth,const Index baseIndex)856   EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
857                                              const Index baseIndex) const {
858     const Index inputIndex = depth + baseIndex;
859     return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
860   }
861   EIGEN_DEVICE_FUNC
coeffNoPadding(const Index depth,const Index baseIndex)862   EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth,
863                                             const Index baseIndex) const {
864     const Index inputIndex = depth + baseIndex;
865     return m_base_mapper.m_impl.coeff(inputIndex);
866   }
867   template <typename PacketT = Packet>
868   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
869       TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
870       PacketT>::type
partialPacketNoPadding(const Index depth,const Index baseIndex,Index num_coeffs)871   partialPacketNoPadding(const Index depth, const Index baseIndex,
872                          Index num_coeffs) const {
873     const Index inputIndex = depth + baseIndex;
874     return m_base_mapper.m_impl.template partialPacket<PacketT>(
875         inputIndex, mask<PacketT>(0, num_coeffs));
876   }
877   EIGEN_DEVICE_FUNC
hasPadding()878   EIGEN_ALWAYS_INLINE bool hasPadding() const {
879     // TODO(ezhulenev): It does seems that for inflated filter it's still
880     // possible to guarantee "no padding or skipping" for non-standard packing.
881     if (nonStandardPatches()) return true;
882 
883     // Non zero padding before.
884     if (m_base_mapper.m_rowPaddingTop > 0) return true;
885     if (m_base_mapper.m_colPaddingLeft > 0) return true;
886 
887     // Non zero padding after in rows.
888     const Index last_row =
889         (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides;
890     if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows) return true;
891 
892     // Non zero padding after in cols.
893     const Index last_col =
894         (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides;
895     if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols) return true;
896 
897     return false;
898   }
899   EIGEN_DEVICE_FUNC
padRow(const Index row)900   EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
901     const Index r = m_rowIndex + row;
902     return r < 0 || r >= m_base_mapper.m_inputRows;
903   }
904   EIGEN_DEVICE_FUNC
padAnyRow(const Index first_row,const Index last_row)905   EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row,
906                                      const Index last_row) const {
907     return m_rowIndex + first_row < 0 ||
908            m_rowIndex + last_row >= m_base_mapper.m_inputRows;
909   }
910   EIGEN_DEVICE_FUNC
padOrSkipRow(const Index row,Index * orig_row)911   EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row,
912                                         Index* orig_row) const {
913     eigen_assert(nonStandardPatches());
914 
915     const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides;
916     *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1)
917                     ? input_row
918                     : ((input_row >= 0)
919                            ? (input_row / m_base_mapper.m_fastInputRowStride)
920                            : 0);
921 
922     return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) ||
923            (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides);
924   }
925   EIGEN_DEVICE_FUNC
padCol(const Index col)926   EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
927     const Index c = m_colIndex + col;
928     return c < 0 || c >= m_base_mapper.m_inputCols;
929   }
930   EIGEN_DEVICE_FUNC
padOrSkipCol(const Index col,Index * orig_col)931   EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col,
932                                         Index* orig_col) const {
933     eigen_assert(nonStandardPatches());
934 
935     const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides;
936     *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1)
937                     ? input_col
938                     : ((input_col >= 0)
939                            ? (input_col / m_base_mapper.m_fastInputColStride)
940                            : 0);
941 
942     return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) ||
943            (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides);
944   }
945   EIGEN_DEVICE_FUNC
baseIndex(const Index row,const Index col)946   EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const {
947     const Index r = m_rowIndex + row;
948     const Index c = m_colIndex + col;
949     return r * m_base_mapper.m_rowInputStride +
950            c * m_base_mapper.m_colInputStride + m_otherIndex;
951   }
952   // Compute a base index when original input row and column were precomputed
953   // using padOrSkipRow and padOrSkipCol. Used only for non standard patches.
954   EIGEN_DEVICE_FUNC
origBaseIndex(const Index orig_row,const Index orig_col)955   EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row,
956                                           const Index orig_col) const {
957     return orig_row * m_base_mapper.m_rowInputStride +
958            orig_col * m_base_mapper.m_colInputStride + m_otherIndex;
959   }
960 
961   EIGEN_DEVICE_FUNC
rowStride()962   EIGEN_ALWAYS_INLINE Index rowStride() const {
963     return m_base_mapper.m_row_strides;
964   }
965   EIGEN_DEVICE_FUNC
colStride()966   EIGEN_ALWAYS_INLINE Index colStride() const {
967     return m_base_mapper.m_col_strides;
968   }
969 
970   EIGEN_DEVICE_FUNC
rowOffset()971   EIGEN_ALWAYS_INLINE Index rowOffset() const {
972     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
973     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
974     return patchOffset - colOffset * m_base_mapper.m_colStride;
975   }
976 
977   EIGEN_DEVICE_FUNC
colOffset()978   EIGEN_ALWAYS_INLINE Index colOffset() const {
979     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
980     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
981     return colOffset;
982   }
983 
984   EIGEN_DEVICE_FUNC
depthOffset()985   EIGEN_ALWAYS_INLINE Index depthOffset() const {
986     return m_depth_offset % patchDepth();
987   }
988 
989   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
getLinearMapper(Index i,Index j)990   getLinearMapper(Index i, Index j) const {
991     return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
992   }
993 
994  private:
995   Index m_depth_offset;  // First row in the input matrix
996   Index m_col_offset;    // First col in the input matrix
997 
998   // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
999   // indices for the first element in a patch specified by col_offset
1000   // (see computeBaseIndices(...) for details).
1001   Index m_rowIndex;
1002   Index m_colIndex;
1003   Index m_otherIndex;
1004 
1005   const ParentMapper m_base_mapper;  // Keeping a copy instead of a reference
1006                                      // performs better in benchmarks.
1007 };
1008 
1009 // Arrange a block of the right input matrix (in our case it's always a "virtual
1010 // matrix" constructed from extracted image patches) in contiguous memory.
1011 //
1012 // Given column major input (A0 beside A1 in memory):
1013 // A0 B0 C0 D0  E0 F0 G0 H0 ... Z0
1014 // A1 B1 C1 D1  E1 F1 G1 H1 ... Z1
1015 // A2 B2 C2 D2  E2 F2 G2 H2 ... Z2
1016 // A3 B3 C3 D3  E3 F3 G3 H3 ... Z3
1017 // A4 B4 C4 D4  E4 F4 G4 H4 ... Z4
1018 // A5 B5 C5 D5  E5 F5 G5 H5 ... Z5
1019 // A6 B6 C6 D6  E6 F6 G6 H6 ... Z6
1020 // A7 B7 C7 D7  E7 F7 G7 H7 ... Z7
1021 // A8 ...
1022 // ...
1023 //
1024 // *) A, B, C, ... - patches extracted from the original input.
1025 // *) A0, A1, A2 ... - values from the same patch at different offsets.
1026 //
1027 // The traversal (packed rhs memory) order (B0 besides A0 in memory):
1028 // A0 B0 C0 D0 A1 B1 C1 D1 ...
1029 // E0 F0 G0 H0 E1 F1 G1 H1 ...
1030 // ...
1031 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1032 //
1033 // This traversal order must be the same as in default gemm_pack_rhs defined in
1034 // GeneralBlockPanelKernel.h.
1035 //
1036 // *) nr - number of registers along the 'n' dimension.
1037 //    See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1038 //    Multiplication" paper.
1039 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1040           typename Device, typename Scalar, typename Index,
1041           typename nocontract_t, typename contract_t, int packet_size,
1042           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1043           int nr>
1044 struct gemm_pack_rhs<
1045     Scalar, Index,
1046     TensorContractionSubMapper<
1047         Scalar, Index, Rhs,
1048         TensorEvaluator<
1049             const TensorReshapingOp<
1050                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1051             Device>,
1052         nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1053         inner_dim_reordered, Alignment>,
1054     nr, ColMajor, false, false> {
1055   typedef TensorContractionSubMapper<
1056       Scalar, Index, Rhs,
1057       TensorEvaluator<
1058           const TensorReshapingOp<
1059               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1060           Device>,
1061       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1062       inner_dim_reordered, Alignment>
1063       SubMapper;
1064   typedef SubMapper DataMapper;
1065   typedef typename packet_traits<Scalar>::type Packet;
1066 
1067   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1068 
1069   EIGEN_DEVICE_FUNC
1070   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1071                                     Index depth, Index cols, Index stride = 0,
1072                                     Index offset = 0) const {
1073     eigen_assert(stride == 0);
1074     eigen_assert(offset == 0);
1075 
1076     const Index packet_cols4 = (cols / 4) * 4;
1077     const Index peeled_k = (depth / packet_size) * packet_size;
1078     const bool non_standard_patches = rhs.nonStandardPatches();
1079 
1080     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1081       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1082       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1083       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1084       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1085 
1086       Index k = 0;
1087       if ((packet_size % 4) == 0 && !non_standard_patches) {
1088         // FAST PATH:
1089         // Iterate over patch columns and rows, if we know that a single
1090         // packet do not span across multiple rows or columns.
1091         if ((rhs.patchDepth() % packet_size) == 0) {
1092           const Index start_col = rhs.colOffset();
1093           const Index max_col = rhs.maxCol(peeled_k);
1094 
1095           for (Index c = start_col; c < max_col; ++c) {
1096             eigen_assert(k <= peeled_k);
1097 
1098             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1099             const Index max_row = rhs.maxRow(peeled_k, c);
1100 
1101             const bool pad_col0 = dm0.padCol(c);
1102             const bool pad_col1 = dm1.padCol(c);
1103             const bool pad_col2 = dm2.padCol(c);
1104             const bool pad_col3 = dm3.padCol(c);
1105 
1106             // Check if we can squeeze reads along the `row` and `depth`
1107             // dimensions (two innermost dimensions).
1108             if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 &&    //
1109                 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) &&  //
1110                 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) &&  //
1111                 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) &&  //
1112                 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) {
1113               // Compute how many elements we can squeeze read.
1114               const Index start_depth =
1115                   (c == start_col) ? rhs.depthOffset() : 0;
1116 
1117               // Upper bound for the number of elements in the depth dimension
1118               // that we can squeeze read.
1119               const Index squeeze_length =
1120                   (max_row - start_row) * rhs.patchDepth() - start_depth;
1121 
1122               // Do not overshoot beyond the block size.
1123               const Index max_depth =
1124                   start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1125               eigen_assert((max_depth - start_depth) % packet_size == 0);
1126 
1127               const Index idx0 = dm0.baseIndex(start_row, c);
1128               const Index idx1 = dm1.baseIndex(start_row, c);
1129               const Index idx2 = dm2.baseIndex(start_row, c);
1130               const Index idx3 = dm3.baseIndex(start_row, c);
1131 
1132               for (Index d = start_depth; d < max_depth; d += packet_size) {
1133                 eigen_assert(k < peeled_k);
1134                 PacketBlock<Packet, 4> kernel;
1135                 kernel.packet[0] = rhs.packetNoPadding(d, idx0);
1136                 kernel.packet[1] = rhs.packetNoPadding(d, idx1);
1137                 kernel.packet[2] = rhs.packetNoPadding(d, idx2);
1138                 kernel.packet[3] = rhs.packetNoPadding(d, idx3);
1139                 ptranspose(kernel);
1140                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1141                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1142                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1143                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1144                 block += 4 * packet_size;
1145                 k += packet_size;
1146               }
1147 
1148               // Go to the next column.
1149               continue;
1150             }
1151 
1152             // If we can't squeeze reads, process rows one by one.
1153             for (Index r = start_row; r < max_row; ++r) {
1154               eigen_assert(k <= peeled_k);
1155 
1156               const bool pad0 = pad_col0 || dm0.padRow(r);
1157               const bool pad1 = pad_col1 || dm1.padRow(r);
1158               const bool pad2 = pad_col2 || dm2.padRow(r);
1159               const bool pad3 = pad_col3 || dm3.padRow(r);
1160 
1161               const Index idx0 = dm0.baseIndex(r, c);
1162               const Index idx1 = dm1.baseIndex(r, c);
1163               const Index idx2 = dm2.baseIndex(r, c);
1164               const Index idx3 = dm3.baseIndex(r, c);
1165 
1166               const Index start_depth = ((c == start_col) && (r == start_row))
1167                                             ? rhs.depthOffset()
1168                                             : 0;
1169               const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1170               eigen_assert((max_depth - start_depth) % packet_size == 0);
1171 
1172               for (Index d = start_depth; d < max_depth; d += packet_size) {
1173                 eigen_assert(k < peeled_k);
1174                 PacketBlock<Packet, 4> kernel;
1175                 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1176                                         : rhs.packetNoPadding(d, idx0);
1177                 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1178                                         : rhs.packetNoPadding(d, idx1);
1179                 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
1180                                         : rhs.packetNoPadding(d, idx2);
1181                 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
1182                                         : rhs.packetNoPadding(d, idx3);
1183                 ptranspose(kernel);
1184                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1185                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1186                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1187                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1188                 block += 4 * packet_size;
1189                 k += packet_size;
1190               }
1191             }
1192           }
1193 
1194           // The loop above should fill peeled_k elements.
1195           eigen_assert(peeled_k == k);
1196 
1197         } else {
1198           for (; k < peeled_k; k += packet_size) {
1199             PacketBlock<Packet, 4> kernel;
1200             kernel.packet[0] = dm0.loadPacketStandard(k);
1201             kernel.packet[1] = dm1.loadPacketStandard(k);
1202             kernel.packet[2] = dm2.loadPacketStandard(k);
1203             kernel.packet[3] = dm3.loadPacketStandard(k);
1204             ptranspose(kernel);
1205             pstoreu(block + 0 * packet_size, kernel.packet[0]);
1206             pstoreu(block + 1 * packet_size, kernel.packet[1]);
1207             pstoreu(block + 2 * packet_size, kernel.packet[2]);
1208             pstoreu(block + 3 * packet_size, kernel.packet[3]);
1209             block += 4 * packet_size;
1210           }
1211         }
1212       }
1213 
1214       // Copy the remaining coefficients of the column block after the peeled_k.
1215       if (!rhs.nonStandardPatches()) {
1216         for (; k < depth; k++) {
1217           block[0] = dm0.loadCoeffStandard(k);
1218           block[1] = dm1.loadCoeffStandard(k);
1219           block[2] = dm2.loadCoeffStandard(k);
1220           block[3] = dm3.loadCoeffStandard(k);
1221           block += 4;
1222         }
1223       } else {
1224         for (; k < depth; k++) {
1225           block[0] = dm0(k);
1226           block[1] = dm1(k);
1227           block[2] = dm2(k);
1228           block[3] = dm3(k);
1229           block += 4;
1230         }
1231       }
1232     }
1233 
1234     // copy the remaining columns one at a time (nr==1)
1235     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1236       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1237       for (Index k = 0; k < depth; k++) {
1238         *block = dm0(k);
1239         block += 1;
1240       }
1241     }
1242   }
1243 };
1244 
1245 // Template specialization for packet_size = 2. We must special-case packet
1246 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1247 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1248           typename Device, typename Scalar, typename Index,
1249           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1250           bool inner_dim_reordered, int Alignment, int nr>
1251 struct gemm_pack_rhs<
1252     Scalar, Index,
1253     TensorContractionSubMapper<
1254         Scalar, Index, Rhs,
1255         TensorEvaluator<
1256             const TensorReshapingOp<
1257                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1258             Device>,
1259         nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered,
1260         Alignment>,
1261     nr, ColMajor, false, false> {
1262   typedef TensorContractionSubMapper<
1263       Scalar, Index, Rhs,
1264       TensorEvaluator<
1265           const TensorReshapingOp<
1266               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1267           Device>,
1268       nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered,
1269       Alignment>
1270       SubMapper;
1271   typedef SubMapper DataMapper;
1272   typedef typename packet_traits<Scalar>::type Packet;
1273 
1274   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1275 
1276   EIGEN_DEVICE_FUNC
1277   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1278                                     Index depth, Index cols, Index stride = 0,
1279                                     Index offset = 0) const {
1280     eigen_assert(stride == 0);
1281     eigen_assert(offset == 0);
1282 
1283     const int packet_size = 2;
1284     const Index packet_cols4 = (cols / 4) * 4;
1285     const Index peeled_k = (depth / packet_size) * packet_size;
1286     const bool non_standard_patches = rhs.nonStandardPatches();
1287 
1288     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1289       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1290       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1291       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1292       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1293 
1294       Index k = 0;
1295       if (!non_standard_patches) {
1296         // FAST PATH:
1297         // Iterate over patch columns and rows if we know that a single
1298         // packet do not span across multiple rows or columns.
1299         if ((rhs.patchDepth() % packet_size) == 0) {
1300           const Index start_col = rhs.colOffset();
1301           const Index max_col = rhs.maxCol(peeled_k);
1302 
1303           for (Index c = start_col; c < max_col; ++c) {
1304             eigen_assert(k <= peeled_k);
1305 
1306             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1307             const Index max_row = rhs.maxRow(peeled_k, c);
1308 
1309             const bool pad_col0 = dm0.padCol(c);
1310             const bool pad_col1 = dm1.padCol(c);
1311             const bool pad_col2 = dm2.padCol(c);
1312             const bool pad_col3 = dm3.padCol(c);
1313 
1314             // We can squeeze reads along the `row` and `depth` dimensions if
1315             // the row stride is `1`, which means that `row` and `depth`
1316             // dimensions are contiguous (two innermost dimensions).
1317             if (rhs.rowStride() == 1 &&                                //
1318                 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 &&    //
1319                 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) &&  //
1320                 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) &&  //
1321                 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) &&  //
1322                 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) {
1323               // Compute how many elements we can squeeze read.
1324               const Index start_depth =
1325                   (c == start_col) ? rhs.depthOffset() : 0;
1326 
1327               // Upper bound for the number of elements in the depth dimension
1328               // that we can squeeze read.
1329               const Index squeeze_length =
1330                   (max_row - start_row) * rhs.patchDepth() - start_depth;
1331 
1332               // Do not overshoot beyond the block size.
1333               const Index max_depth =
1334                   start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1335               eigen_assert((max_depth - start_depth) % packet_size == 0);
1336 
1337               const Index idx0 = dm0.baseIndex(start_row, c);
1338               const Index idx1 = dm1.baseIndex(start_row, c);
1339               const Index idx2 = dm2.baseIndex(start_row, c);
1340               const Index idx3 = dm3.baseIndex(start_row, c);
1341 
1342               for (Index d = start_depth; d < max_depth; d += packet_size) {
1343                 PacketBlock<Packet, 2> kernel0;
1344                 PacketBlock<Packet, 2> kernel1;
1345                 kernel0.packet[0] = rhs.packetNoPadding(d, idx0);
1346                 kernel0.packet[1] = rhs.packetNoPadding(d, idx1);
1347                 kernel1.packet[0] = rhs.packetNoPadding(d, idx2);
1348                 kernel1.packet[1] = rhs.packetNoPadding(d, idx3);
1349                 ptranspose(kernel0);
1350                 ptranspose(kernel1);
1351                 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1352                 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1353                 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1354                 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1355                 block += 4 * packet_size;
1356                 k += packet_size;
1357               }
1358 
1359               // Go to the next column.
1360               continue;
1361             }
1362 
1363             // If we can't squeeze reads, process rows one by one.
1364             for (Index r = start_row; r < max_row; ++r) {
1365               eigen_assert(k <= peeled_k);
1366 
1367               const bool pad0 = pad_col0 || dm0.padRow(r);
1368               const bool pad1 = pad_col1 || dm1.padRow(r);
1369               const bool pad2 = pad_col2 || dm2.padRow(r);
1370               const bool pad3 = pad_col3 || dm3.padRow(r);
1371 
1372               const Index idx0 = dm0.baseIndex(r, c);
1373               const Index idx1 = dm1.baseIndex(r, c);
1374               const Index idx2 = dm2.baseIndex(r, c);
1375               const Index idx3 = dm3.baseIndex(r, c);
1376 
1377               const Index start_depth = ((c == start_col) && (r == start_row))
1378                                             ? rhs.depthOffset()
1379                                             : 0;
1380               const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1381               eigen_assert((max_depth - start_depth) % packet_size == 0);
1382 
1383               for (Index d = start_depth; d < max_depth; d += packet_size) {
1384                 eigen_assert(k < peeled_k);
1385                 PacketBlock<Packet, 2> kernel0;
1386                 PacketBlock<Packet, 2> kernel1;
1387                 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1388                                          : rhs.packetNoPadding(d, idx0);
1389                 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1390                                          : rhs.packetNoPadding(d, idx1);
1391                 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
1392                                          : rhs.packetNoPadding(d, idx2);
1393                 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
1394                                          : rhs.packetNoPadding(d, idx3);
1395                 ptranspose(kernel0);
1396                 ptranspose(kernel1);
1397                 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1398                 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1399                 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1400                 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1401                 block += 4 * packet_size;
1402                 k += packet_size;
1403               }
1404             }
1405           }
1406 
1407           // The loop above should fill peeled_k elements.
1408           eigen_assert(peeled_k == k);
1409 
1410         } else {
1411           // Packet can span multiple rows or columns, so we have to go
1412           // though the slower "standard" path.
1413           for (; k < peeled_k; k += packet_size) {
1414             PacketBlock<Packet, 2> kernel0;
1415             PacketBlock<Packet, 2> kernel1;
1416             kernel0.packet[0] = dm0.loadPacketStandard(k);
1417             kernel0.packet[1] = dm1.loadPacketStandard(k);
1418             kernel1.packet[0] = dm2.loadPacketStandard(k);
1419             kernel1.packet[1] = dm3.loadPacketStandard(k);
1420             ptranspose(kernel0);
1421             ptranspose(kernel1);
1422             pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1423             pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1424             pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1425             pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1426             block += 4 * packet_size;
1427           }
1428         }
1429       }
1430 
1431       // Copy the remaining coefficients of the column block after the peeled_k.
1432       if (!non_standard_patches) {
1433         for (; k < depth; k++) {
1434           block[0] = dm0.loadCoeffStandard(k);
1435           block[1] = dm1.loadCoeffStandard(k);
1436           block[2] = dm2.loadCoeffStandard(k);
1437           block[3] = dm3.loadCoeffStandard(k);
1438           block += 4;
1439         }
1440       } else {
1441         for (; k < depth; k++) {
1442           block[0] = dm0(k);
1443           block[1] = dm1(k);
1444           block[2] = dm2(k);
1445           block[3] = dm3(k);
1446           block += 4;
1447         }
1448       }
1449     }
1450 
1451     // Copy the remaining columns one at a time (nr==1).
1452     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1453       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1454       for (Index k = 0; k < depth; k++) {
1455         *block = dm0(k);
1456         block += 1;
1457       }
1458     }
1459   }
1460 };
1461 
1462 // Special case for non-vectorized types such as float16.
1463 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1464           typename Device, typename Scalar, typename Index,
1465           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1466           bool inner_dim_reordered, int Alignment, int nr>
1467 struct gemm_pack_rhs<
1468     Scalar, Index,
1469     TensorContractionSubMapper<
1470         Scalar, Index, Rhs,
1471         TensorEvaluator<
1472             const TensorReshapingOp<
1473                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1474             Device>,
1475         nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1476         Alignment>,
1477     nr, ColMajor, false, false> {
1478   typedef TensorContractionSubMapper<
1479       Scalar, Index, Rhs,
1480       TensorEvaluator<
1481           const TensorReshapingOp<
1482               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1483           Device>,
1484       nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1485       Alignment>
1486       SubMapper;
1487   typedef SubMapper DataMapper;
1488 
1489   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1490 
1491   EIGEN_DEVICE_FUNC
1492   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1493                                     Index depth, Index cols, Index stride = 0,
1494                                     Index offset = 0) const {
1495     eigen_assert(stride == 0);
1496     eigen_assert(offset == 0);
1497 
1498     const Index packet_cols4 = (cols / 4) * 4;
1499 
1500     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1501       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1502       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1503       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1504       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1505 
1506       if (!rhs.nonStandardPatches()) {
1507         for (Index k = 0; k < depth; k++) {
1508           block[0] = dm0.loadCoeffStandard(k);
1509           block[1] = dm1.loadCoeffStandard(k);
1510           block[2] = dm2.loadCoeffStandard(k);
1511           block[3] = dm3.loadCoeffStandard(k);
1512           block += 4;
1513         }
1514       } else {
1515         for (Index k = 0; k < depth; k++) {
1516           block[0] = dm0(k);
1517           block[1] = dm1(k);
1518           block[2] = dm2(k);
1519           block[3] = dm3(k);
1520           block += 4;
1521         }
1522       }
1523     }
1524 
1525     // Copy the remaining columns one at a time (nr==1).
1526     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1527       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1528       for (Index k = 0; k < depth; k++) {
1529         *block = dm0(k);
1530         block += 1;
1531       }
1532     }
1533   }
1534 };
1535 }  // end namespace internal
1536 
1537 /** SpatialConvolution
1538  * \ingroup CXX11_NeuralNetworks_Module
1539  *
1540  * \brief Applies a 2D convolution over a multichannel input image.
1541  *
1542  * The input parameter is expected to be a tensor with a rank of 3 or more
1543  * (channels, height, width, and optionally others)
1544  * The kernel parameter is expected to be a 4D tensor (filters, channels,
1545  * kernel_height, kernel_width)
1546  * The input and the kernel must both be in col-major layout. The result will
1547  * also be in col-major layout.
1548  *
1549  * If col_in_stride, row_in_stride > 1, then applies convolution with holes
1550  * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
1551  * pixels.
1552  *
1553  * If padding_top, padding_bottom, padding_left, or padding_right is specified,
1554  * then those paddings will be used to pad the input, and padding_type must be
1555  * PADDING_VALID.
1556  *
1557  * The result can be assigned to a tensor of rank equal to the rank of the
1558  * input. The dimensions of the result will be filters, height, width (and
1559  * others if applicable).
1560  *
1561  * It is possible to swap the order of the width and height dimensions provided
1562  * that the same order is used in the input, the kernel, and the output.
1563  *
1564  * It is also possible to add an output kernel to the contraction, output
1565  * kernel is called by Eigen when it "finalizes" the block of an output tensor.
1566  *
1567  */
1568 template <typename Input, typename Kernel,
1569           typename OutputKernel = const NoOpOutputKernel>
1570 EIGEN_DEVICE_FUNC
1571     EIGEN_ALWAYS_INLINE static const typename internal::conditional<
1572         internal::traits<Input>::Layout == ColMajor,
1573         TensorReshapingOp<
1574             const DSizes<typename internal::traits<Input>::Index,
1575                          internal::traits<Input>::NumDimensions>,
1576             const TensorContractionOp<
1577                 const array<IndexPair<typename internal::traits<Input>::Index>,
1578                             1>,
1579                 const TensorReshapingOp<
1580                     const DSizes<typename internal::traits<Input>::Index, 2>,
1581                     const Kernel>,
1582                 const TensorReshapingOp<
1583                     const DSizes<typename internal::traits<Input>::Index, 2>,
1584                     const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
1585                 const OutputKernel> >,
1586         TensorReshapingOp<
1587             const DSizes<typename internal::traits<Input>::Index,
1588                          internal::traits<Input>::NumDimensions>,
1589             const TensorContractionOp<
1590                 const array<IndexPair<typename internal::traits<Input>::Index>,
1591                             1>,
1592                 const TensorReshapingOp<
1593                     const DSizes<typename internal::traits<Input>::Index, 2>,
1594                     const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
1595                 const TensorReshapingOp<
1596                     const DSizes<typename internal::traits<Input>::Index, 2>,
1597                     const Kernel>,
1598                 const OutputKernel> > >::type
1599     SpatialConvolution(const Input& input, const Kernel& kernel,
1600                        const Index row_stride = 1, const Index col_stride = 1,
1601                        const PaddingType padding_type = PADDING_SAME,
1602                        const Index row_in_stride = 1,
1603                        const Index col_in_stride = 1,
1604                        const OutputKernel& output_kernel = OutputKernel(),
1605                        Index padding_top = 0, Index padding_bottom = 0,
1606                        Index padding_left = 0, Index padding_right = 0) {
1607   typedef typename internal::traits<Input>::Index TensorIndex;
1608   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
1609                    internal::traits<Input>::NumDimensions,
1610                    internal::traits<Input>::Layout, TensorIndex> >
1611       in(input);
1612   TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1613                    internal::traits<Kernel>::NumDimensions,
1614                    internal::traits<Kernel>::Layout, TensorIndex> >
1615       kern(kernel);
1616 
1617   EIGEN_STATIC_ASSERT(
1618       internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1619       YOU_MADE_A_PROGRAMMING_MISTAKE)
1620   const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1621 
1622   const int NumDims = internal::traits<Input>::NumDimensions;
1623 
1624   // Number of filters to apply. This is the same as the output depth of the
1625   // result
1626   const TensorIndex kernelFilters =
1627       isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
1628   // Number of channels. This is the same as the input depth.
1629   const TensorIndex kernelChannels =
1630       isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
1631   const TensorIndex kernelRows =
1632       isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
1633   const TensorIndex kernelCols =
1634       isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
1635 
1636   const Index kernelRowsEff =
1637       kernelRows + (kernelRows - 1) * (row_in_stride - 1);
1638   const Index kernelColsEff =
1639       kernelCols + (kernelCols - 1) * (col_in_stride - 1);
1640 
1641   array<IndexPair<TensorIndex>, 1> contract_dims;
1642   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1643 
1644   const TensorIndex InputRows =
1645       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1646   const TensorIndex InputCols =
1647       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1648   const bool padding_explicit =
1649       (padding_top || padding_bottom || padding_left || padding_right);
1650 
1651   TensorIndex out_height;
1652   TensorIndex out_width;
1653   switch (padding_type) {
1654     case PADDING_VALID: {
1655       const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
1656       const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
1657       out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride);
1658       out_width = divup(InputColsEff - kernelColsEff + 1, col_stride);
1659       break;
1660     }
1661     case PADDING_SAME: {
1662       eigen_assert(!padding_explicit);
1663       out_height = divup(InputRows, row_stride);
1664       out_width = divup(InputCols, col_stride);
1665       break;
1666     }
1667     default: {
1668       // Initialize unused variables to avoid a compiler warning
1669       out_height = 0;
1670       out_width = 0;
1671       eigen_assert(false && "unexpected padding");
1672     }
1673   }
1674 
1675   // Molds the output of the patch extraction code into a 2d tensor:
1676   // - the first dimension (dims[0]): the patch values to be multiplied with the
1677   // kernels
1678   // - the second dimension (dims[1]): everything else
1679   DSizes<TensorIndex, 2> pre_contract_dims;
1680   if (isColMajor) {
1681     pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
1682     pre_contract_dims[1] = out_height * out_width;
1683     for (int i = 3; i < NumDims; ++i) {
1684       pre_contract_dims[1] *= in.dimension(i);
1685     }
1686   } else {
1687     pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
1688     pre_contract_dims[0] = out_height * out_width;
1689     for (int i = 0; i < NumDims - 3; ++i) {
1690       pre_contract_dims[0] *= in.dimension(i);
1691     }
1692   }
1693 
1694   // Molds the output of the contraction into the shape expected by the used
1695   // (assuming this is ColMajor):
1696   // - 1st dim: kernel filters
1697   // - 2nd dim: output height
1698   // - 3rd dim: output width
1699   // - 4th dim and beyond: everything else including batch size
1700   DSizes<TensorIndex, NumDims> post_contract_dims;
1701   if (isColMajor) {
1702     post_contract_dims[0] = kernelFilters;
1703     post_contract_dims[1] = out_height;
1704     post_contract_dims[2] = out_width;
1705     for (int i = 3; i < NumDims; ++i) {
1706       post_contract_dims[i] = in.dimension(i);
1707     }
1708   } else {
1709     post_contract_dims[NumDims - 1] = kernelFilters;
1710     post_contract_dims[NumDims - 2] = out_height;
1711     post_contract_dims[NumDims - 3] = out_width;
1712     for (int i = 0; i < NumDims - 3; ++i) {
1713       post_contract_dims[i] = in.dimension(i);
1714     }
1715   }
1716 
1717   DSizes<TensorIndex, 2> kernel_dims;
1718   if (isColMajor) {
1719     kernel_dims[0] = kernelFilters;
1720     kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
1721   } else {
1722     kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
1723     kernel_dims[1] = kernelFilters;
1724   }
1725   if (padding_explicit) {
1726     return choose(
1727         Cond<internal::traits<Input>::Layout == ColMajor>(),
1728         kernel.reshape(kernel_dims)
1729             .contract(input
1730                           .extract_image_patches(
1731                               kernelRows, kernelCols, row_stride, col_stride,
1732                               row_in_stride, col_in_stride,
1733                               /*row_inflate_stride=*/1,
1734                               /*col_inflate_stride=*/1, padding_top,
1735                               padding_bottom, padding_left, padding_right,
1736                               /*padding_value=*/0)
1737                           .reshape(pre_contract_dims),
1738                       contract_dims, output_kernel)
1739             .reshape(post_contract_dims),
1740         input
1741             .extract_image_patches(kernelRows, kernelCols, row_stride,
1742                                    col_stride, row_in_stride, col_in_stride,
1743                                    /*row_inflate_stride=*/1,
1744                                    /*col_inflate_stride=*/1, padding_top,
1745                                    padding_bottom, padding_left, padding_right,
1746                                    /*padding_value=*/0)
1747             .reshape(pre_contract_dims)
1748             .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1749             .reshape(post_contract_dims));
1750   } else {
1751     return choose(
1752         Cond<internal::traits<Input>::Layout == ColMajor>(),
1753         kernel.reshape(kernel_dims)
1754             .contract(input
1755                           .extract_image_patches(
1756                               kernelRows, kernelCols, row_stride, col_stride,
1757                               row_in_stride, col_in_stride, padding_type)
1758                           .reshape(pre_contract_dims),
1759                       contract_dims, output_kernel)
1760             .reshape(post_contract_dims),
1761         input
1762             .extract_image_patches(kernelRows, kernelCols, row_stride,
1763                                    col_stride, row_in_stride, col_in_stride,
1764                                    padding_type)
1765             .reshape(pre_contract_dims)
1766             .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1767             .reshape(post_contract_dims));
1768   }
1769 }
1770 
1771 }  // end namespace Eigen
1772 
1773 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
1774