1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for ArgMin and ArgMax Ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.compiler.tests import xla_test 24from tensorflow.python.framework import dtypes 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.platform import test 28 29 30class ArgMinMaxTest(xla_test.XLATestCase): 31 32 def _assertOpOutputMatchesExpected(self, op, axis, output_type, op_input, 33 expected): 34 """Verifies that 'op' produces 'expected' when fed input 'op_input' . 35 36 Args: 37 op: argmin or argmax operator to test. 38 axis: integer axis to reduce across. 39 output_type: numpy datatype of the output to produce. 40 op_input: numpy input array to use as input to 'op'. 41 expected: numpy array representing the expected output of 'op'. 42 """ 43 with self.cached_session() as session: 44 with self.test_scope(): 45 pinp = array_ops.placeholder( 46 dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") 47 output = op(pinp, axis=axis, output_type=output_type) 48 result = session.run(output, {pinp: op_input}) 49 self.assertAllEqual(result, expected) 50 51 def testArgMinMax(self): 52 # Complex numbers do not support argmin/argmax. 53 minmax_types = self.all_types & {np.int32, np.int64} 54 for dtype in minmax_types: 55 # output_type is a numpy data type that is used to specify the desired 56 # output type of the op as well as to convert the Python number to the 57 # array scalar of the type. 58 for output_type in minmax_types: 59 self._assertOpOutputMatchesExpected( 60 math_ops.argmax, 61 axis=0, 62 output_type=output_type, 63 op_input=np.array([1, 10, 27, 3, 3, 4], dtype=dtype), 64 expected=output_type(2)) 65 self._assertOpOutputMatchesExpected( 66 math_ops.argmax, 67 axis=0, 68 output_type=output_type, 69 op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), 70 expected=np.array([0, 1, 0], dtype=output_type)) 71 self._assertOpOutputMatchesExpected( 72 math_ops.argmax, 73 axis=1, 74 output_type=output_type, 75 op_input=np.array([[4, 1], [3, 2]], dtype=dtype), 76 expected=np.array([0, 0], dtype=output_type)) 77 78 self._assertOpOutputMatchesExpected( 79 math_ops.argmin, 80 axis=0, 81 output_type=output_type, 82 op_input=np.array([3, 10, 27, 3, 2, 4], dtype=dtype), 83 expected=output_type(4)) 84 self._assertOpOutputMatchesExpected( 85 math_ops.argmin, 86 axis=0, 87 output_type=output_type, 88 op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), 89 expected=np.array([1, 0, 1], dtype=output_type)) 90 self._assertOpOutputMatchesExpected( 91 math_ops.argmin, 92 axis=1, 93 output_type=output_type, 94 op_input=np.array([[4, 1], [3, 2]], dtype=dtype), 95 expected=np.array([1, 1], dtype=output_type)) 96 97 98if __name__ == "__main__": 99 test.main() 100