1from optparse import OptionParser
2from optparse import Option, OptionValueError
3import os
4import mini_parser
5import policy
6from policy import MatchPathPrefix
7import re
8import sys
9
10DEBUG=False
11
12'''
13Use file_contexts and policy to verify Treble requirements
14are not violated.
15'''
16coredomainAllowlist = {
17        # TODO: how do we make sure vendor_init doesn't have bad coupling with
18        # /vendor? It is the only system process which is not coredomain.
19        'vendor_init',
20        # TODO(b/152813275): need to avoid allowlist for rootdir
21        "modprobe",
22        "slideshow",
23        "healthd",
24        }
25
26class scontext:
27    def __init__(self):
28        self.fromSystem = False
29        self.fromVendor = False
30        self.coredomain = False
31        self.appdomain = False
32        self.attributes = set()
33        self.entrypoints = []
34        self.entrypointpaths = []
35        self.error = ""
36
37def PrintScontexts():
38    for d in sorted(alldomains.keys()):
39        sctx = alldomains[d]
40        print d
41        print "\tcoredomain="+str(sctx.coredomain)
42        print "\tappdomain="+str(sctx.appdomain)
43        print "\tfromSystem="+str(sctx.fromSystem)
44        print "\tfromVendor="+str(sctx.fromVendor)
45        print "\tattributes="+str(sctx.attributes)
46        print "\tentrypoints="+str(sctx.entrypoints)
47        print "\tentrypointpaths="
48        if sctx.entrypointpaths is not None:
49            for path in sctx.entrypointpaths:
50                print "\t\t"+str(path)
51
52alldomains = {}
53coredomains = set()
54appdomains = set()
55vendordomains = set()
56pol = None
57
58# compat vars
59alltypes = set()
60oldalltypes = set()
61compatMapping = None
62pubtypes = set()
63
64# Distinguish between PRODUCT_FULL_TREBLE and PRODUCT_FULL_TREBLE_OVERRIDE
65FakeTreble = False
66
67def GetAllDomains(pol):
68    global alldomains
69    for result in pol.QueryTypeAttribute("domain", True):
70        alldomains[result] = scontext()
71
72def GetAppDomains():
73    global appdomains
74    global alldomains
75    for d in alldomains:
76        # The application of the "appdomain" attribute is trusted because core
77        # selinux policy contains neverallow rules that enforce that only zygote
78        # and runas spawned processes may transition to processes that have
79        # the appdomain attribute.
80        if "appdomain" in alldomains[d].attributes:
81            alldomains[d].appdomain = True
82            appdomains.add(d)
83
84def GetCoreDomains():
85    global alldomains
86    global coredomains
87    for d in alldomains:
88        domain = alldomains[d]
89        # TestCoredomainViolations will verify if coredomain was incorrectly
90        # applied.
91        if "coredomain" in domain.attributes:
92            domain.coredomain = True
93            coredomains.add(d)
94        # check whether domains are executed off of /system or /vendor
95        if d in coredomainAllowlist:
96            continue
97        # TODO(b/153112003): add checks to prevent app domains from being
98        # incorrectly labeled as coredomain. Apps don't have entrypoints as
99        # they're always dynamically transitioned to by zygote.
100        if d in appdomains:
101            continue
102        # TODO(b/153112747): need to handle cases where there is a dynamic
103        # transition OR there happens to be no context in AOSP files.
104        if not domain.entrypointpaths:
105            continue
106
107        for path in domain.entrypointpaths:
108            vendor = any(MatchPathPrefix(path, prefix) for prefix in
109                         ["/vendor", "/odm"])
110            system = any(MatchPathPrefix(path, prefix) for prefix in
111                         ["/init", "/system_ext", "/product" ])
112
113            # only mark entrypoint as system if it is not in legacy /system/vendor
114            if MatchPathPrefix(path, "/system/vendor"):
115                vendor = True
116            elif MatchPathPrefix(path, "/system"):
117                system = True
118
119            if not vendor and not system:
120                domain.error += "Unrecognized entrypoint for " + d + " at " + path + "\n"
121
122            domain.fromSystem = domain.fromSystem or system
123            domain.fromVendor = domain.fromVendor or vendor
124
125###
126# Add the entrypoint type and path(s) to each domain.
127#
128def GetDomainEntrypoints(pol):
129    global alldomains
130    for x in pol.QueryExpandedTERule(tclass=set(["file"]), perms=set(["entrypoint"])):
131        if not x.sctx in alldomains:
132            continue
133        alldomains[x.sctx].entrypoints.append(str(x.tctx))
134        # postinstall_file represents a special case specific to A/B OTAs.
135        # Update_engine mounts a partition and relabels it postinstall_file.
136        # There is no file_contexts entry associated with postinstall_file
137        # so skip the lookup.
138        if x.tctx == "postinstall_file":
139            continue
140        entrypointpath = pol.QueryFc(x.tctx)
141        if not entrypointpath:
142            continue
143        alldomains[x.sctx].entrypointpaths.extend(entrypointpath)
144###
145# Get attributes associated with each domain
146#
147def GetAttributes(pol):
148    global alldomains
149    for domain in alldomains:
150        for result in pol.QueryTypeAttribute(domain, False):
151            alldomains[domain].attributes.add(result)
152
153def GetAllTypes(pol, oldpol):
154    global alltypes
155    global oldalltypes
156    alltypes = pol.GetAllTypes(False)
157    oldalltypes = oldpol.GetAllTypes(False)
158
159def setup(pol):
160    GetAllDomains(pol)
161    GetAttributes(pol)
162    GetDomainEntrypoints(pol)
163    GetAppDomains()
164    GetCoreDomains()
165
166# setup for the policy compatibility tests
167def compatSetup(pol, oldpol, mapping, types):
168    global compatMapping
169    global pubtypes
170
171    GetAllTypes(pol, oldpol)
172    compatMapping = mapping
173    pubtypes = types
174
175def DomainsWithAttribute(attr):
176    global alldomains
177    domains = []
178    for domain in alldomains:
179        if attr in alldomains[domain].attributes:
180            domains.append(domain)
181    return domains
182
183#############################################################
184# Tests
185#############################################################
186def TestCoredomainViolations():
187    global alldomains
188    # verify that all domains launched from /system have the coredomain
189    # attribute
190    ret = ""
191
192    for d in alldomains:
193        domain = alldomains[d]
194        if domain.fromSystem and domain.fromVendor:
195            ret += "The following domain is system and vendor: " + d + "\n"
196
197    for domain in alldomains.values():
198        ret += domain.error
199
200    violators = []
201    for d in alldomains:
202        domain = alldomains[d]
203        if domain.fromSystem and "coredomain" not in domain.attributes:
204                violators.append(d);
205    if len(violators) > 0:
206        ret += "The following domain(s) must be associated with the "
207        ret += "\"coredomain\" attribute because they are executed off of "
208        ret += "/system:\n"
209        ret += " ".join(str(x) for x in sorted(violators)) + "\n"
210
211    # verify that all domains launched form /vendor do not have the coredomain
212    # attribute
213    violators = []
214    for d in alldomains:
215        domain = alldomains[d]
216        if domain.fromVendor and "coredomain" in domain.attributes:
217            violators.append(d)
218    if len(violators) > 0:
219        ret += "The following domains must not be associated with the "
220        ret += "\"coredomain\" attribute because they are executed off of "
221        ret += "/vendor or /system/vendor:\n"
222        ret += " ".join(str(x) for x in sorted(violators)) + "\n"
223
224    return ret
225
226###
227# Make sure that any new public type introduced in the new policy that was not
228# present in the old policy has been recorded in the mapping file.
229def TestNoUnmappedNewTypes():
230    global alltypes
231    global oldalltypes
232    global compatMapping
233    global pubtypes
234    newt = alltypes - oldalltypes
235    ret = ""
236    violators = []
237
238    for n in newt:
239        if n in pubtypes and compatMapping.rTypeattributesets.get(n) is None:
240            violators.append(n)
241
242    if len(violators) > 0:
243        ret += "SELinux: The following public types were found added to the "
244        ret += "policy without an entry into the compatibility mapping file(s) "
245        ret += "found in private/compat/V.v/V.v[.ignore].cil, where V.v is the "
246        ret += "latest API level.\n"
247        ret += " ".join(str(x) for x in sorted(violators)) + "\n\n"
248        ret += "See examples of how to fix this:\n"
249        ret += "https://android-review.googlesource.com/c/platform/system/sepolicy/+/781036\n"
250        ret += "https://android-review.googlesource.com/c/platform/system/sepolicy/+/852612\n"
251    return ret
252
253###
254# Make sure that any public type removed in the current policy has its
255# declaration added to the mapping file for use in non-platform policy
256def TestNoUnmappedRmTypes():
257    global alltypes
258    global oldalltypes
259    global compatMapping
260    rmt = oldalltypes - alltypes
261    ret = ""
262    violators = []
263
264    for o in rmt:
265        if o in compatMapping.pubtypes and not o in compatMapping.types:
266            violators.append(o)
267
268    if len(violators) > 0:
269        ret += "SELinux: The following formerly public types were removed from "
270        ret += "policy without a declaration in the compatibility mapping "
271        ret += "found in private/compat/V.v/V.v[.ignore].cil, where V.v is the "
272        ret += "latest API level.\n"
273        ret += " ".join(str(x) for x in sorted(violators)) + "\n\n"
274        ret += "See examples of how to fix this:\n"
275        ret += "https://android-review.googlesource.com/c/platform/system/sepolicy/+/822743\n"
276    return ret
277
278def TestTrebleCompatMapping():
279    ret = TestNoUnmappedNewTypes()
280    ret += TestNoUnmappedRmTypes()
281    return ret
282
283def TestViolatorAttribute(attribute):
284    global FakeTreble
285    ret = ""
286    if FakeTreble:
287        return ret
288
289    violators = DomainsWithAttribute(attribute)
290    if len(violators) > 0:
291        ret += "SELinux: The following domains violate the Treble ban "
292        ret += "against use of the " + attribute + " attribute: "
293        ret += " ".join(str(x) for x in sorted(violators)) + "\n"
294    return ret
295
296def TestViolatorAttributes():
297    ret = ""
298    ret += TestViolatorAttribute("socket_between_core_and_vendor_violators")
299    ret += TestViolatorAttribute("vendor_executes_system_violators")
300    return ret
301
302# TODO move this to sepolicy_tests
303def TestCoreDataTypeViolations():
304    global pol
305    return pol.AssertPathTypesDoNotHaveAttr(["/data/vendor/", "/data/vendor_ce/",
306        "/data/vendor_de/"], [], "core_data_file_type")
307
308###
309# extend OptionParser to allow the same option flag to be used multiple times.
310# This is used to allow multiple file_contexts files and tests to be
311# specified.
312#
313class MultipleOption(Option):
314    ACTIONS = Option.ACTIONS + ("extend",)
315    STORE_ACTIONS = Option.STORE_ACTIONS + ("extend",)
316    TYPED_ACTIONS = Option.TYPED_ACTIONS + ("extend",)
317    ALWAYS_TYPED_ACTIONS = Option.ALWAYS_TYPED_ACTIONS + ("extend",)
318
319    def take_action(self, action, dest, opt, value, values, parser):
320        if action == "extend":
321            values.ensure_value(dest, []).append(value)
322        else:
323            Option.take_action(self, action, dest, opt, value, values, parser)
324
325Tests = {"CoredomainViolations": TestCoredomainViolations,
326         "CoreDatatypeViolations": TestCoreDataTypeViolations,
327         "TrebleCompatMapping": TestTrebleCompatMapping,
328         "ViolatorAttributes": TestViolatorAttributes}
329
330if __name__ == '__main__':
331    usage = "treble_sepolicy_tests -l $(ANDROID_HOST_OUT)/lib64/libsepolwrap.so "
332    usage += "-f nonplat_file_contexts -f plat_file_contexts "
333    usage += "-p curr_policy -b base_policy -o old_policy "
334    usage +="-m mapping file [--test test] [--help]"
335    parser = OptionParser(option_class=MultipleOption, usage=usage)
336    parser.add_option("-b", "--basepolicy", dest="basepolicy", metavar="FILE")
337    parser.add_option("-u", "--base-pub-policy", dest="base_pub_policy",
338                      metavar="FILE")
339    parser.add_option("-f", "--file_contexts", dest="file_contexts",
340            metavar="FILE", action="extend", type="string")
341    parser.add_option("-l", "--library-path", dest="libpath", metavar="FILE")
342    parser.add_option("-m", "--mapping", dest="mapping", metavar="FILE")
343    parser.add_option("-o", "--oldpolicy", dest="oldpolicy", metavar="FILE")
344    parser.add_option("-p", "--policy", dest="policy", metavar="FILE")
345    parser.add_option("-t", "--test", dest="tests", action="extend",
346            help="Test options include "+str(Tests))
347    parser.add_option("--fake-treble", action="store_true", dest="faketreble",
348            default=False)
349
350    (options, args) = parser.parse_args()
351
352    if not options.libpath:
353        sys.exit("Must specify path to libsepolwrap library\n" + parser.usage)
354    if not os.path.exists(options.libpath):
355        sys.exit("Error: library-path " + options.libpath + " does not exist\n"
356                + parser.usage)
357    if not options.policy:
358        sys.exit("Must specify current monolithic policy file\n" + parser.usage)
359    if not os.path.exists(options.policy):
360        sys.exit("Error: policy file " + options.policy + " does not exist\n"
361                + parser.usage)
362    if not options.file_contexts:
363        sys.exit("Error: Must specify file_contexts file(s)\n" + parser.usage)
364    for f in options.file_contexts:
365        if not os.path.exists(f):
366            sys.exit("Error: File_contexts file " + f + " does not exist\n" +
367                    parser.usage)
368
369    # Mapping files and public platform policy are only necessary for the
370    # TrebleCompatMapping test.
371    if options.tests is None or options.tests is "TrebleCompatMapping":
372        if not options.basepolicy:
373            sys.exit("Must specify the current platform-only policy file\n"
374                     + parser.usage)
375        if not options.mapping:
376            sys.exit("Must specify a compatibility mapping file\n"
377                     + parser.usage)
378        if not options.oldpolicy:
379            sys.exit("Must specify the previous monolithic policy file\n"
380                     + parser.usage)
381        if not options.base_pub_policy:
382            sys.exit("Must specify the current platform-only public policy "
383                     + ".cil file\n" + parser.usage)
384        basepol = policy.Policy(options.basepolicy, None, options.libpath)
385        oldpol = policy.Policy(options.oldpolicy, None, options.libpath)
386        mapping = mini_parser.MiniCilParser(options.mapping)
387        pubpol = mini_parser.MiniCilParser(options.base_pub_policy)
388        compatSetup(basepol, oldpol, mapping, pubpol.types)
389
390    if options.faketreble:
391        FakeTreble = True
392
393    pol = policy.Policy(options.policy, options.file_contexts, options.libpath)
394    setup(pol)
395
396    if DEBUG:
397        PrintScontexts()
398
399    results = ""
400    # If an individual test is not specified, run all tests.
401    if options.tests is None:
402        for t in Tests.values():
403            results += t()
404    else:
405        for tn in options.tests:
406            t = Tests.get(tn)
407            if t:
408                results += t()
409            else:
410                err = "Error: unknown test: " + tn + "\n"
411                err += "Available tests:\n"
412                for tn in Tests.keys():
413                    err += tn + "\n"
414                sys.exit(err)
415
416    if len(results) > 0:
417        sys.exit(results)
418