1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Stores the environment changes necessary for Pigweed."""
15
16import contextlib
17import os
18import re
19
20# The order here is important. On Python 2 we want StringIO.StringIO and not
21# io.StringIO. On Python 3 there is no StringIO module so we want io.StringIO.
22# Not using six because six is not a standard package we can expect to have
23# installed in the system Python.
24try:
25    from StringIO import StringIO  # type: ignore
26except ImportError:
27    from io import StringIO
28
29from . import apply_visitor
30from . import batch_visitor
31from . import json_visitor
32from . import shell_visitor
33
34# Disable super() warnings since this file must be Python 2 compatible.
35# pylint: disable=super-with-arguments
36
37
38class BadNameType(TypeError):
39    pass
40
41
42class BadValueType(TypeError):
43    pass
44
45
46class EmptyValue(ValueError):
47    pass
48
49
50class NewlineInValue(TypeError):
51    pass
52
53
54class BadVariableName(ValueError):
55    pass
56
57
58class UnexpectedAction(ValueError):
59    pass
60
61
62class AcceptNotOverridden(TypeError):
63    pass
64
65
66class _Action(object):  # pylint: disable=useless-object-inheritance
67    def unapply(self, env, orig_env):
68        pass
69
70    def accept(self, visitor):
71        del visitor
72        raise AcceptNotOverridden('accept() not overridden for {}'.format(
73            self.__class__.__name__))
74
75    def write_deactivate(self,
76                         outs,
77                         windows=(os.name == 'nt'),
78                         replacements=()):
79        pass
80
81
82class _VariableAction(_Action):
83    # pylint: disable=keyword-arg-before-vararg
84    def __init__(self, name, value, allow_empty_values=False, *args, **kwargs):
85        super(_VariableAction, self).__init__(*args, **kwargs)
86        self.name = name
87        self.value = value
88        self.allow_empty_values = allow_empty_values
89
90        self._check()
91
92    def _check(self):
93        try:
94            # In python2, unicode is a distinct type.
95            valid_types = (str, unicode)
96        except NameError:
97            valid_types = (str, )
98
99        if not isinstance(self.name, valid_types):
100            raise BadNameType('variable name {!r} not of type str'.format(
101                self.name))
102        if not isinstance(self.value, valid_types):
103            raise BadValueType('{!r} value {!r} not of type str'.format(
104                self.name, self.value))
105
106        # Empty strings as environment variable values have different behavior
107        # on different operating systems. Just don't allow them.
108        if not self.allow_empty_values and self.value == '':
109            raise EmptyValue('{!r} value {!r} is the empty string'.format(
110                self.name, self.value))
111
112        # Many tools have issues with newlines in environment variable values.
113        # Just don't allow them.
114        if '\n' in self.value:
115            raise NewlineInValue('{!r} value {!r} contains a newline'.format(
116                self.name, self.value))
117
118        if not re.match(r'^[A-Z_][A-Z0-9_]*$', self.name, re.IGNORECASE):
119            raise BadVariableName('bad variable name {!r}'.format(self.name))
120
121    def unapply(self, env, orig_env):
122        if self.name in orig_env:
123            env[self.name] = orig_env[self.name]
124        else:
125            env.pop(self.name, None)
126
127
128class Set(_VariableAction):
129    """Set a variable."""
130    def accept(self, visitor):
131        visitor.visit_set(self)
132
133
134class Clear(_VariableAction):
135    """Remove a variable from the environment."""
136    def __init__(self, *args, **kwargs):
137        kwargs['value'] = ''
138        kwargs['allow_empty_values'] = True
139        super(Clear, self).__init__(*args, **kwargs)
140
141    def accept(self, visitor):
142        visitor.visit_clear(self)
143
144
145class Remove(_VariableAction):
146    """Remove a value from a PATH-like variable."""
147    def accept(self, visitor):
148        visitor.visit_remove(self)
149
150
151class BadVariableValue(ValueError):
152    pass
153
154
155def _append_prepend_check(action):
156    if '=' in action.value:
157        raise BadVariableValue('"{}" contains "="'.format(action.value))
158
159
160class Prepend(_VariableAction):
161    """Prepend a value to a PATH-like variable."""
162    def __init__(self, name, value, join, *args, **kwargs):
163        super(Prepend, self).__init__(name, value, *args, **kwargs)
164        self._join = join
165
166    def _check(self):
167        super(Prepend, self)._check()
168        _append_prepend_check(self)
169
170    def accept(self, visitor):
171        visitor.visit_prepend(self)
172
173
174class Append(_VariableAction):
175    """Append a value to a PATH-like variable. (Uncommon, see Prepend.)"""
176    def __init__(self, name, value, join, *args, **kwargs):
177        super(Append, self).__init__(name, value, *args, **kwargs)
178        self._join = join
179
180    def _check(self):
181        super(Append, self)._check()
182        _append_prepend_check(self)
183
184    def accept(self, visitor):
185        visitor.visit_append(self)
186
187
188class BadEchoValue(ValueError):
189    pass
190
191
192class Echo(_Action):
193    """Echo a value to the terminal."""
194    def __init__(self, value, newline, *args, **kwargs):
195        # These values act funny on Windows.
196        if value.lower() in ('off', 'on'):
197            raise BadEchoValue(value)
198        super(Echo, self).__init__(*args, **kwargs)
199        self.value = value
200        self.newline = newline
201
202    def accept(self, visitor):
203        visitor.visit_echo(self)
204
205
206class Comment(_Action):
207    """Add a comment to the init script."""
208    def __init__(self, value, *args, **kwargs):
209        super(Comment, self).__init__(*args, **kwargs)
210        self.value = value
211
212    def accept(self, visitor):
213        visitor.visit_comment(self)
214
215
216class Command(_Action):
217    """Run a command."""
218    def __init__(self, command, *args, **kwargs):
219        exit_on_error = kwargs.pop('exit_on_error', True)
220        super(Command, self).__init__(*args, **kwargs)
221        assert isinstance(command, (list, tuple))
222        self.command = command
223        self.exit_on_error = exit_on_error
224
225    def accept(self, visitor):
226        visitor.visit_command(self)
227
228
229class Doctor(Command):
230    def __init__(self, *args, **kwargs):
231        log_level = 'warn' if 'PW_ENVSETUP_QUIET' in os.environ else 'info'
232        super(Doctor, self).__init__(
233            command=['pw', '--no-banner', '--loglevel', log_level, 'doctor'],
234            *args,
235            **kwargs)
236
237    def accept(self, visitor):
238        visitor.visit_doctor(self)
239
240
241class BlankLine(_Action):
242    """Write a blank line to the init script."""
243    def accept(self, visitor):
244        visitor.visit_blank_line(self)
245
246
247class Function(_Action):
248    def __init__(self, name, body, *args, **kwargs):
249        super(Function, self).__init__(*args, **kwargs)
250        self.name = name
251        self.body = body
252
253    def accept(self, visitor):
254        visitor.visit_function(self)
255
256
257class Hash(_Action):
258    def accept(self, visitor):
259        visitor.visit_hash(self)
260
261
262class Join(object):  # pylint: disable=useless-object-inheritance
263    def __init__(self, pathsep=os.pathsep):
264        self.pathsep = pathsep
265
266
267# TODO(mohrr) remove disable=useless-object-inheritance once in Python 3.
268# pylint: disable=useless-object-inheritance
269class Environment(object):
270    """Stores the environment changes necessary for Pigweed.
271
272    These changes can be accessed by writing them to a file for bash-like
273    shells to source or by using this as a context manager.
274    """
275    def __init__(self, *args, **kwargs):
276        pathsep = kwargs.pop('pathsep', os.pathsep)
277        windows = kwargs.pop('windows', os.name == 'nt')
278        allcaps = kwargs.pop('allcaps', windows)
279        super(Environment, self).__init__(*args, **kwargs)
280        self._actions = []
281        self._pathsep = pathsep
282        self._windows = windows
283        self._allcaps = allcaps
284        self.replacements = []
285        self._join = Join(pathsep)
286        self._finalized = False
287
288    def add_replacement(self, variable, value=None):
289        self.replacements.append((variable, value))
290
291    def normalize_key(self, name):
292        if self._allcaps:
293            try:
294                return name.upper()
295            except AttributeError:
296                # The _Action class has code to handle incorrect types, so
297                # we just ignore this error here.
298                pass
299        return name
300
301    # A newline is printed after each high-level operation. Top-level
302    # operations should not invoke each other (this is why _remove() exists).
303
304    def set(self, name, value):
305        """Set a variable."""
306        assert not self._finalized
307        name = self.normalize_key(name)
308        self._actions.append(Set(name, value))
309        self._blankline()
310
311    def clear(self, name):
312        """Remove a variable."""
313        assert not self._finalized
314        name = self.normalize_key(name)
315        self._actions.append(Clear(name))
316        self._blankline()
317
318    def _remove(self, name, value):
319        """Remove a value from a variable."""
320        assert not self._finalized
321        name = self.normalize_key(name)
322        if self.get(name, None):
323            self._actions.append(Remove(name, value, self._pathsep))
324
325    def remove(self, name, value):
326        """Remove a value from a PATH-like variable."""
327        assert not self._finalized
328        self._remove(name, value)
329        self._blankline()
330
331    def append(self, name, value):
332        """Add a value to a PATH-like variable. Rarely used, see prepend()."""
333        assert not self._finalized
334        name = self.normalize_key(name)
335        if self.get(name, None):
336            self._remove(name, value)
337            self._actions.append(Append(name, value, self._join))
338        else:
339            self._actions.append(Set(name, value))
340        self._blankline()
341
342    def prepend(self, name, value):
343        """Add a value to the beginning of a PATH-like variable."""
344        assert not self._finalized
345        name = self.normalize_key(name)
346        if self.get(name, None):
347            self._remove(name, value)
348            self._actions.append(Prepend(name, value, self._join))
349        else:
350            self._actions.append(Set(name, value))
351        self._blankline()
352
353    def echo(self, value='', newline=True):
354        """Echo a value to the terminal."""
355        # echo() deliberately ignores self._finalized.
356        self._actions.append(Echo(value, newline))
357        if value:
358            self._blankline()
359
360    def comment(self, comment):
361        """Add a comment to the init script."""
362        # comment() deliberately ignores self._finalized.
363        self._actions.append(Comment(comment))
364        self._blankline()
365
366    def command(self, command, exit_on_error=True):
367        """Run a command."""
368        # command() deliberately ignores self._finalized.
369        self._actions.append(Command(command, exit_on_error=exit_on_error))
370        self._blankline()
371
372    def doctor(self):
373        """Run 'pw doctor'."""
374        self._actions.append(Doctor())
375
376    def function(self, name, body):
377        """Define a function."""
378        assert not self._finalized
379        self._actions.append(Command(name, body))
380        self._blankline()
381
382    def _blankline(self):
383        self._actions.append(BlankLine())
384
385    def finalize(self):
386        """Run cleanup at the end of environment setup."""
387        assert not self._finalized
388        self._finalized = True
389        self._actions.append(Hash())
390        self._blankline()
391
392        if not self._windows:
393            buf = StringIO()
394            self.write_deactivate(buf)
395            self._actions.append(Function('_pw_deactivate', buf.getvalue()))
396            self._blankline()
397
398    def accept(self, visitor):
399        for action in self._actions:
400            action.accept(visitor)
401
402    def json(self, outs):
403        json_visitor.JSONVisitor().serialize(self, outs)
404
405    def write(self, outs):
406        if self._windows:
407            visitor = batch_visitor.BatchVisitor(pathsep=self._pathsep)
408        else:
409            visitor = shell_visitor.ShellVisitor(pathsep=self._pathsep)
410        visitor.serialize(self, outs)
411
412    def write_deactivate(self, outs):
413        if self._windows:
414            return
415        visitor = shell_visitor.DeactivateShellVisitor(pathsep=self._pathsep)
416        visitor.serialize(self, outs)
417
418    @contextlib.contextmanager
419    def __call__(self, export=True):
420        """Set environment as if this was written to a file and sourced.
421
422        Within this context os.environ is updated with the environment
423        defined by this object. If export is False, os.environ is not updated,
424        but in both cases the updated environment is yielded.
425
426        On exit, previous environment is restored. See contextlib documentation
427        for details on how this function is structured.
428
429        Args:
430          export(bool): modify the environment of the running process (and
431            thus, its subprocesses)
432
433        Yields the new environment object.
434        """
435        try:
436            if export:
437                orig_env = os.environ.copy()
438                env = os.environ
439            else:
440                env = os.environ.copy()
441
442            apply = apply_visitor.ApplyVisitor(pathsep=self._pathsep)
443            apply.apply(self, env)
444
445            yield env
446
447        finally:
448            if export:
449                for key in set(os.environ):
450                    try:
451                        os.environ[key] = orig_env[key]
452                    except KeyError:
453                        del os.environ[key]
454                for key in set(orig_env) - set(os.environ):
455                    os.environ[key] = orig_env[key]
456
457    def get(self, key, default=None):
458        """Get the value of a variable within context of this object."""
459        key = self.normalize_key(key)
460        with self(export=False) as env:
461            return env.get(key, default)
462
463    def __getitem__(self, key):
464        """Get the value of a variable within context of this object."""
465        key = self.normalize_key(key)
466        with self(export=False) as env:
467            return env[key]
468