1# 2# Copyright (C) 2017 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17batches = 2 18units = 16 19input_size = 8 20 21model = Model() 22 23input = Input("input", "TENSOR_FLOAT32", "{%d, %d}" % (batches, input_size)) 24weights = Input("weights", "TENSOR_FLOAT32", "{%d, %d}" % (units, input_size)) 25recurrent_weights = Input("recurrent_weights", "TENSOR_FLOAT32", "{%d, %d}" % (units, units)) 26bias = Input("bias", "TENSOR_FLOAT32", "{%d}" % (units)) 27hidden_state_in = Input("hidden_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (batches, units)) 28 29activation_param = Int32Scalar("activation_param", 1) # Relu 30 31hidden_state_out = IgnoredOutput("hidden_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (batches, units)) 32output = Output("output", "TENSOR_FLOAT32", "{%d, %d}" % (batches, units)) 33 34model = model.Operation("RNN", input, weights, recurrent_weights, bias, hidden_state_in, 35 activation_param).To([hidden_state_out, output]) 36 37input0 = { 38 weights: [ 39 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 40 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, 41 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, 42 -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, 43 -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, 44 -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, 45 -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, 46 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, 47 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, 48 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, 49 -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, 50 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, 51 -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, 52 -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, 53 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, 54 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, 55 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, 56 -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, 57 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, 58 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, 59 -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, 60 0.277308, 0.415818 61 ], 62 recurrent_weights: [ 63 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 65 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 66 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 67 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 68 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 69 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 70 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 71 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 73 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 74 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 75 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 76 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 77 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 78 0.1 79 ], 80 bias: [ 81 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, 82 -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, 83 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, 84 -0.37609905 85 ], 86} 87 88 89test_inputs = [ 90 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, 91 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, 92 -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, 93 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, 94 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, 95 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, 96 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, 97 -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, 98 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, 99 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, 100 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, 101 -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, 102 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, 103 -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, 104 -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, 105 -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, 106 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, 107 -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, 108 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, 109 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, 110 -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, 111 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, 112 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, 113 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, 114 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, 115 0.93455386, -0.6324693, -0.083922029 116] 117 118golden_outputs = [ 119 0.496726, 0, 0.965996, 0, 0.0584254, 0, 120 0, 0.12315, 0, 0, 0.612266, 0.456601, 121 0, 0.52286, 1.16099, 0.0291232, 122 123 0, 0, 0.524901, 0, 0, 0, 124 0, 1.02116, 0, 1.35762, 0, 0.356909, 125 0.436415, 0.0355727, 0, 0, 126 127 0, 0, 0, 0.262335, 0, 0, 128 0, 1.33992, 0, 2.9739, 0, 0, 129 1.31914, 2.66147, 0, 0, 130 131 0.942568, 0, 0, 0, 0.025507, 0, 132 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, 133 0.8158, 1.21805, 0.586239, 0.25427, 134 135 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, 136 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, 137 0, 1.22031, 1.30117, 0.495867, 138 139 0.222187, 0, 0.72725, 0, 0.767003, 0, 140 0, 0.147835, 0, 0, 0, 0.608758, 141 0.469394, 0.00720298, 0.927537, 0, 142 143 0.856974, 0.424257, 0, 0, 0.937329, 0, 144 0, 0, 0.476425, 0, 0.566017, 0.418462, 145 0.141911, 0.996214, 1.13063, 0, 146 147 0.967899, 0, 0, 0, 0.0831304, 0, 148 0, 1.00378, 0, 0, 0, 1.44818, 149 1.01768, 0.943891, 0.502745, 0, 150 151 0.940135, 0, 0, 0, 0, 0, 152 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, 153 1.30225, 1.59644, 0.70222, 0, 154 155 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, 156 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, 157 0.0454298, 0.300267, 0.562784, 0.395095, 158 159 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, 160 0, 0, 0, 0.735363, 0.0759267, 1.91017, 161 0.941888, 0, 0, 0, 162 163 0, 0, 1.5909, 0, 0, 0, 164 0, 0.5755, 0, 0.184687, 0, 1.56296, 165 0.625285, 0, 0, 0, 166 167 0, 0, 0.0857888, 0, 0, 0, 168 0, 0.488383, 0.252786, 0, 0, 0, 169 1.02817, 1.85665, 0, 0, 170 171 0.00981836, 0, 1.06371, 0, 0, 0, 172 0, 0, 0, 0.290445, 0.316406, 0, 173 0.304161, 1.25079, 0.0707152, 0, 174 175 0.986264, 0.309201, 0, 0, 0, 0, 176 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, 177 0.524981, 1.92076, 2.07013, 0.333244, 178 179 0.415153, 0.210318, 0, 0, 0, 0, 180 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, 181 0.628881, 3.58099, 1.49974, 0 182] 183 184input_sequence_size = int(len(test_inputs) / input_size / batches) 185 186# TODO: enable the other data points after fixing reference issues 187#for i in range(input_sequence_size): 188for i in range(1): 189 input_begin = i * input_size 190 input_end = input_begin + input_size 191 input0[input] = test_inputs[input_begin:input_end] 192 input0[input].extend(input0[input]) 193 input0[hidden_state_in] = [0 for x in range(batches * units)] 194 output0 = { 195 hidden_state_out: [0 for x in range(batches * units)], 196 } 197 golden_start = i * units 198 golden_end = golden_start + units 199 output0[output] = golden_outputs[golden_start:golden_end] 200 output0[output].extend(output0[output]) 201 Example((input0, output0)) 202