1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Serializes an Environment into a batch file."""
15
16# Disable super() warnings since this file must be Python 2 compatible.
17# pylint: disable=super-with-arguments
18
19# goto label written to the end of Windows batch files for exiting a script.
20_SCRIPT_END_LABEL = '_pw_end'
21
22
23class BatchVisitor(object):  # pylint: disable=useless-object-inheritance
24    """Serializes an Environment into a batch file."""
25    def __init__(self, *args, **kwargs):
26        pathsep = kwargs.pop('pathsep', ':')
27        super(BatchVisitor, self).__init__(*args, **kwargs)
28        self._replacements = ()
29        self._outs = None
30        self._pathsep = pathsep
31
32    def serialize(self, env, outs):
33        try:
34            self._replacements = tuple(
35                (key, env.get(key) if value is None else value)
36                for key, value in env.replacements)
37            self._outs = outs
38            self._outs.write('@echo off\n')
39
40            env.accept(self)
41
42            outs.write(':{}\n'.format(_SCRIPT_END_LABEL))
43
44        finally:
45            self._replacements = ()
46            self._outs = None
47
48    def _apply_replacements(self, action):
49        value = action.value
50        for var, replacement in self._replacements:
51            if var != action.name:
52                value = value.replace(replacement, '%{}%'.format(var))
53        return value
54
55    def visit_set(self, set):  # pylint: disable=redefined-builtin
56        value = self._apply_replacements(set)
57        self._outs.write('set {name}={value}\n'.format(name=set.name,
58                                                       value=value))
59
60    def visit_clear(self, clear):
61        self._outs.write('set {name}=\n'.format(name=clear.name))
62
63    def visit_remove(self, remove):
64        pass  # Not supported on Windows.
65
66    def _join(self, *args):
67        if len(args) == 1 and isinstance(args[0], (list, tuple)):
68            args = args[0]
69        return self._pathsep.join(args)
70
71    def visit_prepend(self, prepend):
72        value = self._apply_replacements(prepend)
73        value = self._join(value, '%{}%'.format(prepend.name))
74        self._outs.write('set {name}={value}\n'.format(name=prepend.name,
75                                                       value=value))
76
77    def visit_append(self, append):
78        value = self._apply_replacements(append)
79        value = self._join('%{}%'.format(append.name), value)
80        self._outs.write('set {name}={value}\n'.format(name=append.name,
81                                                       value=value))
82
83    def visit_echo(self, echo):
84        if echo.newline:
85            if not echo.value:
86                self._outs.write('echo.\n')
87            else:
88                self._outs.write('echo {}\n'.format(echo.value))
89        else:
90            self._outs.write('<nul set /p="{}"\n'.format(echo.value))
91
92    def visit_comment(self, comment):
93        for line in comment.value.splitlines():
94            self._outs.write(':: {}\n'.format(line))
95
96    def visit_command(self, command):
97        # TODO(mohrr) use shlex.quote here?
98        self._outs.write('{}\n'.format(' '.join(command.command)))
99        if not command.exit_on_error:
100            return
101
102        # Assume failing command produced relevant output.
103        self._outs.write(
104            'if %ERRORLEVEL% neq 0 goto {}\n'.format(_SCRIPT_END_LABEL))
105
106    def visit_doctor(self, doctor):
107        self._outs.write('if "%PW_ACTIVATE_SKIP_CHECKS%"=="" (\n')
108        self.visit_command(doctor)
109        self._outs.write(') else (\n')
110        self._outs.write('echo Skipping environment check because '
111                         'PW_ACTIVATE_SKIP_CHECKS is set\n')
112        self._outs.write(')\n')
113
114    def visit_blank_line(self, blank_line):
115        del blank_line
116        self._outs.write('\n')
117
118    def visit_function(self, function):
119        pass  # Not supported on Windows.
120
121    def visit_hash(self, hash):  # pylint: disable=redefined-builtin
122        pass  # Not relevant on Windows.
123