1#
2# Copyright (C) 2018 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17def test(input0, axis, indices, output0, input_data, output_data):
18  model = Model().Operation("GATHER", input0, axis, indices).To(output0)
19
20  quant8 = DataTypeConverter().Identify({
21      input0: ["TENSOR_QUANT8_ASYMM", 0.5, 127],
22      output0: ["TENSOR_QUANT8_ASYMM", 0.5, 127],
23  })
24
25  int32 = DataTypeConverter().Identify({
26      input0: ["TENSOR_INT32"],
27      output0: ["TENSOR_INT32"],
28  })
29
30  float16 = DataTypeConverter().Identify({
31      input0: ["TENSOR_FLOAT16"],
32      output0: ["TENSOR_FLOAT16"],
33  })
34
35  Example({
36      input0: input_data,
37      output0: output_data,
38  }, model=model).AddVariations("relaxed", quant8, int32, float16)
39
40test(
41    input0=Input("input0", "TENSOR_FLOAT32", "{2, 2}"),
42    axis=0,
43    indices=[1, 0],
44    output0=Output("output0", "TENSOR_FLOAT32", "{2, 2}"),
45    input_data=[-2.0, 0.2,
46                 0.7, 0.8],
47    output_data=[0.7, 0.8,
48                -2.0, 0.2],
49)
50
51test(
52    input0=Input("input0", "TENSOR_FLOAT32", "{2, 2}"),
53    axis=0,
54    indices=[1], # Unlike TensorFlow, 0-D arguments and outputs are not supported.
55    output0=Output("output0", "TENSOR_FLOAT32", "{1, 2}"),
56    input_data=[-2.0, 0.2,
57                 0.7, 0.8],
58    output_data=[0.7, 0.8],
59)
60
61test(
62    input0=Input("input0", "TENSOR_FLOAT32", "{3}"),
63    axis=0,
64    indices=[1],
65    output0=Output("output0", "TENSOR_FLOAT32", "{1}"),
66    input_data=[1, 2, 3],
67    output_data=[2],
68)
69
70test(
71    input0=Input("input0", "TENSOR_FLOAT32", "{3}"),
72    axis=0,
73    indices=[1, 0],
74    output0=Output("output0", "TENSOR_FLOAT32", "{2}"),
75    input_data=[1, 2, 3],
76    output_data=[2, 1],
77)
78
79test(
80    input0=Input("input0", "TENSOR_FLOAT32", "{1, 2, 2}"),
81    axis=0,
82    indices=[0, 0],
83    output0=Output("output0", "TENSOR_FLOAT32", "{2, 2, 2}"),
84    input_data=[-2.0, 0.2,
85                 0.7, 0.8],
86    output_data=[-2.0, 0.2,
87                  0.7, 0.8,
88                 -2.0, 0.2,
89                  0.7, 0.8],
90)
91
92test(
93    input0=Input("input0", "TENSOR_FLOAT32", "{4, 1}"),
94    axis=0,
95    indices=[1, 3],
96    output0=Output("output0", "TENSOR_FLOAT32", "{2, 1}"),
97    input_data=[-2.0, 0.2, 0.7, 0.8],
98    output_data=[0.2, 0.8],
99)
100
101test(
102    input0=Input("input0", "TENSOR_FLOAT32", "{1, 2, 3}"),
103    axis=1,
104    indices=[1, 0],
105    output0=Output("output0", "TENSOR_FLOAT32", "{1, 2, 3}"),
106    input_data=[1, 2, 3,
107                4, 5, 6],
108    output_data=[4, 5, 6,
109                 1, 2, 3],
110)
111
112test(
113    input0=Input("input0", "TENSOR_FLOAT32", "{1, 2, 3}"),
114    axis=-1,
115    indices=[2, 0],
116    output0=Output("output0", "TENSOR_FLOAT32", "{1, 2, 2}"),
117    input_data=[1, 2, 3,
118                4, 5, 6],
119    output_data=[3, 1,
120                 6, 4],
121)
122