1#!/usr/bin/env python 2# 3# This script generates a BPF program with structure inspired by trace.py. The 4# generated program operates on PID-indexed stacks. Generally speaking, 5# bookkeeping is done at every intermediate function kprobe/kretprobe to enforce 6# the goal of "fail iff this call chain and these predicates". 7# 8# Top level functions(the ones at the end of the call chain) are responsible for 9# creating the pid_struct and deleting it from the map in kprobe and kretprobe 10# respectively. 11# 12# Intermediate functions(between should_fail_whatever and the top level 13# functions) are responsible for updating the stack to indicate "I have been 14# called and one of my predicate(s) passed" in their entry probes. In their exit 15# probes, they do the opposite, popping their stack to maintain correctness. 16# This implementation aims to ensure correctness in edge cases like recursive 17# calls, so there's some additional information stored in pid_struct for that. 18# 19# At the bottom level function(should_fail_whatever), we do a simple check to 20# ensure all necessary calls/predicates have passed before error injection. 21# 22# Note: presently there are a few hacks to get around various rewriter/verifier 23# issues. 24# 25# Note: this tool requires: 26# - CONFIG_BPF_KPROBE_OVERRIDE 27# 28# USAGE: inject [-h] [-I header] [-P probability] [-v] mode spec 29# 30# Copyright (c) 2018 Facebook, Inc. 31# Licensed under the Apache License, Version 2.0 (the "License") 32# 33# 16-Mar-2018 Howard McLauchlan Created this. 34 35import argparse 36import re 37from bcc import BPF 38 39 40class Probe: 41 errno_mapping = { 42 "kmalloc": "-ENOMEM", 43 "bio": "-EIO", 44 } 45 46 @classmethod 47 def configure(cls, mode, probability): 48 cls.mode = mode 49 cls.probability = probability 50 51 def __init__(self, func, preds, length, entry): 52 # length of call chain 53 self.length = length 54 self.func = func 55 self.preds = preds 56 self.is_entry = entry 57 58 def _bail(self, err): 59 raise ValueError("error in probe '%s': %s" % 60 (self.spec, err)) 61 62 def _get_err(self): 63 return Probe.errno_mapping[Probe.mode] 64 65 def _get_if_top(self): 66 # ordering guarantees that if this function is top, the last tup is top 67 chk = self.preds[0][1] == 0 68 if not chk: 69 return "" 70 71 if Probe.probability == 1: 72 early_pred = "false" 73 else: 74 early_pred = "bpf_get_prandom_u32() > %s" % str(int((1<<32)*Probe.probability)) 75 # init the map 76 # dont do an early exit here so the singular case works automatically 77 # have an early exit for probability option 78 enter = """ 79 /* 80 * Early exit for probability case 81 */ 82 if (%s) 83 return 0; 84 /* 85 * Top level function init map 86 */ 87 struct pid_struct p_struct = {0, 0}; 88 m.insert(&pid, &p_struct); 89 """ % early_pred 90 91 # kill the entry 92 exit = """ 93 /* 94 * Top level function clean up map 95 */ 96 m.delete(&pid); 97 """ 98 99 return enter if self.is_entry else exit 100 101 def _get_heading(self): 102 103 # we need to insert identifier and ctx into self.func 104 # gonna make a lot of formatting assumptions to make this work 105 left = self.func.find("(") 106 right = self.func.rfind(")") 107 108 # self.event and self.func_name need to be accessible 109 self.event = self.func[0:left] 110 self.func_name = self.event + ("_entry" if self.is_entry else "_exit") 111 func_sig = "struct pt_regs *ctx" 112 113 # assume theres something in there, no guarantee its well formed 114 if right > left + 1 and self.is_entry: 115 func_sig += ", " + self.func[left + 1:right] 116 117 return "int %s(%s)" % (self.func_name, func_sig) 118 119 def _get_entry_logic(self): 120 # there is at least one tup(pred, place) for this function 121 text = """ 122 123 if (p->conds_met >= %s) 124 return 0; 125 if (p->conds_met == %s && %s) { 126 p->stack[%s] = p->curr_call; 127 p->conds_met++; 128 }""" 129 text = text % (self.length, self.preds[0][1], self.preds[0][0], 130 self.preds[0][1]) 131 132 # for each additional pred 133 for tup in self.preds[1:]: 134 text += """ 135 else if (p->conds_met == %s && %s) { 136 p->stack[%s] = p->curr_call; 137 p->conds_met++; 138 } 139 """ % (tup[1], tup[0], tup[1]) 140 return text 141 142 def _generate_entry(self): 143 prog = self._get_heading() + """ 144{ 145 u32 pid = bpf_get_current_pid_tgid(); 146 %s 147 148 struct pid_struct *p = m.lookup(&pid); 149 150 if (!p) 151 return 0; 152 153 /* 154 * preparation for predicate, if necessary 155 */ 156 %s 157 /* 158 * Generate entry logic 159 */ 160 %s 161 162 p->curr_call++; 163 164 return 0; 165}""" 166 167 prog = prog % (self._get_if_top(), self.prep, self._get_entry_logic()) 168 return prog 169 170 # only need to check top of stack 171 def _get_exit_logic(self): 172 text = """ 173 if (p->conds_met < 1 || p->conds_met >= %s) 174 return 0; 175 176 if (p->stack[p->conds_met - 1] == p->curr_call) 177 p->conds_met--; 178 """ 179 return text % str(self.length + 1) 180 181 def _generate_exit(self): 182 prog = self._get_heading() + """ 183{ 184 u32 pid = bpf_get_current_pid_tgid(); 185 186 struct pid_struct *p = m.lookup(&pid); 187 188 if (!p) 189 return 0; 190 191 p->curr_call--; 192 193 /* 194 * Generate exit logic 195 */ 196 %s 197 %s 198 return 0; 199}""" 200 201 prog = prog % (self._get_exit_logic(), self._get_if_top()) 202 203 return prog 204 205 # Special case for should_fail_whatever 206 def _generate_bottom(self): 207 pred = self.preds[0][0] 208 text = self._get_heading() + """ 209{ 210 /* 211 * preparation for predicate, if necessary 212 */ 213 %s 214 /* 215 * If this is the only call in the chain and predicate passes 216 */ 217 if (%s == 1 && %s) { 218 bpf_override_return(ctx, %s); 219 return 0; 220 } 221 u32 pid = bpf_get_current_pid_tgid(); 222 223 struct pid_struct *p = m.lookup(&pid); 224 225 if (!p) 226 return 0; 227 228 /* 229 * If all conds have been met and predicate passes 230 */ 231 if (p->conds_met == %s && %s) 232 bpf_override_return(ctx, %s); 233 return 0; 234}""" 235 return text % (self.prep, self.length, pred, self._get_err(), 236 self.length - 1, pred, self._get_err()) 237 238 # presently parses and replaces STRCMP 239 # STRCMP exists because string comparison is inconvenient and somewhat buggy 240 # https://github.com/iovisor/bcc/issues/1617 241 def _prepare_pred(self): 242 self.prep = "" 243 for i in range(len(self.preds)): 244 new_pred = "" 245 pred = self.preds[i][0] 246 place = self.preds[i][1] 247 start, ind = 0, 0 248 while start < len(pred): 249 ind = pred.find("STRCMP(", start) 250 if ind == -1: 251 break 252 new_pred += pred[start:ind] 253 # 7 is len("STRCMP(") 254 start = pred.find(")", start + 7) + 1 255 256 # then ind ... start is STRCMP(...) 257 ptr, literal = pred[ind + 7:start - 1].split(",") 258 literal = literal.strip() 259 260 # x->y->z, some string literal 261 # we make unique id with place_ind 262 uuid = "%s_%s" % (place, ind) 263 unique_bool = "is_true_%s" % uuid 264 self.prep += """ 265 char *str_%s = %s; 266 bool %s = true;\n""" % (uuid, ptr.strip(), unique_bool) 267 268 check = "\t%s &= *(str_%s++) == '%%s';\n" % (unique_bool, uuid) 269 270 for ch in literal: 271 self.prep += check % ch 272 self.prep += check % r'\0' 273 new_pred += unique_bool 274 275 new_pred += pred[start:] 276 self.preds[i] = (new_pred, place) 277 278 def generate_program(self): 279 # generate code to work around various rewriter issues 280 self._prepare_pred() 281 282 # special case for bottom 283 if self.preds[-1][1] == self.length - 1: 284 return self._generate_bottom() 285 286 return self._generate_entry() if self.is_entry else self._generate_exit() 287 288 def attach(self, bpf): 289 if self.is_entry: 290 bpf.attach_kprobe(event=self.event, 291 fn_name=self.func_name) 292 else: 293 bpf.attach_kretprobe(event=self.event, 294 fn_name=self.func_name) 295 296 297class Tool: 298 299 examples =""" 300EXAMPLES: 301# ./inject.py kmalloc -v 'SyS_mount()' 302 Fails all calls to syscall mount 303# ./inject.py kmalloc -v '(true) => SyS_mount()(true)' 304 Explicit rewriting of above 305# ./inject.py kmalloc -v 'mount_subtree() => btrfs_mount()' 306 Fails btrfs mounts only 307# ./inject.py kmalloc -v 'd_alloc_parallel(struct dentry *parent, const struct \\ 308 qstr *name)(STRCMP(name->name, 'bananas'))' 309 Fails dentry allocations of files named 'bananas' 310# ./inject.py kmalloc -v -P 0.01 'SyS_mount()' 311 Fails calls to syscall mount with 1% probability 312 """ 313 # add cases as necessary 314 error_injection_mapping = { 315 "kmalloc": "should_failslab(struct kmem_cache *s, gfp_t gfpflags)", 316 "bio": "should_fail_bio(struct bio *bio)", 317 } 318 319 def __init__(self): 320 parser = argparse.ArgumentParser(description="Fail specified kernel" + 321 " functionality when call chain and predicates are met", 322 formatter_class=argparse.RawDescriptionHelpFormatter, 323 epilog=Tool.examples) 324 parser.add_argument(dest="mode", choices=['kmalloc','bio'], 325 help="indicate which base kernel function to fail") 326 parser.add_argument(metavar="spec", dest="spec", 327 help="specify call chain") 328 parser.add_argument("-I", "--include", action="append", 329 metavar="header", 330 help="additional header files to include in the BPF program") 331 parser.add_argument("-P", "--probability", default=1, 332 metavar="probability", type=float, 333 help="probability that this call chain will fail") 334 parser.add_argument("-v", "--verbose", action="store_true", 335 help="print BPF program") 336 self.args = parser.parse_args() 337 338 self.program = "" 339 self.spec = self.args.spec 340 self.map = {} 341 self.probes = [] 342 self.key = Tool.error_injection_mapping[self.args.mode] 343 344 # create_probes and associated stuff 345 def _create_probes(self): 346 self._parse_spec() 347 Probe.configure(self.args.mode, self.args.probability) 348 # self, func, preds, total, entry 349 350 # create all the pair probes 351 for fx, preds in self.map.items(): 352 353 # do the enter 354 self.probes.append(Probe(fx, preds, self.length, True)) 355 356 if self.key == fx: 357 continue 358 359 # do the exit 360 self.probes.append(Probe(fx, preds, self.length, False)) 361 362 def _parse_frames(self): 363 # sentinel 364 data = self.spec + '\0' 365 start, count = 0, 0 366 367 frames = [] 368 cur_frame = [] 369 i = 0 370 last_frame_added = 0 371 372 while i < len(data): 373 # improper input 374 if count < 0: 375 raise Exception("Check your parentheses") 376 c = data[i] 377 count += c == '(' 378 count -= c == ')' 379 if not count: 380 if c == '\0' or (c == '=' and data[i + 1] == '>'): 381 # This block is closing a chunk. This means cur_frame must 382 # have something in it. 383 if not cur_frame: 384 raise Exception("Cannot parse spec, missing parens") 385 if len(cur_frame) == 2: 386 frame = tuple(cur_frame) 387 elif cur_frame[0][0] == '(': 388 frame = self.key, cur_frame[0] 389 else: 390 frame = cur_frame[0], '(true)' 391 frames.append(frame) 392 del cur_frame[:] 393 i += 1 394 start = i + 1 395 elif c == ')': 396 cur_frame.append(data[start:i + 1].strip()) 397 start = i + 1 398 last_frame_added = start 399 i += 1 400 401 # We only permit spaces after the last frame 402 if self.spec[last_frame_added:].strip(): 403 raise Exception("Invalid characters found after last frame"); 404 # improper input 405 if count: 406 raise Exception("Check your parentheses") 407 return frames 408 409 def _parse_spec(self): 410 frames = self._parse_frames() 411 frames.reverse() 412 413 absolute_order = 0 414 for f in frames: 415 # default case 416 func, pred = f[0], f[1] 417 418 if not self._validate_predicate(pred): 419 raise Exception("Invalid predicate") 420 if not self._validate_identifier(func): 421 raise Exception("Invalid function identifier") 422 tup = (pred, absolute_order) 423 424 if func not in self.map: 425 self.map[func] = [tup] 426 else: 427 self.map[func].append(tup) 428 429 absolute_order += 1 430 431 if self.key not in self.map: 432 self.map[self.key] = [('(true)', absolute_order)] 433 absolute_order += 1 434 435 self.length = absolute_order 436 437 def _validate_identifier(self, func): 438 # We've already established paren balancing. We will only look for 439 # identifier validity here. 440 paren_index = func.find("(") 441 potential_id = func[:paren_index] 442 pattern = '[_a-zA-z][_a-zA-Z0-9]*$' 443 if re.match(pattern, potential_id): 444 return True 445 return False 446 447 def _validate_predicate(self, pred): 448 449 if len(pred) > 0 and pred[0] == "(": 450 open = 1 451 for i in range(1, len(pred)): 452 if pred[i] == "(": 453 open += 1 454 elif pred[i] == ")": 455 open -= 1 456 if open != 0: 457 # not well formed, break 458 return False 459 460 return True 461 462 def _def_pid_struct(self): 463 text = """ 464struct pid_struct { 465 u64 curr_call; /* book keeping to handle recursion */ 466 u64 conds_met; /* stack pointer */ 467 u64 stack[%s]; 468}; 469""" % self.length 470 return text 471 472 def _attach_probes(self): 473 self.bpf = BPF(text=self.program) 474 for p in self.probes: 475 p.attach(self.bpf) 476 477 def _generate_program(self): 478 # leave out auto includes for now 479 self.program += '#include <linux/mm.h>\n' 480 for include in (self.args.include or []): 481 self.program += "#include <%s>\n" % include 482 483 self.program += self._def_pid_struct() 484 self.program += "BPF_HASH(m, u32, struct pid_struct);\n" 485 for p in self.probes: 486 self.program += p.generate_program() + "\n" 487 488 if self.args.verbose: 489 print(self.program) 490 491 def _main_loop(self): 492 while True: 493 self.bpf.perf_buffer_poll() 494 495 def run(self): 496 self._create_probes() 497 self._generate_program() 498 self._attach_probes() 499 self._main_loop() 500 501 502if __name__ == "__main__": 503 Tool().run() 504