1"""Shim module exporting the same ElementTree API for lxml and
2xml.etree backends.
3
4When lxml is installed, it is automatically preferred over the built-in
5xml.etree module.
6On Python 2.7, the cElementTree module is preferred over the pure-python
7ElementTree module.
8
9Besides exporting a unified interface, this also defines extra functions
10or subclasses built-in ElementTree classes to add features that are
11only availble in lxml, like OrderedDict for attributes, pretty_print and
12iterwalk.
13"""
14from fontTools.misc.py23 import unicode, tostr
15
16
17XML_DECLARATION = """<?xml version='1.0' encoding='%s'?>"""
18
19__all__ = [
20    # public symbols
21    "Comment",
22    "dump",
23    "Element",
24    "ElementTree",
25    "fromstring",
26    "fromstringlist",
27    "iselement",
28    "iterparse",
29    "parse",
30    "ParseError",
31    "PI",
32    "ProcessingInstruction",
33    "QName",
34    "SubElement",
35    "tostring",
36    "tostringlist",
37    "TreeBuilder",
38    "XML",
39    "XMLParser",
40    "register_namespace",
41]
42
43try:
44    from lxml.etree import *
45
46    _have_lxml = True
47except ImportError:
48    try:
49        from xml.etree.cElementTree import *
50
51        # the cElementTree version of XML function doesn't support
52        # the optional 'parser' keyword argument
53        from xml.etree.ElementTree import XML
54    except ImportError:  # pragma: no cover
55        from xml.etree.ElementTree import *
56    _have_lxml = False
57
58    import sys
59
60    # dict is always ordered in python >= 3.6 and on pypy
61    PY36 = sys.version_info >= (3, 6)
62    try:
63        import __pypy__
64    except ImportError:
65        __pypy__ = None
66    _dict_is_ordered = bool(PY36 or __pypy__)
67    del PY36, __pypy__
68
69    if _dict_is_ordered:
70        _Attrib = dict
71    else:
72        from collections import OrderedDict as _Attrib
73
74    if isinstance(Element, type):
75        _Element = Element
76    else:
77        # in py27, cElementTree.Element cannot be subclassed, so
78        # we need to import the pure-python class
79        from xml.etree.ElementTree import Element as _Element
80
81    class Element(_Element):
82        """Element subclass that keeps the order of attributes."""
83
84        def __init__(self, tag, attrib=_Attrib(), **extra):
85            super(Element, self).__init__(tag)
86            self.attrib = _Attrib()
87            if attrib:
88                self.attrib.update(attrib)
89            if extra:
90                self.attrib.update(extra)
91
92    def SubElement(parent, tag, attrib=_Attrib(), **extra):
93        """Must override SubElement as well otherwise _elementtree.SubElement
94        fails if 'parent' is a subclass of Element object.
95        """
96        element = parent.__class__(tag, attrib, **extra)
97        parent.append(element)
98        return element
99
100    def _iterwalk(element, events, tag):
101        include = tag is None or element.tag == tag
102        if include and "start" in events:
103            yield ("start", element)
104        for e in element:
105            for item in _iterwalk(e, events, tag):
106                yield item
107        if include:
108            yield ("end", element)
109
110    def iterwalk(element_or_tree, events=("end",), tag=None):
111        """A tree walker that generates events from an existing tree as
112        if it was parsing XML data with iterparse().
113        Drop-in replacement for lxml.etree.iterwalk.
114        """
115        if iselement(element_or_tree):
116            element = element_or_tree
117        else:
118            element = element_or_tree.getroot()
119        if tag == "*":
120            tag = None
121        for item in _iterwalk(element, events, tag):
122            yield item
123
124    _ElementTree = ElementTree
125
126    class ElementTree(_ElementTree):
127        """ElementTree subclass that adds 'pretty_print' and 'doctype'
128        arguments to the 'write' method.
129        Currently these are only supported for the default XML serialization
130        'method', and not also for "html" or "text", for these are delegated
131        to the base class.
132        """
133
134        def write(
135            self,
136            file_or_filename,
137            encoding=None,
138            xml_declaration=False,
139            method=None,
140            doctype=None,
141            pretty_print=False,
142        ):
143            if method and method != "xml":
144                # delegate to super-class
145                super(ElementTree, self).write(
146                    file_or_filename,
147                    encoding=encoding,
148                    xml_declaration=xml_declaration,
149                    method=method,
150                )
151                return
152
153            if encoding is unicode or (
154                encoding is not None and encoding.lower() == "unicode"
155            ):
156                if xml_declaration:
157                    raise ValueError(
158                        "Serialisation to unicode must not request an XML declaration"
159                    )
160                write_declaration = False
161                encoding = "unicode"
162            elif xml_declaration is None:
163                # by default, write an XML declaration only for non-standard encodings
164                write_declaration = encoding is not None and encoding.upper() not in (
165                    "ASCII",
166                    "UTF-8",
167                    "UTF8",
168                    "US-ASCII",
169                )
170            else:
171                write_declaration = xml_declaration
172
173            if encoding is None:
174                encoding = "ASCII"
175
176            if pretty_print:
177                # NOTE this will modify the tree in-place
178                _indent(self._root)
179
180            with _get_writer(file_or_filename, encoding) as write:
181                if write_declaration:
182                    write(XML_DECLARATION % encoding.upper())
183                    if pretty_print:
184                        write("\n")
185                if doctype:
186                    write(_tounicode(doctype))
187                    if pretty_print:
188                        write("\n")
189
190                qnames, namespaces = _namespaces(self._root)
191                _serialize_xml(write, self._root, qnames, namespaces)
192
193    import io
194
195    def tostring(
196        element,
197        encoding=None,
198        xml_declaration=None,
199        method=None,
200        doctype=None,
201        pretty_print=False,
202    ):
203        """Custom 'tostring' function that uses our ElementTree subclass, with
204        pretty_print support.
205        """
206        stream = io.StringIO() if encoding == "unicode" else io.BytesIO()
207        ElementTree(element).write(
208            stream,
209            encoding=encoding,
210            xml_declaration=xml_declaration,
211            method=method,
212            doctype=doctype,
213            pretty_print=pretty_print,
214        )
215        return stream.getvalue()
216
217    # serialization support
218
219    import re
220
221    # Valid XML strings can include any Unicode character, excluding control
222    # characters, the surrogate blocks, FFFE, and FFFF:
223    #   Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
224    # Here we reversed the pattern to match only the invalid characters.
225    # For the 'narrow' python builds supporting only UCS-2, which represent
226    # characters beyond BMP as UTF-16 surrogate pairs, we need to pass through
227    # the surrogate block. I haven't found a more elegant solution...
228    UCS2 = sys.maxunicode < 0x10FFFF
229    if UCS2:
230        _invalid_xml_string = re.compile(
231            "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uFFFE-\uFFFF]"
232        )
233    else:
234        _invalid_xml_string = re.compile(
235            "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uD800-\uDFFF\uFFFE-\uFFFF]"
236        )
237
238    def _tounicode(s):
239        """Test if a string is valid user input and decode it to unicode string
240        using ASCII encoding if it's a bytes string.
241        Reject all bytes/unicode input that contains non-XML characters.
242        Reject all bytes input that contains non-ASCII characters.
243        """
244        try:
245            s = tostr(s, encoding="ascii", errors="strict")
246        except UnicodeDecodeError:
247            raise ValueError(
248                "Bytes strings can only contain ASCII characters. "
249                "Use unicode strings for non-ASCII characters.")
250        except AttributeError:
251            _raise_serialization_error(s)
252        if s and _invalid_xml_string.search(s):
253            raise ValueError(
254                "All strings must be XML compatible: Unicode or ASCII, "
255                "no NULL bytes or control characters"
256            )
257        return s
258
259    import contextlib
260
261    @contextlib.contextmanager
262    def _get_writer(file_or_filename, encoding):
263        # returns text write method and release all resources after using
264        try:
265            write = file_or_filename.write
266        except AttributeError:
267            # file_or_filename is a file name
268            f = open(
269                file_or_filename,
270                "w",
271                encoding="utf-8" if encoding == "unicode" else encoding,
272                errors="xmlcharrefreplace",
273            )
274            with f:
275                yield f.write
276        else:
277            # file_or_filename is a file-like object
278            # encoding determines if it is a text or binary writer
279            if encoding == "unicode":
280                # use a text writer as is
281                yield write
282            else:
283                # wrap a binary writer with TextIOWrapper
284                detach_buffer = False
285                if isinstance(file_or_filename, io.BufferedIOBase):
286                    buf = file_or_filename
287                elif isinstance(file_or_filename, io.RawIOBase):
288                    buf = io.BufferedWriter(file_or_filename)
289                    detach_buffer = True
290                else:
291                    # This is to handle passed objects that aren't in the
292                    # IOBase hierarchy, but just have a write method
293                    buf = io.BufferedIOBase()
294                    buf.writable = lambda: True
295                    buf.write = write
296                    try:
297                        # TextIOWrapper uses this methods to determine
298                        # if BOM (for UTF-16, etc) should be added
299                        buf.seekable = file_or_filename.seekable
300                        buf.tell = file_or_filename.tell
301                    except AttributeError:
302                        pass
303                wrapper = io.TextIOWrapper(
304                    buf,
305                    encoding=encoding,
306                    errors="xmlcharrefreplace",
307                    newline="\n",
308                )
309                try:
310                    yield wrapper.write
311                finally:
312                    # Keep the original file open when the TextIOWrapper and
313                    # the BufferedWriter are destroyed
314                    wrapper.detach()
315                    if detach_buffer:
316                        buf.detach()
317
318    from xml.etree.ElementTree import _namespace_map
319
320    def _namespaces(elem):
321        # identify namespaces used in this tree
322
323        # maps qnames to *encoded* prefix:local names
324        qnames = {None: None}
325
326        # maps uri:s to prefixes
327        namespaces = {}
328
329        def add_qname(qname):
330            # calculate serialized qname representation
331            try:
332                qname = _tounicode(qname)
333                if qname[:1] == "{":
334                    uri, tag = qname[1:].rsplit("}", 1)
335                    prefix = namespaces.get(uri)
336                    if prefix is None:
337                        prefix = _namespace_map.get(uri)
338                        if prefix is None:
339                            prefix = "ns%d" % len(namespaces)
340                        else:
341                            prefix = _tounicode(prefix)
342                        if prefix != "xml":
343                            namespaces[uri] = prefix
344                    if prefix:
345                        qnames[qname] = "%s:%s" % (prefix, tag)
346                    else:
347                        qnames[qname] = tag  # default element
348                else:
349                    qnames[qname] = qname
350            except TypeError:
351                _raise_serialization_error(qname)
352
353        # populate qname and namespaces table
354        for elem in elem.iter():
355            tag = elem.tag
356            if isinstance(tag, QName):
357                if tag.text not in qnames:
358                    add_qname(tag.text)
359            elif isinstance(tag, str):
360                if tag not in qnames:
361                    add_qname(tag)
362            elif tag is not None and tag is not Comment and tag is not PI:
363                _raise_serialization_error(tag)
364            for key, value in elem.items():
365                if isinstance(key, QName):
366                    key = key.text
367                if key not in qnames:
368                    add_qname(key)
369                if isinstance(value, QName) and value.text not in qnames:
370                    add_qname(value.text)
371            text = elem.text
372            if isinstance(text, QName) and text.text not in qnames:
373                add_qname(text.text)
374        return qnames, namespaces
375
376    def _serialize_xml(write, elem, qnames, namespaces, **kwargs):
377        tag = elem.tag
378        text = elem.text
379        if tag is Comment:
380            write("<!--%s-->" % _tounicode(text))
381        elif tag is ProcessingInstruction:
382            write("<?%s?>" % _tounicode(text))
383        else:
384            tag = qnames[_tounicode(tag) if tag is not None else None]
385            if tag is None:
386                if text:
387                    write(_escape_cdata(text))
388                for e in elem:
389                    _serialize_xml(write, e, qnames, None)
390            else:
391                write("<" + tag)
392                if namespaces:
393                    for uri, prefix in sorted(
394                        namespaces.items(), key=lambda x: x[1]
395                    ):  # sort on prefix
396                        if prefix:
397                            prefix = ":" + prefix
398                        write(' xmlns%s="%s"' % (prefix, _escape_attrib(uri)))
399                attrs = elem.attrib
400                if attrs:
401                    # try to keep existing attrib order
402                    if len(attrs) <= 1 or type(attrs) is _Attrib:
403                        items = attrs.items()
404                    else:
405                        # if plain dict, use lexical order
406                        items = sorted(attrs.items())
407                    for k, v in items:
408                        if isinstance(k, QName):
409                            k = _tounicode(k.text)
410                        else:
411                            k = _tounicode(k)
412                        if isinstance(v, QName):
413                            v = qnames[_tounicode(v.text)]
414                        else:
415                            v = _escape_attrib(v)
416                        write(' %s="%s"' % (qnames[k], v))
417                if text is not None or len(elem):
418                    write(">")
419                    if text:
420                        write(_escape_cdata(text))
421                    for e in elem:
422                        _serialize_xml(write, e, qnames, None)
423                    write("</" + tag + ">")
424                else:
425                    write("/>")
426        if elem.tail:
427            write(_escape_cdata(elem.tail))
428
429    def _raise_serialization_error(text):
430        raise TypeError(
431            "cannot serialize %r (type %s)" % (text, type(text).__name__)
432        )
433
434    def _escape_cdata(text):
435        # escape character data
436        try:
437            text = _tounicode(text)
438            # it's worth avoiding do-nothing calls for short strings
439            if "&" in text:
440                text = text.replace("&", "&amp;")
441            if "<" in text:
442                text = text.replace("<", "&lt;")
443            if ">" in text:
444                text = text.replace(">", "&gt;")
445            return text
446        except (TypeError, AttributeError):
447            _raise_serialization_error(text)
448
449    def _escape_attrib(text):
450        # escape attribute value
451        try:
452            text = _tounicode(text)
453            if "&" in text:
454                text = text.replace("&", "&amp;")
455            if "<" in text:
456                text = text.replace("<", "&lt;")
457            if ">" in text:
458                text = text.replace(">", "&gt;")
459            if '"' in text:
460                text = text.replace('"', "&quot;")
461            if "\n" in text:
462                text = text.replace("\n", "&#10;")
463            return text
464        except (TypeError, AttributeError):
465            _raise_serialization_error(text)
466
467    def _indent(elem, level=0):
468        # From http://effbot.org/zone/element-lib.htm#prettyprint
469        i = "\n" + level * "  "
470        if len(elem):
471            if not elem.text or not elem.text.strip():
472                elem.text = i + "  "
473            if not elem.tail or not elem.tail.strip():
474                elem.tail = i
475            for elem in elem:
476                _indent(elem, level + 1)
477            if not elem.tail or not elem.tail.strip():
478                elem.tail = i
479        else:
480            if level and (not elem.tail or not elem.tail.strip()):
481                elem.tail = i
482