1#
2# Copyright (C) 2020 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17# Test for QUANTIZED_LSTM op.
18import copy
19
20model = Model()
21
22batch_size = 2
23input_size = 5
24num_units = 4
25output_size = 3
26
27InputType = ("TENSOR_QUANT8_ASYMM_SIGNED", [batch_size, input_size], 0.0078125, 0)
28input = Input("input", InputType)
29
30InputWeightsType = ("TENSOR_QUANT8_SYMM", [num_units, input_size], 0.00784314, 0)
31input_to_input_weights = Input("input_to_input_weights", InputWeightsType)
32input_to_forget_weights = Input("input_to_forget_weights", InputWeightsType)
33input_to_cell_weights = Input("input_to_cell_weights", InputWeightsType)
34input_to_output_weights = Input("input_to_output_weights", InputWeightsType)
35
36RecurrentWeightsType = ("TENSOR_QUANT8_SYMM", [num_units, output_size], 0.00784314, 0)
37recurrent_to_input_weights = Input("recurrent_to_input_weights", RecurrentWeightsType)
38recurrent_to_forget_weights = Input("recurrent_to_forget_weights", RecurrentWeightsType)
39recurrent_to_cell_weights = Input("recurrent_to_cell_weights", RecurrentWeightsType)
40recurrent_to_output_weights = Input("recurrent_to_output_weights", RecurrentWeightsType)
41
42CellWeightsType = ("TENSOR_QUANT16_SYMM", [num_units], 1.0, 0)
43cell_to_input_weights = Input("cell_to_input_weights", CellWeightsType)
44cell_to_forget_weights = Input("cell_to_forget_weights", CellWeightsType)
45cell_to_output_weights = Input("cell_to_output_weights", CellWeightsType)
46
47# The bias scale value here is not used.
48BiasType = ("TENSOR_INT32", [num_units], 0.0, 0)
49input_gate_bias = Input("input_gate_bias", BiasType)
50forget_gate_bias = Input("forget_gate_bias", BiasType)
51cell_gate_bias = Input("cell_gate_bias", BiasType)
52output_gate_bias = Input("output_gate_bias", BiasType)
53
54projection_weights = Input("projection_weights",
55                           ("TENSOR_QUANT8_SYMM", [output_size, num_units], 0.00392157, 0))
56projection_bias = Input("projection_bias", ("TENSOR_INT32", [output_size]))
57
58OutputStateType = ("TENSOR_QUANT8_ASYMM_SIGNED", [batch_size, output_size], 3.05176e-05, 0)
59CellStateType = ("TENSOR_QUANT16_SYMM", [batch_size, num_units], 3.05176e-05, 0)
60output_state_in = Input("output_state_in", OutputStateType)
61cell_state_in = Input("cell_state_in", CellStateType)
62
63LayerNormType = ("TENSOR_QUANT16_SYMM", [num_units], 3.05182e-05, 0)
64input_layer_norm_weights = Input("input_layer_norm_weights", LayerNormType)
65forget_layer_norm_weights = Input("forget_layer_norm_weights", LayerNormType)
66cell_layer_norm_weights = Input("cell_layer_norm_weights", LayerNormType)
67output_layer_norm_weights = Input("output_layer_norm_weights", LayerNormType)
68
69cell_clip = Float32Scalar("cell_clip", 0.)
70projection_clip = Float32Scalar("projection_clip", 0.)
71
72input_intermediate_scale = Float32Scalar("input_intermediate_scale", 0.007059)
73forget_intermediate_scale = Float32Scalar("forget_intermediate_scale", 0.007812)
74cell_intermediate_scale = Float32Scalar("cell_intermediate_scale", 0.007059)
75output_intermediate_scale = Float32Scalar("output_intermediate_scale", 0.007812)
76hidden_state_zero_point = Int32Scalar("hidden_state_zero_point", 0)
77hidden_state_scale = Float32Scalar("hidden_state_scale", 0.007)
78
79output_state_out = Output("output_state_out", OutputStateType)
80cell_state_out = Output("cell_state_out", CellStateType)
81output = Output("output", OutputStateType)
82
83model = model.Operation(
84    "QUANTIZED_LSTM", input, input_to_input_weights, input_to_forget_weights,
85    input_to_cell_weights, input_to_output_weights, recurrent_to_input_weights,
86    recurrent_to_forget_weights, recurrent_to_cell_weights,
87    recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights,
88    cell_to_output_weights, input_gate_bias, forget_gate_bias, cell_gate_bias,
89    output_gate_bias, projection_weights, projection_bias, output_state_in,
90    cell_state_in, input_layer_norm_weights, forget_layer_norm_weights,
91    cell_layer_norm_weights, output_layer_norm_weights, cell_clip, projection_clip,
92    input_intermediate_scale, forget_intermediate_scale, cell_intermediate_scale,
93    output_intermediate_scale, hidden_state_zero_point, hidden_state_scale).To([output_state_out,
94    cell_state_out, output])
95
96# Example 1. Layer Norm, Projection.
97input0 = {
98    input_to_input_weights: [
99        64, 77, 89, -102, -115, 13, 25, 38, -51, 64, -102, 89, -77, 64, -51, -64, -51, -38, -25, -13
100    ],
101    input_to_forget_weights: [
102        -77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64
103    ],
104    input_to_cell_weights: [
105        -51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77
106    ],
107    input_to_output_weights: [
108        -102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51
109    ],
110    input_gate_bias: [644245, 3221226, 4724464, 8160438],
111    forget_gate_bias: [2147484, -6442451, -4294968, 2147484],
112    cell_gate_bias: [-1073742, 15461883, 5368709, 1717987],
113    output_gate_bias: [1073742, -214748, 4294968, 2147484],
114    recurrent_to_input_weights: [
115        -25, -38, 51, 13, -64, 115, -25, -38, -89, 6, -25, -77
116    ],
117    recurrent_to_forget_weights: [
118        -64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25
119    ],
120    recurrent_to_cell_weights: [
121        -38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25
122    ],
123    recurrent_to_output_weights: [
124        38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25
125    ],
126    projection_weights: [
127        -25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51
128    ],
129    projection_bias: [ 0 for _ in range(output_size) ],
130    input_layer_norm_weights: [3277, 6553, 9830, 16384],
131    forget_layer_norm_weights: [6553, 6553, 13107, 9830],
132    cell_layer_norm_weights: [22937, 6553, 9830, 26214],
133    output_layer_norm_weights: [19660, 6553, 6553, 16384],
134    output_state_in: [ 0 for _ in range(batch_size * output_size) ],
135    cell_state_in: [ 0 for _ in range(batch_size * num_units) ],
136    cell_to_input_weights: [],
137    cell_to_forget_weights: [],
138    cell_to_output_weights: [],
139}
140
141test_input = [90, 102, 13, 26, 38, 102, 13, 26, 51, 64]
142
143golden_output = [
144    127, 127, -108, -67, 127, 127
145]
146
147output0 = {
148    output_state_out: golden_output,
149    cell_state_out: [-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939],
150    output: golden_output,
151}
152
153input0[input] = test_input
154
155Example((input0, output0))
156
157# Example 2. CIFG, Layer Norm, Projection.
158input0 = {
159    input_to_input_weights: [],
160    input_to_forget_weights: [
161        -77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64
162    ],
163    input_to_cell_weights: [
164        -51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77
165    ],
166    input_to_output_weights: [
167        -102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51
168    ],
169    input_gate_bias: [],
170    forget_gate_bias: [2147484, -6442451, -4294968, 2147484],
171    cell_gate_bias: [-1073742, 15461883, 5368709, 1717987],
172    output_gate_bias: [1073742, -214748, 4294968, 2147484],
173    recurrent_to_input_weights: [],
174    recurrent_to_forget_weights: [
175        -64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25
176    ],
177    recurrent_to_cell_weights: [
178        -38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25
179    ],
180    recurrent_to_output_weights: [
181        38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25
182    ],
183    projection_weights: [
184        -25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51
185    ],
186    projection_bias: [ 0 for _ in range(output_size) ],
187    input_layer_norm_weights: [],
188    forget_layer_norm_weights: [6553, 6553, 13107, 9830],
189    cell_layer_norm_weights: [22937, 6553, 9830, 26214],
190    output_layer_norm_weights: [19660, 6553, 6553, 16384],
191    output_state_in: [ 0 for _ in range(batch_size * output_size) ],
192    cell_state_in: [ 0 for _ in range(batch_size * num_units) ],
193    cell_to_input_weights: [],
194    cell_to_forget_weights: [],
195    cell_to_output_weights: [],
196}
197
198test_input = [90, 102, 13, 26, 38, 102, 13, 26, 51, 64]
199
200golden_output = [
201    127, 127, 127, -128, 127, 127
202]
203
204output0 = {
205    output_state_out: golden_output,
206    cell_state_out: [-11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149],
207    output: golden_output,
208}
209
210input0[input] = test_input
211
212Example((input0, output0))
213