1"""Shim module exporting the same ElementTree API for lxml and
2xml.etree backends.
3
4When lxml is installed, it is automatically preferred over the built-in
5xml.etree module.
6On Python 2.7, the cElementTree module is preferred over the pure-python
7ElementTree module.
8
9Besides exporting a unified interface, this also defines extra functions
10or subclasses built-in ElementTree classes to add features that are
11only availble in lxml, like OrderedDict for attributes, pretty_print and
12iterwalk.
13"""
14from __future__ import absolute_import, unicode_literals
15from fontTools.misc.py23 import basestring, unicode, tounicode, open
16
17
18XML_DECLARATION = """<?xml version='1.0' encoding='%s'?>"""
19
20__all__ = [
21    # public symbols
22    "Comment",
23    "dump",
24    "Element",
25    "ElementTree",
26    "fromstring",
27    "fromstringlist",
28    "iselement",
29    "iterparse",
30    "parse",
31    "ParseError",
32    "PI",
33    "ProcessingInstruction",
34    "QName",
35    "SubElement",
36    "tostring",
37    "tostringlist",
38    "TreeBuilder",
39    "XML",
40    "XMLParser",
41    "register_namespace",
42]
43
44try:
45    from lxml.etree import *
46
47    _have_lxml = True
48except ImportError:
49    try:
50        from xml.etree.cElementTree import *
51
52        # the cElementTree version of XML function doesn't support
53        # the optional 'parser' keyword argument
54        from xml.etree.ElementTree import XML
55    except ImportError:  # pragma: no cover
56        from xml.etree.ElementTree import *
57    _have_lxml = False
58
59    import sys
60
61    # dict is always ordered in python >= 3.6 and on pypy
62    PY36 = sys.version_info >= (3, 6)
63    try:
64        import __pypy__
65    except ImportError:
66        __pypy__ = None
67    _dict_is_ordered = bool(PY36 or __pypy__)
68    del PY36, __pypy__
69
70    if _dict_is_ordered:
71        _Attrib = dict
72    else:
73        from collections import OrderedDict as _Attrib
74
75    if isinstance(Element, type):
76        _Element = Element
77    else:
78        # in py27, cElementTree.Element cannot be subclassed, so
79        # we need to import the pure-python class
80        from xml.etree.ElementTree import Element as _Element
81
82    class Element(_Element):
83        """Element subclass that keeps the order of attributes."""
84
85        def __init__(self, tag, attrib=_Attrib(), **extra):
86            super(Element, self).__init__(tag)
87            self.attrib = _Attrib()
88            if attrib:
89                self.attrib.update(attrib)
90            if extra:
91                self.attrib.update(extra)
92
93    def SubElement(parent, tag, attrib=_Attrib(), **extra):
94        """Must override SubElement as well otherwise _elementtree.SubElement
95        fails if 'parent' is a subclass of Element object.
96        """
97        element = parent.__class__(tag, attrib, **extra)
98        parent.append(element)
99        return element
100
101    def _iterwalk(element, events, tag):
102        include = tag is None or element.tag == tag
103        if include and "start" in events:
104            yield ("start", element)
105        for e in element:
106            for item in _iterwalk(e, events, tag):
107                yield item
108        if include:
109            yield ("end", element)
110
111    def iterwalk(element_or_tree, events=("end",), tag=None):
112        """A tree walker that generates events from an existing tree as
113        if it was parsing XML data with iterparse().
114        Drop-in replacement for lxml.etree.iterwalk.
115        """
116        if iselement(element_or_tree):
117            element = element_or_tree
118        else:
119            element = element_or_tree.getroot()
120        if tag == "*":
121            tag = None
122        for item in _iterwalk(element, events, tag):
123            yield item
124
125    _ElementTree = ElementTree
126
127    class ElementTree(_ElementTree):
128        """ElementTree subclass that adds 'pretty_print' and 'doctype'
129        arguments to the 'write' method.
130        Currently these are only supported for the default XML serialization
131        'method', and not also for "html" or "text", for these are delegated
132        to the base class.
133        """
134
135        def write(
136            self,
137            file_or_filename,
138            encoding=None,
139            xml_declaration=False,
140            method=None,
141            doctype=None,
142            pretty_print=False,
143        ):
144            if method and method != "xml":
145                # delegate to super-class
146                super(ElementTree, self).write(
147                    file_or_filename,
148                    encoding=encoding,
149                    xml_declaration=xml_declaration,
150                    method=method,
151                )
152                return
153
154            if encoding is unicode or (
155                encoding is not None and encoding.lower() == "unicode"
156            ):
157                if xml_declaration:
158                    raise ValueError(
159                        "Serialisation to unicode must not request an XML declaration"
160                    )
161                write_declaration = False
162                encoding = "unicode"
163            elif xml_declaration is None:
164                # by default, write an XML declaration only for non-standard encodings
165                write_declaration = encoding is not None and encoding.upper() not in (
166                    "ASCII",
167                    "UTF-8",
168                    "UTF8",
169                    "US-ASCII",
170                )
171            else:
172                write_declaration = xml_declaration
173
174            if encoding is None:
175                encoding = "ASCII"
176
177            if pretty_print:
178                # NOTE this will modify the tree in-place
179                _indent(self._root)
180
181            with _get_writer(file_or_filename, encoding) as write:
182                if write_declaration:
183                    write(XML_DECLARATION % encoding.upper())
184                    if pretty_print:
185                        write("\n")
186                if doctype:
187                    write(_tounicode(doctype))
188                    if pretty_print:
189                        write("\n")
190
191                qnames, namespaces = _namespaces(self._root)
192                _serialize_xml(write, self._root, qnames, namespaces)
193
194    import io
195
196    def tostring(
197        element,
198        encoding=None,
199        xml_declaration=None,
200        method=None,
201        doctype=None,
202        pretty_print=False,
203    ):
204        """Custom 'tostring' function that uses our ElementTree subclass, with
205        pretty_print support.
206        """
207        stream = io.StringIO() if encoding == "unicode" else io.BytesIO()
208        ElementTree(element).write(
209            stream,
210            encoding=encoding,
211            xml_declaration=xml_declaration,
212            method=method,
213            doctype=doctype,
214            pretty_print=pretty_print,
215        )
216        return stream.getvalue()
217
218    # serialization support
219
220    import re
221
222    # Valid XML strings can include any Unicode character, excluding control
223    # characters, the surrogate blocks, FFFE, and FFFF:
224    #   Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
225    # Here we reversed the pattern to match only the invalid characters.
226    # For the 'narrow' python builds supporting only UCS-2, which represent
227    # characters beyond BMP as UTF-16 surrogate pairs, we need to pass through
228    # the surrogate block. I haven't found a more elegant solution...
229    UCS2 = sys.maxunicode < 0x10FFFF
230    if UCS2:
231        _invalid_xml_string = re.compile(
232            "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uFFFE-\uFFFF]"
233        )
234    else:
235        _invalid_xml_string = re.compile(
236            "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uD800-\uDFFF\uFFFE-\uFFFF]"
237        )
238
239    def _tounicode(s):
240        """Test if a string is valid user input and decode it to unicode string
241        using ASCII encoding if it's a bytes string.
242        Reject all bytes/unicode input that contains non-XML characters.
243        Reject all bytes input that contains non-ASCII characters.
244        """
245        try:
246            s = tounicode(s, encoding="ascii", errors="strict")
247        except UnicodeDecodeError:
248            raise ValueError(
249                "Bytes strings can only contain ASCII characters. "
250                "Use unicode strings for non-ASCII characters.")
251        except AttributeError:
252            _raise_serialization_error(s)
253        if s and _invalid_xml_string.search(s):
254            raise ValueError(
255                "All strings must be XML compatible: Unicode or ASCII, "
256                "no NULL bytes or control characters"
257            )
258        return s
259
260    import contextlib
261
262    @contextlib.contextmanager
263    def _get_writer(file_or_filename, encoding):
264        # returns text write method and release all resources after using
265        try:
266            write = file_or_filename.write
267        except AttributeError:
268            # file_or_filename is a file name
269            f = open(
270                file_or_filename,
271                "w",
272                encoding="utf-8" if encoding == "unicode" else encoding,
273                errors="xmlcharrefreplace",
274            )
275            with f:
276                yield f.write
277        else:
278            # file_or_filename is a file-like object
279            # encoding determines if it is a text or binary writer
280            if encoding == "unicode":
281                # use a text writer as is
282                yield write
283            else:
284                # wrap a binary writer with TextIOWrapper
285                detach_buffer = False
286                if isinstance(file_or_filename, io.BufferedIOBase):
287                    buf = file_or_filename
288                elif isinstance(file_or_filename, io.RawIOBase):
289                    buf = io.BufferedWriter(file_or_filename)
290                    detach_buffer = True
291                else:
292                    # This is to handle passed objects that aren't in the
293                    # IOBase hierarchy, but just have a write method
294                    buf = io.BufferedIOBase()
295                    buf.writable = lambda: True
296                    buf.write = write
297                    try:
298                        # TextIOWrapper uses this methods to determine
299                        # if BOM (for UTF-16, etc) should be added
300                        buf.seekable = file_or_filename.seekable
301                        buf.tell = file_or_filename.tell
302                    except AttributeError:
303                        pass
304                wrapper = io.TextIOWrapper(
305                    buf,
306                    encoding=encoding,
307                    errors="xmlcharrefreplace",
308                    newline="\n",
309                )
310                try:
311                    yield wrapper.write
312                finally:
313                    # Keep the original file open when the TextIOWrapper and
314                    # the BufferedWriter are destroyed
315                    wrapper.detach()
316                    if detach_buffer:
317                        buf.detach()
318
319    from xml.etree.ElementTree import _namespace_map
320
321    def _namespaces(elem):
322        # identify namespaces used in this tree
323
324        # maps qnames to *encoded* prefix:local names
325        qnames = {None: None}
326
327        # maps uri:s to prefixes
328        namespaces = {}
329
330        def add_qname(qname):
331            # calculate serialized qname representation
332            try:
333                qname = _tounicode(qname)
334                if qname[:1] == "{":
335                    uri, tag = qname[1:].rsplit("}", 1)
336                    prefix = namespaces.get(uri)
337                    if prefix is None:
338                        prefix = _namespace_map.get(uri)
339                        if prefix is None:
340                            prefix = "ns%d" % len(namespaces)
341                        else:
342                            prefix = _tounicode(prefix)
343                        if prefix != "xml":
344                            namespaces[uri] = prefix
345                    if prefix:
346                        qnames[qname] = "%s:%s" % (prefix, tag)
347                    else:
348                        qnames[qname] = tag  # default element
349                else:
350                    qnames[qname] = qname
351            except TypeError:
352                _raise_serialization_error(qname)
353
354        # populate qname and namespaces table
355        for elem in elem.iter():
356            tag = elem.tag
357            if isinstance(tag, QName):
358                if tag.text not in qnames:
359                    add_qname(tag.text)
360            elif isinstance(tag, basestring):
361                if tag not in qnames:
362                    add_qname(tag)
363            elif tag is not None and tag is not Comment and tag is not PI:
364                _raise_serialization_error(tag)
365            for key, value in elem.items():
366                if isinstance(key, QName):
367                    key = key.text
368                if key not in qnames:
369                    add_qname(key)
370                if isinstance(value, QName) and value.text not in qnames:
371                    add_qname(value.text)
372            text = elem.text
373            if isinstance(text, QName) and text.text not in qnames:
374                add_qname(text.text)
375        return qnames, namespaces
376
377    def _serialize_xml(write, elem, qnames, namespaces, **kwargs):
378        tag = elem.tag
379        text = elem.text
380        if tag is Comment:
381            write("<!--%s-->" % _tounicode(text))
382        elif tag is ProcessingInstruction:
383            write("<?%s?>" % _tounicode(text))
384        else:
385            tag = qnames[_tounicode(tag) if tag is not None else None]
386            if tag is None:
387                if text:
388                    write(_escape_cdata(text))
389                for e in elem:
390                    _serialize_xml(write, e, qnames, None)
391            else:
392                write("<" + tag)
393                if namespaces:
394                    for uri, prefix in sorted(
395                        namespaces.items(), key=lambda x: x[1]
396                    ):  # sort on prefix
397                        if prefix:
398                            prefix = ":" + prefix
399                        write(' xmlns%s="%s"' % (prefix, _escape_attrib(uri)))
400                attrs = elem.attrib
401                if attrs:
402                    # try to keep existing attrib order
403                    if len(attrs) <= 1 or type(attrs) is _Attrib:
404                        items = attrs.items()
405                    else:
406                        # if plain dict, use lexical order
407                        items = sorted(attrs.items())
408                    for k, v in items:
409                        if isinstance(k, QName):
410                            k = _tounicode(k.text)
411                        else:
412                            k = _tounicode(k)
413                        if isinstance(v, QName):
414                            v = qnames[_tounicode(v.text)]
415                        else:
416                            v = _escape_attrib(v)
417                        write(' %s="%s"' % (qnames[k], v))
418                if text is not None or len(elem):
419                    write(">")
420                    if text:
421                        write(_escape_cdata(text))
422                    for e in elem:
423                        _serialize_xml(write, e, qnames, None)
424                    write("</" + tag + ">")
425                else:
426                    write("/>")
427        if elem.tail:
428            write(_escape_cdata(elem.tail))
429
430    def _raise_serialization_error(text):
431        raise TypeError(
432            "cannot serialize %r (type %s)" % (text, type(text).__name__)
433        )
434
435    def _escape_cdata(text):
436        # escape character data
437        try:
438            text = _tounicode(text)
439            # it's worth avoiding do-nothing calls for short strings
440            if "&" in text:
441                text = text.replace("&", "&amp;")
442            if "<" in text:
443                text = text.replace("<", "&lt;")
444            if ">" in text:
445                text = text.replace(">", "&gt;")
446            return text
447        except (TypeError, AttributeError):
448            _raise_serialization_error(text)
449
450    def _escape_attrib(text):
451        # escape attribute value
452        try:
453            text = _tounicode(text)
454            if "&" in text:
455                text = text.replace("&", "&amp;")
456            if "<" in text:
457                text = text.replace("<", "&lt;")
458            if ">" in text:
459                text = text.replace(">", "&gt;")
460            if '"' in text:
461                text = text.replace('"', "&quot;")
462            if "\n" in text:
463                text = text.replace("\n", "&#10;")
464            return text
465        except (TypeError, AttributeError):
466            _raise_serialization_error(text)
467
468    def _indent(elem, level=0):
469        # From http://effbot.org/zone/element-lib.htm#prettyprint
470        i = "\n" + level * "  "
471        if len(elem):
472            if not elem.text or not elem.text.strip():
473                elem.text = i + "  "
474            if not elem.tail or not elem.tail.strip():
475                elem.tail = i
476            for elem in elem:
477                _indent(elem, level + 1)
478            if not elem.tail or not elem.tail.strip():
479                elem.tail = i
480        else:
481            if level and (not elem.tail or not elem.tail.strip()):
482                elem.tail = i
483