1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for segment reduction ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23import numpy as np
24
25from tensorflow.compiler.tests import xla_test
26from tensorflow.python.framework import dtypes
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.platform import googletest
30
31
32class SegmentReductionOpsTest(xla_test.XLATestCase):
33  """Test cases for segment reduction ops."""
34
35  def _segmentReduction(self, op, data, indices, num_segments):
36    with self.session() as sess, self.test_scope():
37      d = array_ops.placeholder(data.dtype, shape=data.shape)
38      if isinstance(indices, int):
39        i = array_ops.placeholder(np.int32, shape=[])
40      else:
41        i = array_ops.placeholder(indices.dtype, shape=indices.shape)
42      return sess.run(op(d, i, num_segments), {d: data, i: indices})
43
44  def _unsortedSegmentSum(self, data, indices, num_segments):
45    return self._segmentReduction(math_ops.unsorted_segment_sum, data, indices,
46                                  num_segments)
47
48  def _unsortedSegmentProd(self, data, indices, num_segments):
49    return self._segmentReduction(math_ops.unsorted_segment_prod, data, indices,
50                                  num_segments)
51
52  def _unsortedSegmentMin(self, data, indices, num_segments):
53    return self._segmentReduction(math_ops.unsorted_segment_min, data, indices,
54                                  num_segments)
55
56  def _unsortedSegmentMax(self, data, indices, num_segments):
57    return self._segmentReduction(math_ops.unsorted_segment_max, data, indices,
58                                  num_segments)
59
60  def testUnsortedSegmentSum0DIndices1DData(self):
61    for dtype in self.numeric_types:
62      self.assertAllClose(
63          np.array(
64              [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5],
65               [0, 0, 0, 0, 0, 0]],
66              dtype=dtype),
67          self._unsortedSegmentSum(
68              np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4))
69
70  def testUnsortedSegmentSum1DIndices1DData(self):
71    for dtype in self.numeric_types:
72      self.assertAllClose(
73          np.array([1, 3, 2, 9], dtype=dtype),
74          self._unsortedSegmentSum(
75              np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
76              np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4))
77
78  def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self):
79    for dtype in self.numeric_types:
80      self.assertAllClose(
81          np.array([6, 3, 0, 6], dtype=dtype),
82          self._unsortedSegmentSum(
83              np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
84              np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
85
86  def testUnsortedSegmentSum1DIndices2DDataDisjoint(self):
87    for dtype in self.numeric_types:
88      data = np.array(
89          [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
90           [50, 51, 52, 53]],
91          dtype=dtype)
92      indices = np.array([8, 1, 0, 3, 7], dtype=np.int32)
93      num_segments = 10
94      y = self._unsortedSegmentSum(data, indices, num_segments)
95      self.assertAllClose(
96          np.array(
97              [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0],
98               [40, 41, 42, 43], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],
99               [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]],
100              dtype=dtype), y)
101
102  def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self):
103    for dtype in self.numeric_types:
104      data = np.array(
105          [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
106           [50, 51, 52, 53]],
107          dtype=dtype)
108      indices = np.array([0, 1, 2, 0, 1], dtype=np.int32)
109      num_segments = 4
110      y = self._unsortedSegmentSum(data, indices, num_segments)
111      self.assertAllClose(
112          np.array(
113              [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33],
114               [0, 0, 0, 0]],
115              dtype=dtype), y)
116
117  def testUnsortedSegmentSum2DIndices3DData(self):
118    for dtype in self.numeric_types:
119      data = np.array(
120          [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[
121              200, 201, 202
122          ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]],
123          dtype=dtype)
124      indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32)
125      num_segments = 8
126      y = self._unsortedSegmentSum(data, indices, num_segments)
127      self.assertAllClose(
128          np.array(
129              [[210, 211, 212], [110, 111, 112], [310, 311, 312], [
130                  100, 102, 104
131              ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]],
132              dtype=dtype), y)
133
134  def testUnsortedSegmentSum1DIndices3DData(self):
135    for dtype in self.numeric_types:
136      data = np.array(
137          [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[
138              200, 201, 202
139          ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]],
140          dtype=dtype)
141      indices = np.array([3, 0, 2, 5], dtype=np.int32)
142      num_segments = 6
143      y = self._unsortedSegmentSum(data, indices, num_segments)
144      self.assertAllClose(
145          np.array(
146              [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]],
147               [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]],
148               [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]],
149              dtype=dtype), y)
150
151  def testUnsortedSegmentSumShapeError(self):
152    for dtype in self.numeric_types:
153      data = np.ones((4, 8, 7), dtype=dtype)
154      indices = np.ones((3, 2), dtype=np.int32)
155      num_segments = 4
156      self.assertRaises(
157          ValueError,
158          functools.partial(self._segmentReduction,
159                            math_ops.unsorted_segment_sum, data, indices,
160                            num_segments))
161
162  def testUnsortedSegmentOps1DIndices1DDataNegativeIndices(self):
163    """Tests for min, max, and prod ops.
164
165    These share most of their implementation with sum, so we only test basic
166    functionality.
167    """
168    for dtype in self.numeric_types:
169      self.assertAllClose(
170          np.array([8, 3, 1, 0], dtype=dtype),
171          self._unsortedSegmentProd(
172              np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
173              np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
174
175    for dtype in self.int_types | self.float_types:
176      minval = dtypes.as_dtype(dtype).min
177      maxval = dtypes.as_dtype(dtype).max
178
179      self.assertAllClose(
180          np.array([2, 3, maxval, 0], dtype=dtype),
181          self._unsortedSegmentMin(
182              np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
183              np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
184      self.assertAllClose(
185          np.array([4, 3, minval, 6], dtype=dtype),
186          self._unsortedSegmentMax(
187              np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
188              np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
189
190
191if __name__ == "__main__":
192  googletest.main()
193