1#!/usr/bin/env python
2#
3# Copyright 2015 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Appcommands-compatible command class with extra fixins."""
18from __future__ import absolute_import
19from __future__ import print_function
20
21import cmd
22import inspect
23import pdb
24import shlex
25import sys
26import traceback
27import types
28
29import gflags as flags
30from google.apputils import app
31from google.apputils import appcommands
32import six
33
34
35__all__ = [
36    'NewCmd',
37    'Repl',
38]
39
40flags.DEFINE_boolean(
41    'debug_mode', False,
42    'Show tracebacks on Python exceptions.')
43flags.DEFINE_boolean(
44    'headless', False,
45    'Assume no user is at the controlling console.')
46FLAGS = flags.FLAGS
47
48
49def _SafeMakeAscii(s):
50    if isinstance(s, six.text_type):
51        return s.encode('ascii')
52    elif isinstance(s, str):
53        return s.decode('ascii')
54    return six.text_type(s).encode('ascii', 'backslashreplace')
55
56
57class NewCmd(appcommands.Cmd):
58
59    """Featureful extension of appcommands.Cmd."""
60
61    def __init__(self, name, flag_values):
62        super(NewCmd, self).__init__(name, flag_values)
63        run_with_args = getattr(self, 'RunWithArgs', None)
64        self._new_style = isinstance(run_with_args, types.MethodType)
65        if self._new_style:
66            func = run_with_args.__func__
67
68            argspec = inspect.getargspec(func)
69            if argspec.args and argspec.args[0] == 'self':
70                argspec = argspec._replace(  # pylint: disable=protected-access
71                    args=argspec.args[1:])
72            self._argspec = argspec
73            # TODO(craigcitro): Do we really want to support all this
74            # nonsense?
75            self._star_args = self._argspec.varargs is not None
76            self._star_kwds = self._argspec.keywords is not None
77            self._max_args = len(self._argspec.args or ())
78            self._min_args = self._max_args - len(self._argspec.defaults or ())
79            if self._star_args:
80                self._max_args = sys.maxsize
81
82            self._debug_mode = FLAGS.debug_mode
83            self.surface_in_shell = True
84            self.__doc__ = self.RunWithArgs.__doc__
85
86    def __getattr__(self, name):
87        if name in self._command_flags:
88            return self._command_flags[name].value
89        return super(NewCmd, self).__getattribute__(name)
90
91    def _GetFlag(self, flagname):
92        if flagname in self._command_flags:
93            return self._command_flags[flagname]
94        return None
95
96    def Run(self, argv):
97        """Run this command.
98
99        If self is a new-style command, we set up arguments and call
100        self.RunWithArgs, gracefully handling exceptions. If not, we
101        simply call self.Run(argv).
102
103        Args:
104          argv: List of arguments as strings.
105
106        Returns:
107          0 on success, nonzero on failure.
108        """
109        if not self._new_style:
110            return super(NewCmd, self).Run(argv)
111
112        # TODO(craigcitro): We need to save and restore flags each time so
113        # that we can per-command flags in the REPL.
114        args = argv[1:]
115        fail = None
116        fail_template = '%s positional args, found %d, expected at %s %d'
117        if len(args) < self._min_args:
118            fail = fail_template % ('Not enough', len(args),
119                                    'least', self._min_args)
120        if len(args) > self._max_args:
121            fail = fail_template % ('Too many', len(args),
122                                    'most', self._max_args)
123        if fail:
124            print(fail)
125            if self.usage:
126                print('Usage: %s' % (self.usage,))
127            return 1
128
129        if self._debug_mode:
130            return self.RunDebug(args, {})
131        return self.RunSafely(args, {})
132
133    def RunCmdLoop(self, argv):
134        """Hook for use in cmd.Cmd-based command shells."""
135        try:
136            args = shlex.split(argv)
137        except ValueError as e:
138            raise SyntaxError(self.EncodeForPrinting(e))
139        return self.Run([self._command_name] + args)
140
141    @staticmethod
142    def EncodeForPrinting(s):
143        """Safely encode a string as the encoding for sys.stdout."""
144        encoding = sys.stdout.encoding or 'ascii'
145        return six.text_type(s).encode(encoding, 'backslashreplace')
146
147    def _FormatError(self, e):
148        """Hook for subclasses to modify how error messages are printed."""
149        return _SafeMakeAscii(e)
150
151    def _HandleError(self, e):
152        message = self._FormatError(e)
153        print('Exception raised in %s operation: %s' % (
154            self._command_name, message))
155        return 1
156
157    def _IsDebuggableException(self, e):
158        """Hook for subclasses to skip debugging on certain exceptions."""
159        return not isinstance(e, app.UsageError)
160
161    def RunDebug(self, args, kwds):
162        """Run this command in debug mode."""
163        try:
164            return_value = self.RunWithArgs(*args, **kwds)
165        except BaseException as e:
166            # Don't break into the debugger for expected exceptions.
167            if not self._IsDebuggableException(e):
168                return self._HandleError(e)
169            print()
170            print('****************************************************')
171            print('**   Unexpected Exception raised in execution!    **')
172            if FLAGS.headless:
173                print('**  --headless mode enabled, exiting.             **')
174                print('**  See STDERR for traceback.                     **')
175            else:
176                print('**  --debug_mode enabled, starting pdb.           **')
177            print('****************************************************')
178            print()
179            traceback.print_exc()
180            print()
181            if not FLAGS.headless:
182                pdb.post_mortem()
183            return 1
184        return return_value
185
186    def RunSafely(self, args, kwds):
187        """Run this command, turning exceptions into print statements."""
188        try:
189            return_value = self.RunWithArgs(*args, **kwds)
190        except BaseException as e:
191            return self._HandleError(e)
192        return return_value
193
194
195class CommandLoop(cmd.Cmd):
196
197    """Instance of cmd.Cmd built to work with NewCmd."""
198
199    class TerminateSignal(Exception):
200
201        """Exception type used for signaling loop completion."""
202
203    def __init__(self, commands, prompt):
204        cmd.Cmd.__init__(self)
205        self._commands = {'help': commands['help']}
206        self._special_command_names = ['help', 'repl', 'EOF']
207        for name, command in commands.items():
208            if (name not in self._special_command_names and
209                    isinstance(command, NewCmd) and
210                    command.surface_in_shell):
211                self._commands[name] = command
212                setattr(self, 'do_%s' % (name,), command.RunCmdLoop)
213        self._default_prompt = prompt
214        self._set_prompt()
215        self._last_return_code = 0
216
217    @property
218    def last_return_code(self):
219        return self._last_return_code
220
221    def _set_prompt(self):  # pylint: disable=invalid-name
222        self.prompt = self._default_prompt
223
224    def do_EOF(self, *unused_args):  # pylint: disable=invalid-name
225        """Terminate the running command loop.
226
227        This function raises an exception to avoid the need to do
228        potentially-error-prone string parsing inside onecmd.
229
230        Args:
231          *unused_args: unused.
232
233        Returns:
234          Never returns.
235
236        Raises:
237          CommandLoop.TerminateSignal: always.
238        """
239        raise CommandLoop.TerminateSignal()
240
241    def postloop(self):
242        print('Goodbye.')
243
244    # pylint: disable=arguments-differ
245    def completedefault(self, unused_text, line, unused_begidx, unused_endidx):
246        if not line:
247            return []
248        else:
249            command_name = line.partition(' ')[0].lower()
250            usage = ''
251            if command_name in self._commands:
252                usage = self._commands[command_name].usage
253            if usage:
254                print()
255                print(usage)
256                print('%s%s' % (self.prompt, line), end=' ')
257            return []
258    # pylint: enable=arguments-differ
259
260    def emptyline(self):
261        print('Available commands:', end=' ')
262        print(' '.join(list(self._commands)))
263
264    def precmd(self, line):
265        """Preprocess the shell input."""
266        if line == 'EOF':
267            return line
268        if line.startswith('exit') or line.startswith('quit'):
269            return 'EOF'
270        words = line.strip().split()
271        if len(words) == 1 and words[0] not in ['help', 'ls', 'version']:
272            return 'help %s' % (line.strip(),)
273        return line
274
275    def onecmd(self, line):
276        """Process a single command.
277
278        Runs a single command, and stores the return code in
279        self._last_return_code. Always returns False unless the command
280        was EOF.
281
282        Args:
283          line: (str) Command line to process.
284
285        Returns:
286          A bool signaling whether or not the command loop should terminate.
287        """
288        try:
289            self._last_return_code = cmd.Cmd.onecmd(self, line)
290        except CommandLoop.TerminateSignal:
291            return True
292        except BaseException as e:
293            name = line.split(' ')[0]
294            print('Error running %s:' % name)
295            print(e)
296            self._last_return_code = 1
297        return False
298
299    def get_names(self):
300        names = dir(self)
301        commands = (name for name in self._commands
302                    if name not in self._special_command_names)
303        names.extend('do_%s' % (name,) for name in commands)
304        names.remove('do_EOF')
305        return names
306
307    def do_help(self, arg):
308        """Print the help for command_name (if present) or general help."""
309
310        command_name = arg
311
312        # TODO(craigcitro): Add command-specific flags.
313        def FormatOneCmd(name, command, command_names):
314            """Format one command."""
315            indent_size = appcommands.GetMaxCommandLength() + 3
316            if len(command_names) > 1:
317                indent = ' ' * indent_size
318                command_help = flags.TextWrap(
319                    command.CommandGetHelp('', cmd_names=command_names),
320                    indent=indent,
321                    firstline_indent='')
322                first_help_line, _, rest = command_help.partition('\n')
323                first_line = '%-*s%s' % (indent_size,
324                                         name + ':', first_help_line)
325                return '\n'.join((first_line, rest))
326            default_indent = '  '
327            return '\n' + flags.TextWrap(
328                command.CommandGetHelp('', cmd_names=command_names),
329                indent=default_indent,
330                firstline_indent=default_indent) + '\n'
331
332        if not command_name:
333            print('\nHelp for commands:\n')
334            command_names = list(self._commands)
335            print('\n\n'.join(
336                FormatOneCmd(name, command, command_names)
337                for name, command in self._commands.items()
338                if name not in self._special_command_names))
339            print()
340        elif command_name in self._commands:
341            print(FormatOneCmd(command_name, self._commands[command_name],
342                               command_names=[command_name]))
343        return 0
344
345    def postcmd(self, stop, line):
346        return bool(stop) or line == 'EOF'
347
348
349class Repl(NewCmd):
350
351    """Start an interactive session."""
352    PROMPT = '> '
353
354    def __init__(self, name, fv):
355        super(Repl, self).__init__(name, fv)
356        self.surface_in_shell = False
357        flags.DEFINE_string(
358            'prompt', '',
359            'Prompt to use for interactive shell.',
360            flag_values=fv)
361
362    def RunWithArgs(self):
363        """Start an interactive session."""
364        prompt = FLAGS.prompt or self.PROMPT
365        repl = CommandLoop(appcommands.GetCommandList(), prompt=prompt)
366        print('Welcome! (Type help for more information.)')
367        while True:
368            try:
369                repl.cmdloop()
370                break
371            except KeyboardInterrupt:
372                print()
373        return repl.last_return_code
374