1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/protozero/filtering/filter_bytecode_parser.h"
18 
19 #include "perfetto/base/logging.h"
20 #include "perfetto/ext/base/hash.h"
21 #include "perfetto/protozero/packed_repeated_fields.h"
22 #include "perfetto/protozero/proto_decoder.h"
23 #include "perfetto/protozero/proto_utils.h"
24 #include "src/protozero/filtering/filter_bytecode_common.h"
25 
26 namespace protozero {
27 
Reset()28 void FilterBytecodeParser::Reset() {
29   bool suppress = suppress_logs_for_fuzzer_;
30   *this = FilterBytecodeParser();
31   suppress_logs_for_fuzzer_ = suppress;
32 }
33 
Load(const void * filter_data,size_t len)34 bool FilterBytecodeParser::Load(const void* filter_data, size_t len) {
35   Reset();
36   bool res = LoadInternal(static_cast<const uint8_t*>(filter_data), len);
37   // If load fails, don't leave the parser in a half broken state.
38   if (!res)
39     Reset();
40   return res;
41 }
42 
LoadInternal(const uint8_t * bytecode_data,size_t len)43 bool FilterBytecodeParser::LoadInternal(const uint8_t* bytecode_data,
44                                         size_t len) {
45   // First unpack the varints into a plain uint32 vector, so it's easy to
46   // iterate through them and look ahead.
47   std::vector<uint32_t> words;
48   bool packed_parse_err = false;
49   words.reserve(len);  // An overestimation, but avoids reallocations.
50   using BytecodeDecoder =
51       PackedRepeatedFieldIterator<proto_utils::ProtoWireType::kVarInt,
52                                   uint32_t>;
53   for (BytecodeDecoder it(bytecode_data, len, &packed_parse_err); it; ++it)
54     words.emplace_back(*it);
55 
56   if (packed_parse_err || words.empty())
57     return false;
58 
59   perfetto::base::Hash hasher;
60   for (size_t i = 0; i < words.size() - 1; ++i)
61     hasher.Update(words[i]);
62 
63   uint32_t expected_csum = static_cast<uint32_t>(hasher.digest());
64   if (expected_csum != words.back()) {
65     if (!suppress_logs_for_fuzzer_) {
66       PERFETTO_ELOG("Filter bytecode checksum failed. Expected: %x, actual: %x",
67                     expected_csum, words.back());
68     }
69     return false;
70   }
71 
72   words.pop_back();  // Pop the checksum.
73 
74   // Temporay storage for each message. Cleared on every END_OF_MESSAGE.
75   std::vector<uint32_t> direct_indexed_fields;
76   std::vector<uint32_t> ranges;
77   uint32_t max_msg_index = 0;
78 
79   auto add_directly_indexed_field = [&](uint32_t field_id, uint32_t msg_id) {
80     PERFETTO_DCHECK(field_id > 0 && field_id < kDirectlyIndexLimit);
81     direct_indexed_fields.resize(std::max(direct_indexed_fields.size(),
82                                           static_cast<size_t>(field_id) + 1));
83     direct_indexed_fields[field_id] = kAllowed | msg_id;
84   };
85 
86   auto add_range = [&](uint32_t id_start, uint32_t id_end, uint32_t msg_id) {
87     PERFETTO_DCHECK(id_end > id_start);
88     PERFETTO_DCHECK(id_start >= kDirectlyIndexLimit);
89     ranges.emplace_back(id_start);
90     ranges.emplace_back(id_end);
91     ranges.emplace_back(kAllowed | msg_id);
92   };
93 
94   for (size_t i = 0; i < words.size(); ++i) {
95     const uint32_t word = words[i];
96     const bool has_next_word = i < words.size() - 1;
97     const uint32_t opcode = word & 0x7u;
98     const uint32_t field_id = word >> 3;
99 
100     if (field_id == 0 && opcode != kFilterOpcode_EndOfMessage) {
101       PERFETTO_DLOG("bytecode error @ word %zu, invalid field id (0)", i);
102       return false;
103     }
104 
105     if (opcode == kFilterOpcode_SimpleField ||
106         opcode == kFilterOpcode_NestedField) {
107       // Field words are organized as follow:
108       // MSB: 1 if allowed, 0 if not allowed.
109       // Remaining bits:
110       //   Message index in the case of nested (non-simple) messages.
111       //   0x7f..f in the case of simple messages.
112       uint32_t msg_id;
113       if (opcode == kFilterOpcode_SimpleField) {
114         msg_id = kSimpleField;
115       } else {  // FILTER_OPCODE_NESTED_FIELD
116         // The next word in the bytecode contains the message index.
117         if (!has_next_word) {
118           PERFETTO_DLOG("bytecode error @ word %zu: unterminated nested field",
119                         i);
120           return false;
121         }
122         msg_id = words[++i];
123         max_msg_index = std::max(max_msg_index, msg_id);
124       }
125 
126       if (field_id < kDirectlyIndexLimit) {
127         add_directly_indexed_field(field_id, msg_id);
128       } else {
129         // In the case of a large field id (rare) we waste an extra word and
130         // represent it as a range. Doesn't make sense to introduce extra
131         // complexity to deal with rare cases like this.
132         add_range(field_id, field_id + 1, msg_id);
133       }
134     } else if (opcode == kFilterOpcode_SimpleFieldRange) {
135       if (!has_next_word) {
136         PERFETTO_DLOG("bytecode error @ word %zu: unterminated range", i);
137         return false;
138       }
139       const uint32_t range_len = words[++i];
140       const uint32_t range_end = field_id + range_len;  // STL-style, excl.
141       uint32_t id = field_id;
142 
143       // Here's the subtle complexity: at the bytecode level, we don't know
144       // anything about the kDirectlyIndexLimit. It is legit to define a range
145       // that spans across the direct-indexing threshold (e.g. 126-132). In that
146       // case we want to add all the elements < the indexing to the O(1) bucket
147       // and add only the remaining range as a non-indexed range.
148       for (; id < range_end && id < kDirectlyIndexLimit; ++id)
149         add_directly_indexed_field(id, kAllowed | kSimpleField);
150       PERFETTO_DCHECK(id >= kDirectlyIndexLimit || id == range_end);
151       if (id < range_end)
152         add_range(id, range_end, kSimpleField);
153     } else if (opcode == kFilterOpcode_EndOfMessage) {
154       // For each message append:
155       // 1. The "header" word telling how many directly indexed fields there
156       //    are.
157       // 2. The words for the directly indexed fields (id < 128).
158       // 3. The rest of the fields, encoded as ranges.
159       // Also update the |message_offset_| index to remember the word offset for
160       // the current message.
161       message_offset_.emplace_back(static_cast<uint32_t>(words_.size()));
162       words_.emplace_back(static_cast<uint32_t>(direct_indexed_fields.size()));
163       words_.insert(words_.end(), direct_indexed_fields.begin(),
164                     direct_indexed_fields.end());
165       words_.insert(words_.end(), ranges.begin(), ranges.end());
166       direct_indexed_fields.clear();
167       ranges.clear();
168     } else {
169       PERFETTO_DLOG("bytecode error @ word %zu: invalid opcode (%x)", i, word);
170       return false;
171     }
172   }  // (for word in bytecode).
173 
174   if (max_msg_index > 0 && max_msg_index >= message_offset_.size()) {
175     PERFETTO_DLOG(
176         "bytecode error: a message index (%u) is out of range "
177         "(num_messages=%zu)",
178         max_msg_index, message_offset_.size());
179     return false;
180   }
181 
182   // Add a final entry to |message_offset_| so we can tell where the last
183   // message ends without an extra branch in the Query() hotpath.
184   message_offset_.emplace_back(static_cast<uint32_t>(words_.size()));
185 
186   return true;
187 }
188 
Query(uint32_t msg_index,uint32_t field_id)189 FilterBytecodeParser::QueryResult FilterBytecodeParser::Query(
190     uint32_t msg_index,
191     uint32_t field_id) {
192   FilterBytecodeParser::QueryResult res{false, 0u};
193   if (static_cast<uint64_t>(msg_index) + 1 >=
194       static_cast<uint64_t>(message_offset_.size())) {
195     return res;
196   }
197   const uint32_t start_offset = message_offset_[msg_index];
198   // These are DCHECKs and not just CHECKS because the |words_| is populated
199   // by the LoadInternal call above. These cannot be violated with a malformed
200   // bytecode.
201   PERFETTO_DCHECK(start_offset < words_.size());
202   const uint32_t* word = &words_[start_offset];
203   const uint32_t end_off = message_offset_[msg_index + 1];
204   const uint32_t* const end = words_.data() + end_off;
205   PERFETTO_DCHECK(end > word && end <= words_.data() + words_.size());
206   const uint32_t num_directly_indexed = *(word++);
207   PERFETTO_DCHECK(num_directly_indexed <= kDirectlyIndexLimit);
208   PERFETTO_DCHECK(word + num_directly_indexed <= end);
209   uint32_t field_state = 0;
210   if (PERFETTO_LIKELY(field_id < num_directly_indexed)) {
211     PERFETTO_DCHECK(&word[field_id] < end);
212     field_state = word[field_id];
213   } else {
214     for (word = word + num_directly_indexed; word + 2 < end;) {
215       const uint32_t range_start = *(word++);
216       const uint32_t range_end = *(word++);
217       const uint32_t range_state = *(word++);
218       if (field_id >= range_start && field_id < range_end) {
219         field_state = range_state;
220         break;
221       }
222     }  // for (word in ranges)
223   }    // if (field_id >= num_directly_indexed)
224 
225   res.allowed = (field_state & kAllowed) != 0;
226   res.nested_msg_index = field_state & ~kAllowed;
227   PERFETTO_DCHECK(res.simple_field() ||
228                   res.nested_msg_index < message_offset_.size() - 1);
229   return res;
230 }
231 
232 }  // namespace protozero
233