1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fields/scalar_field.h"
18 
19 #include "fields/fixed_scalar_field.h"
20 #include "fields/size_field.h"
21 #include "util.h"
22 
23 const std::string ScalarField::kFieldType = "ScalarField";
24 
ScalarField(std::string name,int size,ParseLocation loc)25 ScalarField::ScalarField(std::string name, int size, ParseLocation loc) : PacketField(name, loc), size_(size) {
26   if (size_ > 64 || size_ < 0) {
27     ERROR(this) << "Not implemented for size_ = " << size_;
28   }
29 }
30 
GetFieldType() const31 const std::string& ScalarField::GetFieldType() const {
32   return ScalarField::kFieldType;
33 }
34 
GetSize() const35 Size ScalarField::GetSize() const {
36   return size_;
37 }
38 
GetDataType() const39 std::string ScalarField::GetDataType() const {
40   return util::GetTypeForSize(size_);
41 }
42 
GetShiftBits(int i)43 int GetShiftBits(int i) {
44   int bits_past_byte_boundary = i % 8;
45   if (bits_past_byte_boundary == 0) {
46     return 0;
47   } else {
48     return 8 - bits_past_byte_boundary;
49   }
50 }
51 
GenBounds(std::ostream & s,Size start_offset,Size end_offset,Size size) const52 int ScalarField::GenBounds(std::ostream& s, Size start_offset, Size end_offset, Size size) const {
53   int num_leading_bits = 0;
54 
55   if (!start_offset.empty()) {
56     // Default to start if available.
57     num_leading_bits = start_offset.bits() % 8;
58     s << "auto " << GetName() << "_it = to_bound + (" << start_offset << ") / 8;";
59   } else if (!end_offset.empty()) {
60     num_leading_bits = GetShiftBits(end_offset.bits() + size.bits());
61     Size byte_offset = Size(num_leading_bits + size.bits()) + end_offset;
62     s << "auto " << GetName() << "_it = to_bound + (to_bound.NumBytesRemaining() - (" << byte_offset << ") / 8);";
63   } else {
64     ERROR(this) << "Ambiguous offset for field.";
65   }
66   return num_leading_bits;
67 }
68 
GenExtractor(std::ostream & s,int num_leading_bits,bool) const69 void ScalarField::GenExtractor(std::ostream& s, int num_leading_bits, bool) const {
70   Size size = GetSize();
71   // Extract the correct number of bytes. The return type could be different
72   // from the extract type if an earlier field causes the beginning of the
73   // current field to start in the middle of a byte.
74   std::string extract_type = util::GetTypeForSize(size.bits() + num_leading_bits);
75   s << "auto extracted_value = " << GetName() << "_it.extract<" << extract_type << ">();";
76 
77   // Right shift the result to remove leading bits.
78   if (num_leading_bits != 0) {
79     s << "extracted_value >>= " << num_leading_bits << ";";
80   }
81   // Mask the result if necessary.
82   if (util::RoundSizeUp(size.bits()) != size.bits()) {
83     uint64_t mask = 0;
84     for (int i = 0; i < size.bits(); i++) {
85       mask <<= 1;
86       mask |= 1;
87     }
88     s << "extracted_value &= 0x" << std::hex << mask << std::dec << ";";
89   }
90   s << "*" << GetName() << "_ptr = static_cast<" << GetDataType() << ">(extracted_value);";
91 }
92 
GetGetterFunctionName() const93 std::string ScalarField::GetGetterFunctionName() const {
94   std::stringstream ss;
95   ss << "Get" << util::UnderscoreToCamelCase(GetName());
96   return ss.str();
97 }
98 
GenGetter(std::ostream & s,Size start_offset,Size end_offset) const99 void ScalarField::GenGetter(std::ostream& s, Size start_offset, Size end_offset) const {
100   s << GetDataType() << " " << GetGetterFunctionName() << "() const {";
101   s << "ASSERT(was_validated_);";
102   s << "auto to_bound = begin();";
103   int num_leading_bits = GenBounds(s, start_offset, end_offset, GetSize());
104   s << GetDataType() << " " << GetName() << "_value{};";
105   s << GetDataType() << "* " << GetName() << "_ptr = &" << GetName() << "_value;";
106   GenExtractor(s, num_leading_bits, false);
107   s << "return " << GetName() << "_value;";
108   s << "}";
109 }
110 
GetBuilderParameterType() const111 std::string ScalarField::GetBuilderParameterType() const {
112   return GetDataType();
113 }
114 
HasParameterValidator() const115 bool ScalarField::HasParameterValidator() const {
116   return util::RoundSizeUp(GetSize().bits()) != GetSize().bits();
117 }
118 
GenParameterValidator(std::ostream & s) const119 void ScalarField::GenParameterValidator(std::ostream& s) const {
120   s << "ASSERT(" << GetName() << " < (static_cast<uint64_t>(1) << " << GetSize().bits() << "));";
121 }
122 
GenInserter(std::ostream & s) const123 void ScalarField::GenInserter(std::ostream& s) const {
124   if (GetSize().bits() == 8) {
125     s << "i.insert_byte(" << GetName() << "_);";
126   } else {
127     s << "insert(" << GetName() << "_, i," << GetSize().bits() << ");";
128   }
129 }
130 
GenValidator(std::ostream &) const131 void ScalarField::GenValidator(std::ostream&) const {
132   // Do nothing
133 }
134 
GenStringRepresentation(std::ostream & s,std::string accessor) const135 void ScalarField::GenStringRepresentation(std::ostream& s, std::string accessor) const {
136   s << "+" << accessor;
137 }
138 
GetRustDataType() const139 std::string ScalarField::GetRustDataType() const {
140   return util::GetRustTypeForSize(size_);
141 }
142 
GetRustParseDataType() const143 std::string ScalarField::GetRustParseDataType() const {
144   return util::GetRustTypeForSize(size_);
145 }
146 
GetRustBitOffset(std::ostream &,Size start_offset,Size end_offset,Size size) const147 int ScalarField::GetRustBitOffset(
148     std::ostream&, Size start_offset, Size end_offset, Size size) const {
149   int num_leading_bits = 0;
150 
151   if (!start_offset.empty()) {
152     // Default to start if available.
153     num_leading_bits = start_offset.bits() % 8;
154   } else if (!end_offset.empty()) {
155     num_leading_bits = GetShiftBits(end_offset.bits() + size.bits());
156     Size byte_offset = Size(num_leading_bits + size.bits()) + end_offset;
157   } else {
158     ERROR(this) << "Ambiguous offset for field.";
159   }
160   return num_leading_bits;
161 }
162 
GenRustGetter(std::ostream & s,Size start_offset,Size end_offset) const163 void ScalarField::GenRustGetter(std::ostream& s, Size start_offset, Size end_offset) const {
164   Size size = GetSize();
165 
166   int num_leading_bits = GetRustBitOffset(s, start_offset, end_offset, GetSize());
167 
168   s << "let " << GetName() << " = ";
169   auto offset = num_leading_bits == 0 ? 0 : -1;
170   s << GetRustParseDataType() << "::from_le_bytes([";
171   int total_bytes;
172   if (size_ <= 8) {
173     total_bytes = 1;
174   } else if (size_ <= 16) {
175     total_bytes = 2;
176   } else if (size_ <= 32) {
177     total_bytes = 4;
178   } else {
179     total_bytes = 8;
180   }
181   for (int i = 0; i < total_bytes; i++) {
182     if (i > 0) {
183       s << ",";
184     }
185     if (i < size.bytes()) {
186       s << "bytes[" << start_offset.bytes() + i + offset << "]";
187     } else {
188       s << 0;
189     }
190   }
191   s << "]);";
192 
193   if (num_leading_bits != 0) {
194     s << "let " << GetName() << " = " << GetName() << " >> " << num_leading_bits << ";";
195   }
196 
197   if (util::RoundSizeUp(size.bits()) != size.bits()) {
198     uint64_t mask = 0;
199     for (int i = 0; i < size.bits(); i++) {
200       mask <<= 1;
201       mask |= 1;
202     }
203     s << "let " << GetName() << " = ";
204     s << GetName() << " & 0x" << std::hex << mask << std::dec << ";";
205   }
206 
207   // needs casting from primitive
208   if (GetRustParseDataType() != GetRustDataType()) {
209     s << "let " << GetName() << " = ";
210     s << GetRustDataType() << "::from_" << GetRustParseDataType() << "(" << GetName() << ").unwrap();";
211   }
212 }
213 
GenRustWriter(std::ostream & s,Size start_offset,Size end_offset) const214 void ScalarField::GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const {
215   Size size = GetSize();
216   int num_leading_bits = GetRustBitOffset(s, start_offset, end_offset, GetSize());
217 
218   if (GetFieldType() == SizeField::kFieldType || GetFieldType() == FixedScalarField::kFieldType) {
219     // Do nothing, the field access has already happened
220   } else if (GetRustParseDataType() != GetRustDataType()) {
221     // needs casting to primitive
222     s << "let " << GetName() << " = self." << GetName() << ".to_" << GetRustParseDataType() << "().unwrap();";
223   } else {
224     s << "let " << GetName() << " = self." << GetName() << ";";
225   }
226   if (util::RoundSizeUp(size.bits()) != size.bits()) {
227     uint64_t mask = 0;
228     for (int i = 0; i < size.bits(); i++) {
229       mask <<= 1;
230       mask |= 1;
231     }
232     s << "let " << GetName() << " = ";
233     s << GetName() << " & 0x" << std::hex << mask << std::dec << ";";
234   }
235 
236   int access_offset = 0;
237   if (num_leading_bits != 0) {
238     access_offset = -1;
239     uint64_t mask = 0;
240     for (int i = 0; i < num_leading_bits; i++) {
241       mask <<= 1;
242       mask |= 1;
243     }
244     s << "let " << GetName() << " = (" << GetName() << " << " << num_leading_bits << ") | ("
245       << "(buffer[" << start_offset.bytes() + access_offset << "] as " << GetRustParseDataType() << ") & 0x" << std::hex
246       << mask << std::dec << ");";
247   }
248 
249   s << "buffer[" << start_offset.bytes() + access_offset << ".."
250     << start_offset.bytes() + GetSize().bytes() + access_offset << "].copy_from_slice(&" << GetName()
251     << ".to_le_bytes()[0.." << size.bytes() << "]);";
252 }
253