1#!/usr/bin/env python
2"""Appcommands-compatible command class with extra fixins."""
3from __future__ import print_function
4
5import cmd
6import inspect
7import pdb
8import shlex
9import sys
10import traceback
11import types
12
13import six
14
15from google.apputils import app
16from google.apputils import appcommands
17import gflags as flags
18
19__all__ = [
20    'NewCmd',
21    'Repl',
22]
23
24flags.DEFINE_boolean(
25    'debug_mode', False,
26    'Show tracebacks on Python exceptions.')
27flags.DEFINE_boolean(
28    'headless', False,
29    'Assume no user is at the controlling console.')
30FLAGS = flags.FLAGS
31
32
33def _SafeMakeAscii(s):
34    if isinstance(s, six.text_type):
35        return s.encode('ascii')
36    elif isinstance(s, str):
37        return s.decode('ascii')
38    else:
39        return six.text_type(s).encode('ascii', 'backslashreplace')
40
41
42class NewCmd(appcommands.Cmd):
43
44    """Featureful extension of appcommands.Cmd."""
45
46    def __init__(self, name, flag_values):
47        super(NewCmd, self).__init__(name, flag_values)
48        run_with_args = getattr(self, 'RunWithArgs', None)
49        self._new_style = isinstance(run_with_args, types.MethodType)
50        if self._new_style:
51            func = run_with_args.__func__
52
53            argspec = inspect.getargspec(func)
54            if argspec.args and argspec.args[0] == 'self':
55                argspec = argspec._replace(  # pylint: disable=protected-access
56                    args=argspec.args[1:])
57            self._argspec = argspec
58            # TODO(craigcitro): Do we really want to support all this
59            # nonsense?
60            self._star_args = self._argspec.varargs is not None
61            self._star_kwds = self._argspec.keywords is not None
62            self._max_args = len(self._argspec.args or ())
63            self._min_args = self._max_args - len(self._argspec.defaults or ())
64            if self._star_args:
65                self._max_args = sys.maxsize
66
67            self._debug_mode = FLAGS.debug_mode
68            self.surface_in_shell = True
69            self.__doc__ = self.RunWithArgs.__doc__
70
71    def __getattr__(self, name):
72        if name in self._command_flags:
73            return self._command_flags[name].value
74        return super(NewCmd, self).__getattribute__(name)
75
76    def _GetFlag(self, flagname):
77        if flagname in self._command_flags:
78            return self._command_flags[flagname]
79        else:
80            return None
81
82    def Run(self, argv):
83        """Run this command.
84
85        If self is a new-style command, we set up arguments and call
86        self.RunWithArgs, gracefully handling exceptions. If not, we
87        simply call self.Run(argv).
88
89        Args:
90          argv: List of arguments as strings.
91
92        Returns:
93          0 on success, nonzero on failure.
94        """
95        if not self._new_style:
96            return super(NewCmd, self).Run(argv)
97
98        # TODO(craigcitro): We need to save and restore flags each time so
99        # that we can per-command flags in the REPL.
100        args = argv[1:]
101        fail = None
102        fail_template = '%s positional args, found %d, expected at %s %d'
103        if len(args) < self._min_args:
104            fail = fail_template % ('Not enough', len(args),
105                                    'least', self._min_args)
106        if len(args) > self._max_args:
107            fail = fail_template % ('Too many', len(args),
108                                    'most', self._max_args)
109        if fail:
110            print(fail)
111            if self.usage:
112                print('Usage: %s' % (self.usage,))
113            return 1
114
115        if self._debug_mode:
116            return self.RunDebug(args, {})
117        else:
118            return self.RunSafely(args, {})
119
120    def RunCmdLoop(self, argv):
121        """Hook for use in cmd.Cmd-based command shells."""
122        try:
123            args = shlex.split(argv)
124        except ValueError as e:
125            raise SyntaxError(self.EncodeForPrinting(e))
126        return self.Run([self._command_name] + args)
127
128    @staticmethod
129    def EncodeForPrinting(s):
130        """Safely encode a string as the encoding for sys.stdout."""
131        encoding = sys.stdout.encoding or 'ascii'
132        return six.text_type(s).encode(encoding, 'backslashreplace')
133
134    def _FormatError(self, e):
135        """Hook for subclasses to modify how error messages are printed."""
136        return _SafeMakeAscii(e)
137
138    def _HandleError(self, e):
139        message = self._FormatError(e)
140        print('Exception raised in %s operation: %s' % (
141            self._command_name, message))
142        return 1
143
144    def _IsDebuggableException(self, e):
145        """Hook for subclasses to skip debugging on certain exceptions."""
146        return not isinstance(e, app.UsageError)
147
148    def RunDebug(self, args, kwds):
149        """Run this command in debug mode."""
150        try:
151            return_value = self.RunWithArgs(*args, **kwds)
152        except BaseException as e:
153            # Don't break into the debugger for expected exceptions.
154            if not self._IsDebuggableException(e):
155                return self._HandleError(e)
156            print()
157            print('****************************************************')
158            print('**   Unexpected Exception raised in execution!    **')
159            if FLAGS.headless:
160                print('**  --headless mode enabled, exiting.             **')
161                print('**  See STDERR for traceback.                     **')
162            else:
163                print('**  --debug_mode enabled, starting pdb.           **')
164            print('****************************************************')
165            print()
166            traceback.print_exc()
167            print()
168            if not FLAGS.headless:
169                pdb.post_mortem()
170            return 1
171        return return_value
172
173    def RunSafely(self, args, kwds):
174        """Run this command, turning exceptions into print statements."""
175        try:
176            return_value = self.RunWithArgs(*args, **kwds)
177        except BaseException as e:
178            return self._HandleError(e)
179        return return_value
180
181
182class CommandLoop(cmd.Cmd):
183
184    """Instance of cmd.Cmd built to work with NewCmd."""
185
186    class TerminateSignal(Exception):
187
188        """Exception type used for signaling loop completion."""
189
190    def __init__(self, commands, prompt):
191        cmd.Cmd.__init__(self)
192        self._commands = {'help': commands['help']}
193        self._special_command_names = ['help', 'repl', 'EOF']
194        for name, command in commands.items():
195            if (name not in self._special_command_names and
196                    isinstance(command, NewCmd) and
197                    command.surface_in_shell):
198                self._commands[name] = command
199                setattr(self, 'do_%s' % (name,), command.RunCmdLoop)
200        self._default_prompt = prompt
201        self._set_prompt()
202        self._last_return_code = 0
203
204    @property
205    def last_return_code(self):
206        return self._last_return_code
207
208    def _set_prompt(self):
209        self.prompt = self._default_prompt
210
211    def do_EOF(self, *unused_args):  # pylint: disable=invalid-name
212        """Terminate the running command loop.
213
214        This function raises an exception to avoid the need to do
215        potentially-error-prone string parsing inside onecmd.
216
217        Args:
218          *unused_args: unused.
219
220        Returns:
221          Never returns.
222
223        Raises:
224          CommandLoop.TerminateSignal: always.
225        """
226        raise CommandLoop.TerminateSignal()
227
228    def postloop(self):
229        print('Goodbye.')
230
231    # pylint: disable=arguments-differ
232    def completedefault(self, unused_text, line, unused_begidx, unused_endidx):
233        if not line:
234            return []
235        else:
236            command_name = line.partition(' ')[0].lower()
237            usage = ''
238            if command_name in self._commands:
239                usage = self._commands[command_name].usage
240            if usage:
241                print()
242                print(usage)
243                print('%s%s' % (self.prompt, line), end=' ')
244            return []
245    # pylint: enable=arguments-differ
246
247    def emptyline(self):
248        print('Available commands:', end=' ')
249        print(' '.join(list(self._commands)))
250
251    def precmd(self, line):
252        """Preprocess the shell input."""
253        if line == 'EOF':
254            return line
255        if line.startswith('exit') or line.startswith('quit'):
256            return 'EOF'
257        words = line.strip().split()
258        if len(words) == 1 and words[0] not in ['help', 'ls', 'version']:
259            return 'help %s' % (line.strip(),)
260        return line
261
262    def onecmd(self, line):
263        """Process a single command.
264
265        Runs a single command, and stores the return code in
266        self._last_return_code. Always returns False unless the command
267        was EOF.
268
269        Args:
270          line: (str) Command line to process.
271
272        Returns:
273          A bool signaling whether or not the command loop should terminate.
274        """
275        try:
276            self._last_return_code = cmd.Cmd.onecmd(self, line)
277        except CommandLoop.TerminateSignal:
278            return True
279        except BaseException as e:
280            name = line.split(' ')[0]
281            print('Error running %s:' % name)
282            print(e)
283            self._last_return_code = 1
284        return False
285
286    def get_names(self):
287        names = dir(self)
288        commands = (name for name in self._commands
289                    if name not in self._special_command_names)
290        names.extend('do_%s' % (name,) for name in commands)
291        names.remove('do_EOF')
292        return names
293
294    def do_help(self, command_name):
295        """Print the help for command_name (if present) or general help."""
296
297        # TODO(craigcitro): Add command-specific flags.
298        def FormatOneCmd(name, command, command_names):
299            indent_size = appcommands.GetMaxCommandLength() + 3
300            if len(command_names) > 1:
301                indent = ' ' * indent_size
302                command_help = flags.TextWrap(
303                    command.CommandGetHelp('', cmd_names=command_names),
304                    indent=indent,
305                    firstline_indent='')
306                first_help_line, _, rest = command_help.partition('\n')
307                first_line = '%-*s%s' % (indent_size,
308                                         name + ':', first_help_line)
309                return '\n'.join((first_line, rest))
310            else:
311                default_indent = '  '
312                return '\n' + flags.TextWrap(
313                    command.CommandGetHelp('', cmd_names=command_names),
314                    indent=default_indent,
315                    firstline_indent=default_indent) + '\n'
316
317        if not command_name:
318            print('\nHelp for commands:\n')
319            command_names = list(self._commands)
320            print('\n\n'.join(
321                FormatOneCmd(name, command, command_names)
322                for name, command in self._commands.items()
323                if name not in self._special_command_names))
324            print()
325        elif command_name in self._commands:
326            print(FormatOneCmd(command_name, self._commands[command_name],
327                               command_names=[command_name]))
328        return 0
329
330    def postcmd(self, stop, line):
331        return bool(stop) or line == 'EOF'
332
333
334class Repl(NewCmd):
335
336    """Start an interactive session."""
337    PROMPT = '> '
338
339    def __init__(self, name, fv):
340        super(Repl, self).__init__(name, fv)
341        self.surface_in_shell = False
342        flags.DEFINE_string(
343            'prompt', '',
344            'Prompt to use for interactive shell.',
345            flag_values=fv)
346
347    def RunWithArgs(self):
348        """Start an interactive session."""
349        prompt = FLAGS.prompt or self.PROMPT
350        repl = CommandLoop(appcommands.GetCommandList(), prompt=prompt)
351        print('Welcome! (Type help for more information.)')
352        while True:
353            try:
354                repl.cmdloop()
355                break
356            except KeyboardInterrupt:
357                print()
358        return repl.last_return_code
359