1#!/usr/bin/env python
2#
3# Copyright (C) 2021 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""gen_sdk is a command line tool for managing the sdk extension proto db.
17
18Example usages:
19# Print a binary representation of the proto database.
20$ gen_sdk --action print_binary
21
22# Validate the database
23$ gen_sdk --action validate
24
25# Create a new SDK
26$ gen_sdk --action new_sdk --sdk 1 --modules=IPSEC,SDK_EXTENSIONS
27"""
28
29import argparse
30import google.protobuf.text_format
31import pathlib
32import sys
33
34from sdk_pb2 import ExtensionVersion
35from sdk_pb2 import ExtensionDatabase
36from sdk_pb2 import SdkModule
37from sdk_pb2 import SdkVersion
38
39
40def ParseArgs(argv):
41  parser = argparse.ArgumentParser('Manage the extension SDK database')
42  parser.add_argument(
43    '--database',
44    type=pathlib.Path,
45    metavar='PATH',
46    default='extensions_db.textpb',
47    help='The existing text-proto database to use. (default: extensions_db.textpb)'
48  )
49  parser.add_argument(
50    '--action',
51    choices=['print_binary', 'new_sdk', 'validate'],
52    metavar='ACTION',
53    required=True,
54    help='Which action to take (print_binary|new_sdk|validate).'
55  )
56  parser.add_argument(
57    '--sdk',
58    type=int,
59    metavar='SDK',
60    help='The extension SDK level to deal with (int)'
61  )
62  parser.add_argument(
63    '--modules',
64    metavar='MODULES',
65    help='Comma-separated list of modules providing new APIs. Required for action=new_sdk.'
66  )
67  return parser.parse_args(argv)
68
69
70"""Print the binary representation of the db proto to stdout."""
71def PrintBinary(database):
72  sys.stdout.buffer.write(database.SerializeToString())
73
74
75def ValidateDatabase(database, dbname):
76  def find_duplicate(l):
77    s = set()
78    for i in l:
79      if i in s:
80        return i
81      s.add(i)
82    return None
83
84  def find_bug():
85    dupe = find_duplicate([v.version for v in database.versions])
86    if dupe:
87      return 'Found duplicate extension version: %d' % dupe
88
89    for version in database.versions:
90      dupe = find_duplicate([r.module for r in version.requirements])
91      if dupe:
92        return 'Found duplicate module requirement for %s in single version %s' % (dupe, version)
93
94    prev_requirements = {}
95    for version in sorted(database.versions, key=lambda v: v.version):
96      for requirement in version.requirements:
97        if requirement.module in prev_requirements:
98          prev = prev_requirements[requirement.module]
99          if prev.version > requirement.version.version:
100            return 'Found module requirement moving backwards: %s in %s' % (requirement, version)
101        prev_requirements[requirement.module] = requirement.version
102    return None
103
104  err = find_bug()
105  if err is not None:
106    print('%s not valid, aborting:\n  %s' % (dbname, err))
107    sys.exit(1)
108
109
110def NewSdk(database, new_version, modules):
111  new_requirements = {}
112
113  # Gather the previous highest requirement of each module
114  for prev_version in sorted(database.versions, key=lambda v: v.version):
115    for prev_requirement in prev_version.requirements:
116      new_requirements[prev_requirement.module] = prev_requirement.version
117
118  # Add new requirements of this version
119  for module in modules:
120    new_requirements[module] = SdkVersion(version=new_version)
121
122  to_proto = lambda m : ExtensionVersion.ModuleRequirement(module=m, version=new_requirements[m])
123  module_requirements = [to_proto(m) for m in new_requirements]
124  extension_version = ExtensionVersion(version=new_version, requirements=module_requirements)
125  database.versions.append(extension_version)
126
127  module_names = ','.join([SdkModule.Name(m) for m in modules])
128  print('Created a new extension SDK level %d with modules %s' % (new_version, module_names))
129
130
131def main(argv):
132  args = ParseArgs(argv)
133  with args.database.open('r') as f:
134    database = google.protobuf.text_format.Parse(f.read(), ExtensionDatabase())
135
136  if args.modules:
137    modules = [SdkModule.Value(m) for m in args.modules.split(',')]
138
139  ValidateDatabase(database, 'Input database')
140
141  {
142    'validate': lambda : print('Validated database'),
143    'print_binary': lambda : PrintBinary(database),
144    'new_sdk': lambda : NewSdk(database, args.sdk, modules)
145  }[args.action]()
146
147  if args.action in ['new_sdk']:
148    ValidateDatabase(database, 'Post-modification database')
149    with args.database.open('w') as f:
150      f.write(google.protobuf.text_format.MessageToString(database))
151
152if __name__ == '__main__':
153  main(sys.argv[1:])
154