1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Serializes an Environment into a shell file."""
15
16import inspect
17
18# Disable super() warnings since this file must be Python 2 compatible.
19# pylint: disable=super-with-arguments
20
21
22class _BaseShellVisitor(object):  # pylint: disable=useless-object-inheritance
23    def __init__(self, *args, **kwargs):
24        pathsep = kwargs.pop('pathsep', ':')
25        super(_BaseShellVisitor, self).__init__(*args, **kwargs)
26        self._pathsep = pathsep
27        self._outs = None
28
29    def _remove_value_from_path(self, variable, value):
30        return ('{variable}="$(echo "${variable}"'
31                ' | sed "s|{pathsep}{value}{pathsep}|{pathsep}|g;"'
32                ' | sed "s|^{value}{pathsep}||g;"'
33                ' | sed "s|{pathsep}{value}$||g;"'
34                ')"\nexport {variable}\n'.format(variable=variable,
35                                                 value=value,
36                                                 pathsep=self._pathsep))
37
38    def visit_hash(self, hash):  # pylint: disable=redefined-builtin
39        del hash
40        self._outs.write(
41            inspect.cleandoc('''
42        # This should detect bash and zsh, which have a hash command that must
43        # be called to get it to forget past commands. Without forgetting past
44        # commands the $PATH changes we made may not be respected.
45        if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
46            hash -r\n
47        fi
48        '''))
49
50
51class ShellVisitor(_BaseShellVisitor):
52    """Serializes an Environment into a shell file."""
53    def __init__(self, *args, **kwargs):
54        super(ShellVisitor, self).__init__(*args, **kwargs)
55        self._replacements = ()
56
57    def serialize(self, env, outs):
58        try:
59            self._replacements = tuple(
60                (key, env.get(key) if value is None else value)
61                for key, value in env.replacements)
62            self._outs = outs
63
64            env.accept(self)
65
66        finally:
67            self._replacements = ()
68            self._outs = None
69
70    def _apply_replacements(self, action):
71        value = action.value
72        for var, replacement in self._replacements:
73            if var != action.name:
74                value = value.replace(replacement, '${}'.format(var))
75        return value
76
77    def visit_set(self, set):  # pylint: disable=redefined-builtin
78        value = self._apply_replacements(set)
79        self._outs.write('{name}="{value}"\nexport {name}\n'.format(
80            name=set.name, value=value))
81
82    def visit_clear(self, clear):
83        self._outs.write('unset {name}\n'.format(**vars(clear)))
84
85    def visit_remove(self, remove):
86        value = self._apply_replacements(remove)
87        self._outs.write('# Remove \n#   {value}\n# from\n#   {value}\n# '
88                         'before adding it back.\n')
89        self._outs.write(self._remove_value_from_path(remove.name, value))
90
91    def _join(self, *args):
92        if len(args) == 1 and isinstance(args[0], (list, tuple)):
93            args = args[0]
94        return self._pathsep.join(args)
95
96    def visit_prepend(self, prepend):
97        value = self._apply_replacements(prepend)
98        value = self._join(value, '${}'.format(prepend.name))
99        self._outs.write('{name}="{value}"\nexport {name}\n'.format(
100            name=prepend.name, value=value))
101
102    def visit_append(self, append):
103        value = self._apply_replacements(append)
104        value = self._join('${}'.format(append.name), value)
105        self._outs.write('{name}="{value}"\nexport {name}\n'.format(
106            name=append.name, value=value))
107
108    def visit_echo(self, echo):
109        # TODO(mohrr) use shlex.quote().
110        self._outs.write('if [ -z "${PW_ENVSETUP_QUIET:-}" ]; then\n')
111        if echo.newline:
112            self._outs.write('  echo "{}"\n'.format(echo.value))
113        else:
114            self._outs.write('  echo -n "{}"\n'.format(echo.value))
115        self._outs.write('fi\n')
116
117    def visit_comment(self, comment):
118        for line in comment.value.splitlines():
119            self._outs.write('# {}\n'.format(line))
120
121    def visit_command(self, command):
122        # TODO(mohrr) use shlex.quote here?
123        self._outs.write('{}\n'.format(' '.join(command.command)))
124        if not command.exit_on_error:
125            return
126
127        # Assume failing command produced relevant output.
128        self._outs.write('if [ "$?" -ne 0 ]; then\n  return 1\nfi\n')
129
130    def visit_doctor(self, doctor):
131        self._outs.write('if [ -z "$PW_ACTIVATE_SKIP_CHECKS" ]; then\n')
132        self.visit_command(doctor)
133        self._outs.write('else\n')
134        self._outs.write('echo Skipping environment check because '
135                         'PW_ACTIVATE_SKIP_CHECKS is set\n')
136        self._outs.write('fi\n')
137
138    def visit_blank_line(self, blank_line):
139        del blank_line
140        self._outs.write('\n')
141
142    def visit_function(self, function):
143        self._outs.write('{name}() {{\n{body}\n}}\n'.format(
144            name=function.name, body=function.body))
145
146
147class DeactivateShellVisitor(_BaseShellVisitor):
148    """Removes values from an Environment."""
149    def __init__(self, *args, **kwargs):
150        pathsep = kwargs.pop('pathsep', ':')
151        super(DeactivateShellVisitor, self).__init__(*args, **kwargs)
152        self._pathsep = pathsep
153
154    def serialize(self, env, outs):
155        try:
156            self._outs = outs
157
158            env.accept(self)
159
160        finally:
161            self._outs = None
162
163    def visit_set(self, set):  # pylint: disable=redefined-builtin
164        self._outs.write('unset {name}\n'.format(name=set.name))
165
166    def visit_clear(self, clear):
167        pass  # Not relevant.
168
169    def visit_remove(self, remove):
170        pass  # Not relevant.
171
172    def visit_prepend(self, prepend):
173        self._outs.write(
174            self._remove_value_from_path(prepend.name, prepend.value))
175
176    def visit_append(self, append):
177        self._outs.write(
178            self._remove_value_from_path(append.name, append.value))
179
180    def visit_echo(self, echo):
181        pass  # Not relevant.
182
183    def visit_comment(self, comment):
184        pass  # Not relevant.
185
186    def visit_command(self, command):
187        pass  # Not relevant.
188
189    def visit_doctor(self, doctor):
190        pass  # Not relevant.
191
192    def visit_blank_line(self, blank_line):
193        pass  # Not relevant.
194
195    def visit_function(self, function):
196        pass  # Not relevant.
197