1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test cases for segment reduction ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23import numpy as np 24 25from tensorflow.compiler.tests import xla_test 26from tensorflow.python.framework import dtypes 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.platform import googletest 30 31 32class SegmentReductionOpsTest(xla_test.XLATestCase): 33 """Test cases for segment reduction ops.""" 34 35 def _segmentReduction(self, op, data, indices, num_segments): 36 with self.session() as sess, self.test_scope(): 37 d = array_ops.placeholder(data.dtype, shape=data.shape) 38 if isinstance(indices, int): 39 i = array_ops.placeholder(np.int32, shape=[]) 40 else: 41 i = array_ops.placeholder(indices.dtype, shape=indices.shape) 42 return sess.run(op(d, i, num_segments), {d: data, i: indices}) 43 44 def _unsortedSegmentSum(self, data, indices, num_segments): 45 return self._segmentReduction(math_ops.unsorted_segment_sum, data, indices, 46 num_segments) 47 48 def _unsortedSegmentProd(self, data, indices, num_segments): 49 return self._segmentReduction(math_ops.unsorted_segment_prod, data, indices, 50 num_segments) 51 52 def _unsortedSegmentMin(self, data, indices, num_segments): 53 return self._segmentReduction(math_ops.unsorted_segment_min, data, indices, 54 num_segments) 55 56 def _unsortedSegmentMax(self, data, indices, num_segments): 57 return self._segmentReduction(math_ops.unsorted_segment_max, data, indices, 58 num_segments) 59 60 def testUnsortedSegmentSum0DIndices1DData(self): 61 for dtype in self.numeric_types: 62 self.assertAllClose( 63 np.array( 64 [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5], 65 [0, 0, 0, 0, 0, 0]], 66 dtype=dtype), 67 self._unsortedSegmentSum( 68 np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4)) 69 70 def testUnsortedSegmentSum1DIndices1DData(self): 71 for dtype in self.numeric_types: 72 self.assertAllClose( 73 np.array([1, 3, 2, 9], dtype=dtype), 74 self._unsortedSegmentSum( 75 np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 76 np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) 77 78 def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self): 79 for dtype in self.numeric_types: 80 self.assertAllClose( 81 np.array([6, 3, 0, 6], dtype=dtype), 82 self._unsortedSegmentSum( 83 np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), 84 np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) 85 86 def testUnsortedSegmentSum1DIndices2DDataDisjoint(self): 87 for dtype in self.numeric_types: 88 data = np.array( 89 [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43], 90 [50, 51, 52, 53]], 91 dtype=dtype) 92 indices = np.array([8, 1, 0, 3, 7], dtype=np.int32) 93 num_segments = 10 94 y = self._unsortedSegmentSum(data, indices, num_segments) 95 self.assertAllClose( 96 np.array( 97 [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0], 98 [40, 41, 42, 43], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], 99 [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]], 100 dtype=dtype), y) 101 102 def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self): 103 for dtype in self.numeric_types: 104 data = np.array( 105 [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43], 106 [50, 51, 52, 53]], 107 dtype=dtype) 108 indices = np.array([0, 1, 2, 0, 1], dtype=np.int32) 109 num_segments = 4 110 y = self._unsortedSegmentSum(data, indices, num_segments) 111 self.assertAllClose( 112 np.array( 113 [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33], 114 [0, 0, 0, 0]], 115 dtype=dtype), y) 116 117 def testUnsortedSegmentSum2DIndices3DData(self): 118 for dtype in self.numeric_types: 119 data = np.array( 120 [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ 121 200, 201, 202 122 ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], 123 dtype=dtype) 124 indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32) 125 num_segments = 8 126 y = self._unsortedSegmentSum(data, indices, num_segments) 127 self.assertAllClose( 128 np.array( 129 [[210, 211, 212], [110, 111, 112], [310, 311, 312], [ 130 100, 102, 104 131 ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]], 132 dtype=dtype), y) 133 134 def testUnsortedSegmentSum1DIndices3DData(self): 135 for dtype in self.numeric_types: 136 data = np.array( 137 [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ 138 200, 201, 202 139 ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], 140 dtype=dtype) 141 indices = np.array([3, 0, 2, 5], dtype=np.int32) 142 num_segments = 6 143 y = self._unsortedSegmentSum(data, indices, num_segments) 144 self.assertAllClose( 145 np.array( 146 [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]], 147 [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]], 148 [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]], 149 dtype=dtype), y) 150 151 def testUnsortedSegmentSumShapeError(self): 152 for dtype in self.numeric_types: 153 data = np.ones((4, 8, 7), dtype=dtype) 154 indices = np.ones((3, 2), dtype=np.int32) 155 num_segments = 4 156 self.assertRaises( 157 ValueError, 158 functools.partial(self._segmentReduction, 159 math_ops.unsorted_segment_sum, data, indices, 160 num_segments)) 161 162 def testUnsortedSegmentOps1DIndices1DDataNegativeIndices(self): 163 """Tests for min, max, and prod ops. 164 165 These share most of their implementation with sum, so we only test basic 166 functionality. 167 """ 168 for dtype in self.numeric_types: 169 self.assertAllClose( 170 np.array([8, 3, 1, 0], dtype=dtype), 171 self._unsortedSegmentProd( 172 np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), 173 np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) 174 175 for dtype in self.int_types | self.float_types: 176 minval = dtypes.as_dtype(dtype).min 177 maxval = dtypes.as_dtype(dtype).max 178 179 self.assertAllClose( 180 np.array([2, 3, maxval, 0], dtype=dtype), 181 self._unsortedSegmentMin( 182 np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), 183 np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) 184 self.assertAllClose( 185 np.array([4, 3, minval, 6], dtype=dtype), 186 self._unsortedSegmentMax( 187 np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), 188 np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) 189 190 191if __name__ == "__main__": 192 googletest.main() 193