1#!/usr/bin/env python
2"""Command registry for apitools."""
3
4import logging
5import textwrap
6
7from protorpc import descriptor
8from protorpc import messages
9
10from apitools.gen import extended_descriptor
11
12# This is a code generator; we're purposely verbose.
13# pylint:disable=too-many-statements
14
15_VARIANT_TO_FLAG_TYPE_MAP = {
16    messages.Variant.DOUBLE: 'float',
17    messages.Variant.FLOAT: 'float',
18    messages.Variant.INT64: 'string',
19    messages.Variant.UINT64: 'string',
20    messages.Variant.INT32: 'integer',
21    messages.Variant.BOOL: 'boolean',
22    messages.Variant.STRING: 'string',
23    messages.Variant.MESSAGE: 'string',
24    messages.Variant.BYTES: 'string',
25    messages.Variant.UINT32: 'integer',
26    messages.Variant.ENUM: 'enum',
27    messages.Variant.SINT32: 'integer',
28    messages.Variant.SINT64: 'integer',
29}
30
31
32class FlagInfo(messages.Message):
33
34    """Information about a flag and conversion to a message.
35
36    Fields:
37      name: name of this flag.
38      type: type of the flag.
39      description: description of the flag.
40      default: default value for this flag.
41      enum_values: if this flag is an enum, the list of possible
42          values.
43      required: whether or not this flag is required.
44      fv: name of the flag_values object where this flag should
45          be registered.
46      conversion: template for type conversion.
47      special: (boolean, default: False) If True, this flag doesn't
48          correspond to an attribute on the request.
49    """
50    name = messages.StringField(1)
51    type = messages.StringField(2)
52    description = messages.StringField(3)
53    default = messages.StringField(4)
54    enum_values = messages.StringField(5, repeated=True)
55    required = messages.BooleanField(6, default=False)
56    fv = messages.StringField(7)
57    conversion = messages.StringField(8)
58    special = messages.BooleanField(9, default=False)
59
60
61class ArgInfo(messages.Message):
62
63    """Information about a single positional command argument.
64
65    Fields:
66      name: argument name.
67      description: description of this argument.
68      conversion: template for type conversion.
69    """
70    name = messages.StringField(1)
71    description = messages.StringField(2)
72    conversion = messages.StringField(3)
73
74
75class CommandInfo(messages.Message):
76
77    """Information about a single command.
78
79    Fields:
80      name: name of this command.
81      class_name: name of the apitools_base.NewCmd class for this command.
82      description: description of this command.
83      flags: list of FlagInfo messages for the command-specific flags.
84      args: list of ArgInfo messages for the positional args.
85      request_type: name of the request type for this command.
86      client_method_path: path from the client object to the method
87          this command is wrapping.
88    """
89    name = messages.StringField(1)
90    class_name = messages.StringField(2)
91    description = messages.StringField(3)
92    flags = messages.MessageField(FlagInfo, 4, repeated=True)
93    args = messages.MessageField(ArgInfo, 5, repeated=True)
94    request_type = messages.StringField(6)
95    client_method_path = messages.StringField(7)
96    has_upload = messages.BooleanField(8, default=False)
97    has_download = messages.BooleanField(9, default=False)
98
99
100class CommandRegistry(object):
101
102    """Registry for CLI commands."""
103
104    def __init__(self, package, version, client_info, message_registry,
105                 root_package, base_files_package, base_url, names):
106        self.__package = package
107        self.__version = version
108        self.__client_info = client_info
109        self.__names = names
110        self.__message_registry = message_registry
111        self.__root_package = root_package
112        self.__base_files_package = base_files_package
113        self.__base_url = base_url
114        self.__command_list = []
115        self.__global_flags = []
116
117    def Validate(self):
118        self.__message_registry.Validate()
119
120    def AddGlobalParameters(self, schema):
121        for field in schema.fields:
122            self.__global_flags.append(self.__FlagInfoFromField(field, schema))
123
124    def AddCommandForMethod(self, service_name, method_name, method_info,
125                            request, _):
126        """Add the given method as a command."""
127        command_name = self.__GetCommandName(method_info.method_id)
128        calling_path = '%s.%s' % (service_name, method_name)
129        request_type = self.__message_registry.LookupDescriptor(request)
130        description = method_info.description
131        if not description:
132            description = 'Call the %s method.' % method_info.method_id
133        field_map = dict((f.name, f) for f in request_type.fields)
134        args = []
135        arg_names = []
136        for field_name in method_info.ordered_params:
137            extended_field = field_map[field_name]
138            name = extended_field.name
139            args.append(ArgInfo(
140                name=name,
141                description=extended_field.description,
142                conversion=self.__GetConversion(extended_field, request_type),
143            ))
144            arg_names.append(name)
145        flags = []
146        for extended_field in sorted(request_type.fields,
147                                     key=lambda x: x.name):
148            field = extended_field.field_descriptor
149            if extended_field.name in arg_names:
150                continue
151            if self.__FieldIsRequired(field):
152                logging.warning(
153                    'Required field %s not in ordered_params for command %s',
154                    extended_field.name, command_name)
155            flags.append(self.__FlagInfoFromField(
156                extended_field, request_type, fv='fv'))
157        if method_info.upload_config:
158            # TODO(craigcitro): Consider adding additional flags to allow
159            # determining the filename from the object metadata.
160            upload_flag_info = FlagInfo(
161                name='upload_filename', type='string', default='',
162                description='Filename to use for upload.', fv='fv',
163                special=True)
164            flags.append(upload_flag_info)
165            mime_description = (
166                'MIME type to use for the upload. Only needed if '
167                'the extension on --upload_filename does not determine '
168                'the correct (or any) MIME type.')
169            mime_type_flag_info = FlagInfo(
170                name='upload_mime_type', type='string', default='',
171                description=mime_description, fv='fv', special=True)
172            flags.append(mime_type_flag_info)
173        if method_info.supports_download:
174            download_flag_info = FlagInfo(
175                name='download_filename', type='string', default='',
176                description='Filename to use for download.', fv='fv',
177                special=True)
178            flags.append(download_flag_info)
179            overwrite_description = (
180                'If True, overwrite the existing file when downloading.')
181            overwrite_flag_info = FlagInfo(
182                name='overwrite', type='boolean', default='False',
183                description=overwrite_description, fv='fv', special=True)
184            flags.append(overwrite_flag_info)
185        command_info = CommandInfo(
186            name=command_name,
187            class_name=self.__names.ClassName(command_name),
188            description=description,
189            flags=flags,
190            args=args,
191            request_type=request_type.full_name,
192            client_method_path=calling_path,
193            has_upload=bool(method_info.upload_config),
194            has_download=bool(method_info.supports_download)
195        )
196        self.__command_list.append(command_info)
197
198    def __LookupMessage(self, message, field):
199        message_type = self.__message_registry.LookupDescriptor(
200            '%s.%s' % (message.name, field.type_name))
201        if message_type is None:
202            message_type = self.__message_registry.LookupDescriptor(
203                field.type_name)
204        return message_type
205
206    def __GetCommandName(self, method_id):
207        command_name = method_id
208        prefix = '%s.' % self.__package
209        if command_name.startswith(prefix):
210            command_name = command_name[len(prefix):]
211        command_name = command_name.replace('.', '_')
212        return command_name
213
214    def __GetConversion(self, extended_field, extended_message):
215        field = extended_field.field_descriptor
216
217        type_name = ''
218        if field.variant in (messages.Variant.MESSAGE, messages.Variant.ENUM):
219            if field.type_name.startswith('protorpc.'):
220                type_name = field.type_name
221            else:
222                field_message = self.__LookupMessage(extended_message, field)
223                if field_message is None:
224                    raise ValueError(
225                        'Could not find type for field %s' % field.name)
226                type_name = 'messages.%s' % field_message.full_name
227
228        template = ''
229        if field.variant in (messages.Variant.INT64, messages.Variant.UINT64):
230            template = 'int(%s)'
231        elif field.variant == messages.Variant.MESSAGE:
232            template = 'apitools_base.JsonToMessage(%s, %%s)' % type_name
233        elif field.variant == messages.Variant.ENUM:
234            template = '%s(%%s)' % type_name
235        elif field.variant == messages.Variant.STRING:
236            template = "%s.decode('utf8')"
237
238        if self.__FieldIsRepeated(extended_field.field_descriptor):
239            if template:
240                template = '[%s for x in %%s]' % (template % 'x')
241
242        return template
243
244    def __FieldIsRequired(self, field):
245        return field.label == descriptor.FieldDescriptor.Label.REQUIRED
246
247    def __FieldIsRepeated(self, field):
248        return field.label == descriptor.FieldDescriptor.Label.REPEATED
249
250    def __FlagInfoFromField(self, extended_field, extended_message, fv=''):
251        field = extended_field.field_descriptor
252        flag_info = FlagInfo()
253        flag_info.name = str(field.name)
254        # TODO(craigcitro): We should key by variant.
255        flag_info.type = _VARIANT_TO_FLAG_TYPE_MAP[field.variant]
256        flag_info.description = extended_field.description
257        if field.default_value:
258            # TODO(craigcitro): Formatting?
259            flag_info.default = field.default_value
260        if flag_info.type == 'enum':
261            # TODO(craigcitro): Does protorpc do this for us?
262            enum_type = self.__LookupMessage(extended_message, field)
263            if enum_type is None:
264                raise ValueError('Cannot find enum type %s', field.type_name)
265            flag_info.enum_values = [x.name for x in enum_type.values]
266            # Note that this choice is completely arbitrary -- but we only
267            # push the value through if the user specifies it, so this
268            # doesn't hurt anything.
269            if flag_info.default is None:
270                flag_info.default = flag_info.enum_values[0]
271        if self.__FieldIsRequired(field):
272            flag_info.required = True
273        flag_info.fv = fv
274        flag_info.conversion = self.__GetConversion(
275            extended_field, extended_message)
276        return flag_info
277
278    def __PrintFlagDeclarations(self, printer):
279        package = self.__client_info.package
280        function_name = '_Declare%sFlags' % (package[0].upper() + package[1:])
281        printer()
282        printer()
283        printer('def %s():', function_name)
284        with printer.Indent():
285            printer('"""Declare global flags in an idempotent way."""')
286            printer("if 'api_endpoint' in flags.FLAGS:")
287            with printer.Indent():
288                printer('return')
289            printer('flags.DEFINE_string(')
290            with printer.Indent('    '):
291                printer("'api_endpoint',")
292                printer('%r,', self.__base_url)
293                printer("'URL of the API endpoint to use.',")
294                printer("short_name='%s_url')", self.__package)
295            printer('flags.DEFINE_string(')
296            with printer.Indent('    '):
297                printer("'history_file',")
298                printer('%r,', '~/.%s.%s.history' %
299                        (self.__package, self.__version))
300                printer("'File with interactive shell history.')")
301            printer('flags.DEFINE_multistring(')
302            with printer.Indent('    '):
303                printer("'add_header', [],")
304                printer("'Additional http headers (as key=value strings). '")
305                printer("'Can be specified multiple times.')")
306            printer('flags.DEFINE_string(')
307            with printer.Indent('    '):
308                printer("'service_account_json_keyfile', '',")
309                printer("'Filename for a JSON service account key downloaded'")
310                printer("' from the Developer Console.')")
311            for flag_info in self.__global_flags:
312                self.__PrintFlag(printer, flag_info)
313        printer()
314        printer()
315        printer('FLAGS = flags.FLAGS')
316        printer('apitools_base_cli.DeclareBaseFlags()')
317        printer('%s()', function_name)
318
319    def __PrintGetGlobalParams(self, printer):
320        printer('def GetGlobalParamsFromFlags():')
321        with printer.Indent():
322            printer('"""Return a StandardQueryParameters based on flags."""')
323            printer('result = messages.StandardQueryParameters()')
324
325            for flag_info in self.__global_flags:
326                rhs = 'FLAGS.%s' % flag_info.name
327                if flag_info.conversion:
328                    rhs = flag_info.conversion % rhs
329                printer('if FLAGS[%r].present:', flag_info.name)
330                with printer.Indent():
331                    printer('result.%s = %s', flag_info.name, rhs)
332            printer('return result')
333        printer()
334        printer()
335
336    def __PrintGetClient(self, printer):
337        printer('def GetClientFromFlags():')
338        with printer.Indent():
339            printer('"""Return a client object, configured from flags."""')
340            printer('log_request = FLAGS.log_request or '
341                    'FLAGS.log_request_response')
342            printer('log_response = FLAGS.log_response or '
343                    'FLAGS.log_request_response')
344            printer('api_endpoint = apitools_base.NormalizeApiEndpoint('
345                    'FLAGS.api_endpoint)')
346            printer("additional_http_headers = dict(x.split('=', 1) for x in "
347                    "FLAGS.add_header)")
348            printer('credentials_args = {')
349            with printer.Indent('    '):
350                printer("'service_account_json_keyfile': os.path.expanduser("
351                        'FLAGS.service_account_json_keyfile)')
352            printer('}')
353            printer('try:')
354            with printer.Indent():
355                printer('client = client_lib.%s(',
356                        self.__client_info.client_class_name)
357                with printer.Indent(indent='    '):
358                    printer('api_endpoint, log_request=log_request,')
359                    printer('log_response=log_response,')
360                    printer('credentials_args=credentials_args,')
361                    printer('additional_http_headers=additional_http_headers)')
362            printer('except apitools_base.CredentialsError as e:')
363            with printer.Indent():
364                printer("print 'Error creating credentials: %%s' %% e")
365                printer('sys.exit(1)')
366            printer('return client')
367        printer()
368        printer()
369
370    def __PrintCommandDocstring(self, printer, command_info):
371        with printer.CommentContext():
372            for line in textwrap.wrap('"""%s' % command_info.description,
373                                      printer.CalculateWidth()):
374                printer(line)
375            extended_descriptor.PrintIndentedDescriptions(
376                printer, command_info.args, 'Args')
377            extended_descriptor.PrintIndentedDescriptions(
378                printer, command_info.flags, 'Flags')
379            printer('"""')
380
381    def __PrintFlag(self, printer, flag_info):
382        printer('flags.DEFINE_%s(', flag_info.type)
383        with printer.Indent(indent='    '):
384            printer('%r,', flag_info.name)
385            printer('%r,', flag_info.default)
386            if flag_info.type == 'enum':
387                printer('%r,', flag_info.enum_values)
388
389            # TODO(craigcitro): Consider using 'drop_whitespace' elsewhere.
390            description_lines = textwrap.wrap(
391                flag_info.description, 75 - len(printer.indent),
392                drop_whitespace=False)
393            for line in description_lines[:-1]:
394                printer('%r', line)
395            last_line = description_lines[-1] if description_lines else ''
396            printer('%r%s', last_line, ',' if flag_info.fv else ')')
397            if flag_info.fv:
398                printer('flag_values=%s)', flag_info.fv)
399        if flag_info.required:
400            printer('flags.MarkFlagAsRequired(%r)', flag_info.name)
401
402    def __PrintPyShell(self, printer):
403        printer('class PyShell(appcommands.Cmd):')
404        printer()
405        with printer.Indent():
406            printer('def Run(self, _):')
407            with printer.Indent():
408                printer(
409                    '"""Run an interactive python shell with the client."""')
410                printer('client = GetClientFromFlags()')
411                printer('params = GetGlobalParamsFromFlags()')
412                printer('for field in params.all_fields():')
413                with printer.Indent():
414                    printer('value = params.get_assigned_value(field.name)')
415                    printer('if value != field.default:')
416                    with printer.Indent():
417                        printer('client.AddGlobalParam(field.name, value)')
418                printer('banner = """')
419                printer('       == %s interactive console ==' % (
420                    self.__client_info.package))
421                printer('             client: a %s client' %
422                        self.__client_info.package)
423                printer('      apitools_base: base apitools module')
424                printer('     messages: the generated messages module')
425                printer('"""')
426                printer('local_vars = {')
427                with printer.Indent(indent='    '):
428                    printer("'apitools_base': apitools_base,")
429                    printer("'client': client,")
430                    printer("'client_lib': client_lib,")
431                    printer("'messages': messages,")
432                printer('}')
433                printer("if platform.system() == 'Linux':")
434                with printer.Indent():
435                    printer('console = apitools_base_cli.ConsoleWithReadline(')
436                    with printer.Indent(indent='    '):
437                        printer('local_vars, histfile=FLAGS.history_file)')
438                printer('else:')
439                with printer.Indent():
440                    printer('console = code.InteractiveConsole(local_vars)')
441                printer('try:')
442                with printer.Indent():
443                    printer('console.interact(banner)')
444                printer('except SystemExit as e:')
445                with printer.Indent():
446                    printer('return e.code')
447        printer()
448        printer()
449
450    def WriteFile(self, printer):
451        """Write a simple CLI (currently just a stub)."""
452        printer('#!/usr/bin/env python')
453        printer('"""CLI for %s, version %s."""',
454                self.__package, self.__version)
455        printer('# NOTE: This file is autogenerated and should not be edited '
456                'by hand.')
457        # TODO(craigcitro): Add a build stamp, along with some other
458        # information.
459        printer()
460        printer('import code')
461        printer('import os')
462        printer('import platform')
463        printer('import sys')
464        printer()
465        printer('import protorpc')
466        printer('from protorpc import message_types')
467        printer('from protorpc import messages')
468        printer()
469        appcommands_import = 'from google.apputils import appcommands'
470        printer(appcommands_import)
471
472        flags_import = 'import gflags as flags'
473        printer(flags_import)
474        printer()
475        printer('import %s as apitools_base', self.__base_files_package)
476        printer('from %s import cli as apitools_base_cli',
477                self.__base_files_package)
478        import_prefix = ''
479        printer('%simport %s as client_lib',
480                import_prefix, self.__client_info.client_rule_name)
481        printer('%simport %s as messages',
482                import_prefix, self.__client_info.messages_rule_name)
483        self.__PrintFlagDeclarations(printer)
484        printer()
485        printer()
486        self.__PrintGetGlobalParams(printer)
487        self.__PrintGetClient(printer)
488        self.__PrintPyShell(printer)
489        self.__PrintCommands(printer)
490        printer('def main(_):')
491        with printer.Indent():
492            printer("appcommands.AddCmd('pyshell', PyShell)")
493            for command_info in self.__command_list:
494                printer("appcommands.AddCmd('%s', %s)",
495                        command_info.name, command_info.class_name)
496            printer()
497            printer('apitools_base_cli.SetupLogger()')
498            # TODO(craigcitro): Just call SetDefaultCommand as soon as
499            # another appcommands release happens and this exists
500            # externally.
501            printer("if hasattr(appcommands, 'SetDefaultCommand'):")
502            with printer.Indent():
503                printer("appcommands.SetDefaultCommand('pyshell')")
504        printer()
505        printer()
506        printer('run_main = apitools_base_cli.run_main')
507        printer()
508        printer("if __name__ == '__main__':")
509        with printer.Indent():
510            printer('appcommands.Run()')
511
512    def __PrintCommands(self, printer):
513        """Print all commands in this registry using printer."""
514        for command_info in self.__command_list:
515            arg_list = [arg_info.name for arg_info in command_info.args]
516            printer(
517                'class %s(apitools_base_cli.NewCmd):', command_info.class_name)
518            with printer.Indent():
519                printer('"""Command wrapping %s."""',
520                        command_info.client_method_path)
521                printer()
522                printer('usage = """%s%s%s"""',
523                        command_info.name,
524                        ' ' if arg_list else '',
525                        ' '.join('<%s>' % argname for argname in arg_list))
526                printer()
527                printer('def __init__(self, name, fv):')
528                with printer.Indent():
529                    printer('super(%s, self).__init__(name, fv)',
530                            command_info.class_name)
531                    for flag in command_info.flags:
532                        self.__PrintFlag(printer, flag)
533                printer()
534                printer('def RunWithArgs(%s):', ', '.join(['self'] + arg_list))
535                with printer.Indent():
536                    self.__PrintCommandDocstring(printer, command_info)
537                    printer('client = GetClientFromFlags()')
538                    printer('global_params = GetGlobalParamsFromFlags()')
539                    printer(
540                        'request = messages.%s(', command_info.request_type)
541                    with printer.Indent(indent='    '):
542                        for arg in command_info.args:
543                            rhs = arg.name
544                            if arg.conversion:
545                                rhs = arg.conversion % arg.name
546                            printer('%s=%s,', arg.name, rhs)
547                        printer(')')
548                    for flag_info in command_info.flags:
549                        if flag_info.special:
550                            continue
551                        rhs = 'FLAGS.%s' % flag_info.name
552                        if flag_info.conversion:
553                            rhs = flag_info.conversion % rhs
554                        printer('if FLAGS[%r].present:', flag_info.name)
555                        with printer.Indent():
556                            printer('request.%s = %s', flag_info.name, rhs)
557                    call_args = ['request', 'global_params=global_params']
558                    if command_info.has_upload:
559                        call_args.append('upload=upload')
560                        printer('upload = None')
561                        printer('if FLAGS.upload_filename:')
562                        with printer.Indent():
563                            printer('upload = apitools_base.Upload.FromFile(')
564                            printer('    FLAGS.upload_filename, '
565                                    'FLAGS.upload_mime_type,')
566                            printer('    progress_callback='
567                                    'apitools_base.UploadProgressPrinter,')
568                            printer('    finish_callback='
569                                    'apitools_base.UploadCompletePrinter)')
570                    if command_info.has_download:
571                        call_args.append('download=download')
572                        printer('download = None')
573                        printer('if FLAGS.download_filename:')
574                        with printer.Indent():
575                            printer('download = apitools_base.Download.'
576                                    'FromFile(FLAGS.download_filename, '
577                                    'overwrite=FLAGS.overwrite,')
578                            printer('    progress_callback='
579                                    'apitools_base.DownloadProgressPrinter,')
580                            printer('    finish_callback='
581                                    'apitools_base.DownloadCompletePrinter)')
582                    printer(
583                        'result = client.%s(', command_info.client_method_path)
584                    with printer.Indent(indent='    '):
585                        printer('%s)', ', '.join(call_args))
586                    printer('print apitools_base_cli.FormatOutput(result)')
587            printer()
588            printer()
589