1#!/usr/bin/env python 2"""Command registry for apitools.""" 3 4import logging 5import textwrap 6 7from protorpc import descriptor 8from protorpc import messages 9 10from apitools.gen import extended_descriptor 11 12# This is a code generator; we're purposely verbose. 13# pylint:disable=too-many-statements 14 15_VARIANT_TO_FLAG_TYPE_MAP = { 16 messages.Variant.DOUBLE: 'float', 17 messages.Variant.FLOAT: 'float', 18 messages.Variant.INT64: 'string', 19 messages.Variant.UINT64: 'string', 20 messages.Variant.INT32: 'integer', 21 messages.Variant.BOOL: 'boolean', 22 messages.Variant.STRING: 'string', 23 messages.Variant.MESSAGE: 'string', 24 messages.Variant.BYTES: 'string', 25 messages.Variant.UINT32: 'integer', 26 messages.Variant.ENUM: 'enum', 27 messages.Variant.SINT32: 'integer', 28 messages.Variant.SINT64: 'integer', 29} 30 31 32class FlagInfo(messages.Message): 33 34 """Information about a flag and conversion to a message. 35 36 Fields: 37 name: name of this flag. 38 type: type of the flag. 39 description: description of the flag. 40 default: default value for this flag. 41 enum_values: if this flag is an enum, the list of possible 42 values. 43 required: whether or not this flag is required. 44 fv: name of the flag_values object where this flag should 45 be registered. 46 conversion: template for type conversion. 47 special: (boolean, default: False) If True, this flag doesn't 48 correspond to an attribute on the request. 49 """ 50 name = messages.StringField(1) 51 type = messages.StringField(2) 52 description = messages.StringField(3) 53 default = messages.StringField(4) 54 enum_values = messages.StringField(5, repeated=True) 55 required = messages.BooleanField(6, default=False) 56 fv = messages.StringField(7) 57 conversion = messages.StringField(8) 58 special = messages.BooleanField(9, default=False) 59 60 61class ArgInfo(messages.Message): 62 63 """Information about a single positional command argument. 64 65 Fields: 66 name: argument name. 67 description: description of this argument. 68 conversion: template for type conversion. 69 """ 70 name = messages.StringField(1) 71 description = messages.StringField(2) 72 conversion = messages.StringField(3) 73 74 75class CommandInfo(messages.Message): 76 77 """Information about a single command. 78 79 Fields: 80 name: name of this command. 81 class_name: name of the apitools_base.NewCmd class for this command. 82 description: description of this command. 83 flags: list of FlagInfo messages for the command-specific flags. 84 args: list of ArgInfo messages for the positional args. 85 request_type: name of the request type for this command. 86 client_method_path: path from the client object to the method 87 this command is wrapping. 88 """ 89 name = messages.StringField(1) 90 class_name = messages.StringField(2) 91 description = messages.StringField(3) 92 flags = messages.MessageField(FlagInfo, 4, repeated=True) 93 args = messages.MessageField(ArgInfo, 5, repeated=True) 94 request_type = messages.StringField(6) 95 client_method_path = messages.StringField(7) 96 has_upload = messages.BooleanField(8, default=False) 97 has_download = messages.BooleanField(9, default=False) 98 99 100class CommandRegistry(object): 101 102 """Registry for CLI commands.""" 103 104 def __init__(self, package, version, client_info, message_registry, 105 root_package, base_files_package, base_url, names): 106 self.__package = package 107 self.__version = version 108 self.__client_info = client_info 109 self.__names = names 110 self.__message_registry = message_registry 111 self.__root_package = root_package 112 self.__base_files_package = base_files_package 113 self.__base_url = base_url 114 self.__command_list = [] 115 self.__global_flags = [] 116 117 def Validate(self): 118 self.__message_registry.Validate() 119 120 def AddGlobalParameters(self, schema): 121 for field in schema.fields: 122 self.__global_flags.append(self.__FlagInfoFromField(field, schema)) 123 124 def AddCommandForMethod(self, service_name, method_name, method_info, 125 request, _): 126 """Add the given method as a command.""" 127 command_name = self.__GetCommandName(method_info.method_id) 128 calling_path = '%s.%s' % (service_name, method_name) 129 request_type = self.__message_registry.LookupDescriptor(request) 130 description = method_info.description 131 if not description: 132 description = 'Call the %s method.' % method_info.method_id 133 field_map = dict((f.name, f) for f in request_type.fields) 134 args = [] 135 arg_names = [] 136 for field_name in method_info.ordered_params: 137 extended_field = field_map[field_name] 138 name = extended_field.name 139 args.append(ArgInfo( 140 name=name, 141 description=extended_field.description, 142 conversion=self.__GetConversion(extended_field, request_type), 143 )) 144 arg_names.append(name) 145 flags = [] 146 for extended_field in sorted(request_type.fields, 147 key=lambda x: x.name): 148 field = extended_field.field_descriptor 149 if extended_field.name in arg_names: 150 continue 151 if self.__FieldIsRequired(field): 152 logging.warning( 153 'Required field %s not in ordered_params for command %s', 154 extended_field.name, command_name) 155 flags.append(self.__FlagInfoFromField( 156 extended_field, request_type, fv='fv')) 157 if method_info.upload_config: 158 # TODO(craigcitro): Consider adding additional flags to allow 159 # determining the filename from the object metadata. 160 upload_flag_info = FlagInfo( 161 name='upload_filename', type='string', default='', 162 description='Filename to use for upload.', fv='fv', 163 special=True) 164 flags.append(upload_flag_info) 165 mime_description = ( 166 'MIME type to use for the upload. Only needed if ' 167 'the extension on --upload_filename does not determine ' 168 'the correct (or any) MIME type.') 169 mime_type_flag_info = FlagInfo( 170 name='upload_mime_type', type='string', default='', 171 description=mime_description, fv='fv', special=True) 172 flags.append(mime_type_flag_info) 173 if method_info.supports_download: 174 download_flag_info = FlagInfo( 175 name='download_filename', type='string', default='', 176 description='Filename to use for download.', fv='fv', 177 special=True) 178 flags.append(download_flag_info) 179 overwrite_description = ( 180 'If True, overwrite the existing file when downloading.') 181 overwrite_flag_info = FlagInfo( 182 name='overwrite', type='boolean', default='False', 183 description=overwrite_description, fv='fv', special=True) 184 flags.append(overwrite_flag_info) 185 command_info = CommandInfo( 186 name=command_name, 187 class_name=self.__names.ClassName(command_name), 188 description=description, 189 flags=flags, 190 args=args, 191 request_type=request_type.full_name, 192 client_method_path=calling_path, 193 has_upload=bool(method_info.upload_config), 194 has_download=bool(method_info.supports_download) 195 ) 196 self.__command_list.append(command_info) 197 198 def __LookupMessage(self, message, field): 199 message_type = self.__message_registry.LookupDescriptor( 200 '%s.%s' % (message.name, field.type_name)) 201 if message_type is None: 202 message_type = self.__message_registry.LookupDescriptor( 203 field.type_name) 204 return message_type 205 206 def __GetCommandName(self, method_id): 207 command_name = method_id 208 prefix = '%s.' % self.__package 209 if command_name.startswith(prefix): 210 command_name = command_name[len(prefix):] 211 command_name = command_name.replace('.', '_') 212 return command_name 213 214 def __GetConversion(self, extended_field, extended_message): 215 field = extended_field.field_descriptor 216 217 type_name = '' 218 if field.variant in (messages.Variant.MESSAGE, messages.Variant.ENUM): 219 if field.type_name.startswith('protorpc.'): 220 type_name = field.type_name 221 else: 222 field_message = self.__LookupMessage(extended_message, field) 223 if field_message is None: 224 raise ValueError( 225 'Could not find type for field %s' % field.name) 226 type_name = 'messages.%s' % field_message.full_name 227 228 template = '' 229 if field.variant in (messages.Variant.INT64, messages.Variant.UINT64): 230 template = 'int(%s)' 231 elif field.variant == messages.Variant.MESSAGE: 232 template = 'apitools_base.JsonToMessage(%s, %%s)' % type_name 233 elif field.variant == messages.Variant.ENUM: 234 template = '%s(%%s)' % type_name 235 elif field.variant == messages.Variant.STRING: 236 template = "%s.decode('utf8')" 237 238 if self.__FieldIsRepeated(extended_field.field_descriptor): 239 if template: 240 template = '[%s for x in %%s]' % (template % 'x') 241 242 return template 243 244 def __FieldIsRequired(self, field): 245 return field.label == descriptor.FieldDescriptor.Label.REQUIRED 246 247 def __FieldIsRepeated(self, field): 248 return field.label == descriptor.FieldDescriptor.Label.REPEATED 249 250 def __FlagInfoFromField(self, extended_field, extended_message, fv=''): 251 field = extended_field.field_descriptor 252 flag_info = FlagInfo() 253 flag_info.name = str(field.name) 254 # TODO(craigcitro): We should key by variant. 255 flag_info.type = _VARIANT_TO_FLAG_TYPE_MAP[field.variant] 256 flag_info.description = extended_field.description 257 if field.default_value: 258 # TODO(craigcitro): Formatting? 259 flag_info.default = field.default_value 260 if flag_info.type == 'enum': 261 # TODO(craigcitro): Does protorpc do this for us? 262 enum_type = self.__LookupMessage(extended_message, field) 263 if enum_type is None: 264 raise ValueError('Cannot find enum type %s', field.type_name) 265 flag_info.enum_values = [x.name for x in enum_type.values] 266 # Note that this choice is completely arbitrary -- but we only 267 # push the value through if the user specifies it, so this 268 # doesn't hurt anything. 269 if flag_info.default is None: 270 flag_info.default = flag_info.enum_values[0] 271 if self.__FieldIsRequired(field): 272 flag_info.required = True 273 flag_info.fv = fv 274 flag_info.conversion = self.__GetConversion( 275 extended_field, extended_message) 276 return flag_info 277 278 def __PrintFlagDeclarations(self, printer): 279 package = self.__client_info.package 280 function_name = '_Declare%sFlags' % (package[0].upper() + package[1:]) 281 printer() 282 printer() 283 printer('def %s():', function_name) 284 with printer.Indent(): 285 printer('"""Declare global flags in an idempotent way."""') 286 printer("if 'api_endpoint' in flags.FLAGS:") 287 with printer.Indent(): 288 printer('return') 289 printer('flags.DEFINE_string(') 290 with printer.Indent(' '): 291 printer("'api_endpoint',") 292 printer('%r,', self.__base_url) 293 printer("'URL of the API endpoint to use.',") 294 printer("short_name='%s_url')", self.__package) 295 printer('flags.DEFINE_string(') 296 with printer.Indent(' '): 297 printer("'history_file',") 298 printer('%r,', '~/.%s.%s.history' % 299 (self.__package, self.__version)) 300 printer("'File with interactive shell history.')") 301 printer('flags.DEFINE_multistring(') 302 with printer.Indent(' '): 303 printer("'add_header', [],") 304 printer("'Additional http headers (as key=value strings). '") 305 printer("'Can be specified multiple times.')") 306 printer('flags.DEFINE_string(') 307 with printer.Indent(' '): 308 printer("'service_account_json_keyfile', '',") 309 printer("'Filename for a JSON service account key downloaded'") 310 printer("' from the Developer Console.')") 311 for flag_info in self.__global_flags: 312 self.__PrintFlag(printer, flag_info) 313 printer() 314 printer() 315 printer('FLAGS = flags.FLAGS') 316 printer('apitools_base_cli.DeclareBaseFlags()') 317 printer('%s()', function_name) 318 319 def __PrintGetGlobalParams(self, printer): 320 printer('def GetGlobalParamsFromFlags():') 321 with printer.Indent(): 322 printer('"""Return a StandardQueryParameters based on flags."""') 323 printer('result = messages.StandardQueryParameters()') 324 325 for flag_info in self.__global_flags: 326 rhs = 'FLAGS.%s' % flag_info.name 327 if flag_info.conversion: 328 rhs = flag_info.conversion % rhs 329 printer('if FLAGS[%r].present:', flag_info.name) 330 with printer.Indent(): 331 printer('result.%s = %s', flag_info.name, rhs) 332 printer('return result') 333 printer() 334 printer() 335 336 def __PrintGetClient(self, printer): 337 printer('def GetClientFromFlags():') 338 with printer.Indent(): 339 printer('"""Return a client object, configured from flags."""') 340 printer('log_request = FLAGS.log_request or ' 341 'FLAGS.log_request_response') 342 printer('log_response = FLAGS.log_response or ' 343 'FLAGS.log_request_response') 344 printer('api_endpoint = apitools_base.NormalizeApiEndpoint(' 345 'FLAGS.api_endpoint)') 346 printer("additional_http_headers = dict(x.split('=', 1) for x in " 347 "FLAGS.add_header)") 348 printer('credentials_args = {') 349 with printer.Indent(' '): 350 printer("'service_account_json_keyfile': os.path.expanduser(" 351 'FLAGS.service_account_json_keyfile)') 352 printer('}') 353 printer('try:') 354 with printer.Indent(): 355 printer('client = client_lib.%s(', 356 self.__client_info.client_class_name) 357 with printer.Indent(indent=' '): 358 printer('api_endpoint, log_request=log_request,') 359 printer('log_response=log_response,') 360 printer('credentials_args=credentials_args,') 361 printer('additional_http_headers=additional_http_headers)') 362 printer('except apitools_base.CredentialsError as e:') 363 with printer.Indent(): 364 printer("print 'Error creating credentials: %%s' %% e") 365 printer('sys.exit(1)') 366 printer('return client') 367 printer() 368 printer() 369 370 def __PrintCommandDocstring(self, printer, command_info): 371 with printer.CommentContext(): 372 for line in textwrap.wrap('"""%s' % command_info.description, 373 printer.CalculateWidth()): 374 printer(line) 375 extended_descriptor.PrintIndentedDescriptions( 376 printer, command_info.args, 'Args') 377 extended_descriptor.PrintIndentedDescriptions( 378 printer, command_info.flags, 'Flags') 379 printer('"""') 380 381 def __PrintFlag(self, printer, flag_info): 382 printer('flags.DEFINE_%s(', flag_info.type) 383 with printer.Indent(indent=' '): 384 printer('%r,', flag_info.name) 385 printer('%r,', flag_info.default) 386 if flag_info.type == 'enum': 387 printer('%r,', flag_info.enum_values) 388 389 # TODO(craigcitro): Consider using 'drop_whitespace' elsewhere. 390 description_lines = textwrap.wrap( 391 flag_info.description, 75 - len(printer.indent), 392 drop_whitespace=False) 393 for line in description_lines[:-1]: 394 printer('%r', line) 395 last_line = description_lines[-1] if description_lines else '' 396 printer('%r%s', last_line, ',' if flag_info.fv else ')') 397 if flag_info.fv: 398 printer('flag_values=%s)', flag_info.fv) 399 if flag_info.required: 400 printer('flags.MarkFlagAsRequired(%r)', flag_info.name) 401 402 def __PrintPyShell(self, printer): 403 printer('class PyShell(appcommands.Cmd):') 404 printer() 405 with printer.Indent(): 406 printer('def Run(self, _):') 407 with printer.Indent(): 408 printer( 409 '"""Run an interactive python shell with the client."""') 410 printer('client = GetClientFromFlags()') 411 printer('params = GetGlobalParamsFromFlags()') 412 printer('for field in params.all_fields():') 413 with printer.Indent(): 414 printer('value = params.get_assigned_value(field.name)') 415 printer('if value != field.default:') 416 with printer.Indent(): 417 printer('client.AddGlobalParam(field.name, value)') 418 printer('banner = """') 419 printer(' == %s interactive console ==' % ( 420 self.__client_info.package)) 421 printer(' client: a %s client' % 422 self.__client_info.package) 423 printer(' apitools_base: base apitools module') 424 printer(' messages: the generated messages module') 425 printer('"""') 426 printer('local_vars = {') 427 with printer.Indent(indent=' '): 428 printer("'apitools_base': apitools_base,") 429 printer("'client': client,") 430 printer("'client_lib': client_lib,") 431 printer("'messages': messages,") 432 printer('}') 433 printer("if platform.system() == 'Linux':") 434 with printer.Indent(): 435 printer('console = apitools_base_cli.ConsoleWithReadline(') 436 with printer.Indent(indent=' '): 437 printer('local_vars, histfile=FLAGS.history_file)') 438 printer('else:') 439 with printer.Indent(): 440 printer('console = code.InteractiveConsole(local_vars)') 441 printer('try:') 442 with printer.Indent(): 443 printer('console.interact(banner)') 444 printer('except SystemExit as e:') 445 with printer.Indent(): 446 printer('return e.code') 447 printer() 448 printer() 449 450 def WriteFile(self, printer): 451 """Write a simple CLI (currently just a stub).""" 452 printer('#!/usr/bin/env python') 453 printer('"""CLI for %s, version %s."""', 454 self.__package, self.__version) 455 printer('# NOTE: This file is autogenerated and should not be edited ' 456 'by hand.') 457 # TODO(craigcitro): Add a build stamp, along with some other 458 # information. 459 printer() 460 printer('import code') 461 printer('import os') 462 printer('import platform') 463 printer('import sys') 464 printer() 465 printer('import protorpc') 466 printer('from protorpc import message_types') 467 printer('from protorpc import messages') 468 printer() 469 appcommands_import = 'from google.apputils import appcommands' 470 printer(appcommands_import) 471 472 flags_import = 'import gflags as flags' 473 printer(flags_import) 474 printer() 475 printer('import %s as apitools_base', self.__base_files_package) 476 printer('from %s import cli as apitools_base_cli', 477 self.__base_files_package) 478 import_prefix = '' 479 printer('%simport %s as client_lib', 480 import_prefix, self.__client_info.client_rule_name) 481 printer('%simport %s as messages', 482 import_prefix, self.__client_info.messages_rule_name) 483 self.__PrintFlagDeclarations(printer) 484 printer() 485 printer() 486 self.__PrintGetGlobalParams(printer) 487 self.__PrintGetClient(printer) 488 self.__PrintPyShell(printer) 489 self.__PrintCommands(printer) 490 printer('def main(_):') 491 with printer.Indent(): 492 printer("appcommands.AddCmd('pyshell', PyShell)") 493 for command_info in self.__command_list: 494 printer("appcommands.AddCmd('%s', %s)", 495 command_info.name, command_info.class_name) 496 printer() 497 printer('apitools_base_cli.SetupLogger()') 498 # TODO(craigcitro): Just call SetDefaultCommand as soon as 499 # another appcommands release happens and this exists 500 # externally. 501 printer("if hasattr(appcommands, 'SetDefaultCommand'):") 502 with printer.Indent(): 503 printer("appcommands.SetDefaultCommand('pyshell')") 504 printer() 505 printer() 506 printer('run_main = apitools_base_cli.run_main') 507 printer() 508 printer("if __name__ == '__main__':") 509 with printer.Indent(): 510 printer('appcommands.Run()') 511 512 def __PrintCommands(self, printer): 513 """Print all commands in this registry using printer.""" 514 for command_info in self.__command_list: 515 arg_list = [arg_info.name for arg_info in command_info.args] 516 printer( 517 'class %s(apitools_base_cli.NewCmd):', command_info.class_name) 518 with printer.Indent(): 519 printer('"""Command wrapping %s."""', 520 command_info.client_method_path) 521 printer() 522 printer('usage = """%s%s%s"""', 523 command_info.name, 524 ' ' if arg_list else '', 525 ' '.join('<%s>' % argname for argname in arg_list)) 526 printer() 527 printer('def __init__(self, name, fv):') 528 with printer.Indent(): 529 printer('super(%s, self).__init__(name, fv)', 530 command_info.class_name) 531 for flag in command_info.flags: 532 self.__PrintFlag(printer, flag) 533 printer() 534 printer('def RunWithArgs(%s):', ', '.join(['self'] + arg_list)) 535 with printer.Indent(): 536 self.__PrintCommandDocstring(printer, command_info) 537 printer('client = GetClientFromFlags()') 538 printer('global_params = GetGlobalParamsFromFlags()') 539 printer( 540 'request = messages.%s(', command_info.request_type) 541 with printer.Indent(indent=' '): 542 for arg in command_info.args: 543 rhs = arg.name 544 if arg.conversion: 545 rhs = arg.conversion % arg.name 546 printer('%s=%s,', arg.name, rhs) 547 printer(')') 548 for flag_info in command_info.flags: 549 if flag_info.special: 550 continue 551 rhs = 'FLAGS.%s' % flag_info.name 552 if flag_info.conversion: 553 rhs = flag_info.conversion % rhs 554 printer('if FLAGS[%r].present:', flag_info.name) 555 with printer.Indent(): 556 printer('request.%s = %s', flag_info.name, rhs) 557 call_args = ['request', 'global_params=global_params'] 558 if command_info.has_upload: 559 call_args.append('upload=upload') 560 printer('upload = None') 561 printer('if FLAGS.upload_filename:') 562 with printer.Indent(): 563 printer('upload = apitools_base.Upload.FromFile(') 564 printer(' FLAGS.upload_filename, ' 565 'FLAGS.upload_mime_type,') 566 printer(' progress_callback=' 567 'apitools_base.UploadProgressPrinter,') 568 printer(' finish_callback=' 569 'apitools_base.UploadCompletePrinter)') 570 if command_info.has_download: 571 call_args.append('download=download') 572 printer('download = None') 573 printer('if FLAGS.download_filename:') 574 with printer.Indent(): 575 printer('download = apitools_base.Download.' 576 'FromFile(FLAGS.download_filename, ' 577 'overwrite=FLAGS.overwrite,') 578 printer(' progress_callback=' 579 'apitools_base.DownloadProgressPrinter,') 580 printer(' finish_callback=' 581 'apitools_base.DownloadCompletePrinter)') 582 printer( 583 'result = client.%s(', command_info.client_method_path) 584 with printer.Indent(indent=' '): 585 printer('%s)', ', '.join(call_args)) 586 printer('print apitools_base_cli.FormatOutput(result)') 587 printer() 588 printer() 589