1#
2# Copyright (C) 2020 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17# Model: z = if (x) then (y + 10) else (y - 10)
18
19input_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
20output_add = [y + 10 for y in input_data]
21output_sub = [y - 10 for y in input_data]
22
23ValueType = ["TENSOR_FLOAT32", [3, 4]]
24BoolType = ["TENSOR_BOOL8", [1]]
25
26def MakeBranchModel(operation_name):
27  y = Input("y", ValueType)
28  z = Output("z", ValueType)
29  return Model().Operation(operation_name, y, [10.0], 0).To(z)
30
31def Test(x_data, y_data, z_data, name):
32  x = Input("x", BoolType)
33  y = Input("y", ValueType)
34  z = Output("z", ValueType)
35  then_model = MakeBranchModel("ADD")
36  else_model = MakeBranchModel("SUB")
37  model = Model().Operation("IF", x, then_model, else_model, y).To(z)
38
39  quant8 = DataTypeConverter("quant8", scale=1.0, zeroPoint=100)
40  quant8_signed = DataTypeConverter("quant8_signed", scale=1.0, zeroPoint=100)
41
42  example = Example({x: [x_data], y: y_data, z: z_data}, name=name)
43  example.AddVariations("relaxed", "float16", "int32", quant8, quant8_signed)
44  example.AddVariations(AllOutputsAsInternalCoverter())
45
46Test(x_data=True, y_data=input_data, z_data=output_add, name="true")
47Test(x_data=False, y_data=input_data, z_data=output_sub, name="false")
48