Issue #1470548: XMLGenerator now works with binary output streams.

This commit is contained in:
Serhiy Storchaka 2013-02-10 14:34:53 +02:00
commit 02c2076bd5
3 changed files with 192 additions and 92 deletions

View File

@ -13,7 +13,7 @@
from xml.sax.expatreader import create_parser from xml.sax.expatreader import create_parser
from xml.sax.handler import feature_namespaces from xml.sax.handler import feature_namespaces
from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
from io import StringIO from io import BytesIO, StringIO
import os.path import os.path
import shutil import shutil
from test import support from test import support
@ -173,31 +173,29 @@ def test_make_parser(self):
# ===== XMLGenerator # ===== XMLGenerator
start = '<?xml version="1.0" encoding="iso-8859-1"?>\n' class XmlgenTest:
class XmlgenTest(unittest.TestCase):
def test_xmlgen_basic(self): def test_xmlgen_basic(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
gen.startElement("doc", {}) gen.startElement("doc", {})
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc></doc>") self.assertEqual(result.getvalue(), self.xml("<doc></doc>"))
def test_xmlgen_basic_empty(self): def test_xmlgen_basic_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
gen.startElement("doc", {}) gen.startElement("doc", {})
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc/>") self.assertEqual(result.getvalue(), self.xml("<doc/>"))
def test_xmlgen_content(self): def test_xmlgen_content(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
@ -206,10 +204,10 @@ def test_xmlgen_content(self):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>") self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
def test_xmlgen_content_empty(self): def test_xmlgen_content_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
@ -218,10 +216,10 @@ def test_xmlgen_content_empty(self):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>") self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
def test_xmlgen_pi(self): def test_xmlgen_pi(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
@ -230,10 +228,11 @@ def test_xmlgen_pi(self):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<?test data?><doc></doc>") self.assertEqual(result.getvalue(),
self.xml("<?test data?><doc></doc>"))
def test_xmlgen_content_escape(self): def test_xmlgen_content_escape(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
@ -243,10 +242,10 @@ def test_xmlgen_content_escape(self):
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), self.assertEqual(result.getvalue(),
start + "<doc>&lt;huhei&amp;</doc>") self.xml("<doc>&lt;huhei&amp;</doc>"))
def test_xmlgen_attr_escape(self): def test_xmlgen_attr_escape(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
@ -260,13 +259,43 @@ def test_xmlgen_attr_escape(self):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + self.assertEqual(result.getvalue(), self.xml(
("<doc a='\"'><e a=\"'\"></e>" "<doc a='\"'><e a=\"'\"></e>"
"<e a=\"'&quot;\"></e>" "<e a=\"'&quot;\"></e>"
"<e a=\"&#10;&#13;&#9;\"></e></doc>")) "<e a=\"&#10;&#13;&#9;\"></e></doc>"))
def test_xmlgen_encoding(self):
encodings = ('iso-8859-15', 'utf-8', 'utf-8-sig',
'utf-16', 'utf-16be', 'utf-16le',
'utf-32', 'utf-32be', 'utf-32le')
for encoding in encodings:
result = self.ioclass()
gen = XMLGenerator(result, encoding=encoding)
gen.startDocument()
gen.startElement("doc", {"a": '\u20ac'})
gen.characters("\u20ac")
gen.endElement("doc")
gen.endDocument()
self.assertEqual(result.getvalue(),
self.xml('<doc a="\u20ac">\u20ac</doc>', encoding=encoding))
def test_xmlgen_unencodable(self):
result = self.ioclass()
gen = XMLGenerator(result, encoding='ascii')
gen.startDocument()
gen.startElement("doc", {"a": '\u20ac'})
gen.characters("\u20ac")
gen.endElement("doc")
gen.endDocument()
self.assertEqual(result.getvalue(),
self.xml('<doc a="&#8364;">&#8364;</doc>', encoding='ascii'))
def test_xmlgen_ignorable(self): def test_xmlgen_ignorable(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
@ -275,10 +304,10 @@ def test_xmlgen_ignorable(self):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc> </doc>") self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
def test_xmlgen_ignorable_empty(self): def test_xmlgen_ignorable_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
@ -287,10 +316,10 @@ def test_xmlgen_ignorable_empty(self):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc> </doc>") self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
def test_xmlgen_ns(self): def test_xmlgen_ns(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
@ -303,12 +332,12 @@ def test_xmlgen_ns(self):
gen.endPrefixMapping("ns1") gen.endPrefixMapping("ns1")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + \ self.assertEqual(result.getvalue(), self.xml(
('<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' % '<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
ns_uri)) ns_uri))
def test_xmlgen_ns_empty(self): def test_xmlgen_ns_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
@ -321,12 +350,12 @@ def test_xmlgen_ns_empty(self):
gen.endPrefixMapping("ns1") gen.endPrefixMapping("ns1")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + \ self.assertEqual(result.getvalue(), self.xml(
('<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' % '<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' %
ns_uri)) ns_uri))
def test_1463026_1(self): def test_1463026_1(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
@ -334,10 +363,10 @@ def test_1463026_1(self):
gen.endElementNS((None, 'a'), 'a') gen.endElementNS((None, 'a'), 'a')
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start+'<a b="c"></a>') self.assertEqual(result.getvalue(), self.xml('<a b="c"></a>'))
def test_1463026_1_empty(self): def test_1463026_1_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
@ -345,10 +374,10 @@ def test_1463026_1_empty(self):
gen.endElementNS((None, 'a'), 'a') gen.endElementNS((None, 'a'), 'a')
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start+'<a b="c"/>') self.assertEqual(result.getvalue(), self.xml('<a b="c"/>'))
def test_1463026_2(self): def test_1463026_2(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
@ -358,10 +387,10 @@ def test_1463026_2(self):
gen.endPrefixMapping(None) gen.endPrefixMapping(None)
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start+'<a xmlns="qux"></a>') self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"></a>'))
def test_1463026_2_empty(self): def test_1463026_2_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
@ -371,10 +400,10 @@ def test_1463026_2_empty(self):
gen.endPrefixMapping(None) gen.endPrefixMapping(None)
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start+'<a xmlns="qux"/>') self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"/>'))
def test_1463026_3(self): def test_1463026_3(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
@ -385,10 +414,10 @@ def test_1463026_3(self):
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), self.assertEqual(result.getvalue(),
start+'<my:a xmlns:my="qux" b="c"></my:a>') self.xml('<my:a xmlns:my="qux" b="c"></my:a>'))
def test_1463026_3_empty(self): def test_1463026_3_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
@ -399,7 +428,7 @@ def test_1463026_3_empty(self):
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), self.assertEqual(result.getvalue(),
start+'<my:a xmlns:my="qux" b="c"/>') self.xml('<my:a xmlns:my="qux" b="c"/>'))
def test_5027_1(self): def test_5027_1(self):
# The xml prefix (as in xml:lang below) is reserved and bound by # The xml prefix (as in xml:lang below) is reserved and bound by
@ -416,13 +445,13 @@ def test_5027_1(self):
parser = make_parser() parser = make_parser()
parser.setFeature(feature_namespaces, True) parser.setFeature(feature_namespaces, True)
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
parser.setContentHandler(gen) parser.setContentHandler(gen)
parser.parse(test_xml) parser.parse(test_xml)
self.assertEqual(result.getvalue(), self.assertEqual(result.getvalue(),
start + ( self.xml(
'<a:g1 xmlns:a="http://example.com/ns">' '<a:g1 xmlns:a="http://example.com/ns">'
'<a:g2 xml:lang="en">Hello</a:g2>' '<a:g2 xml:lang="en">Hello</a:g2>'
'</a:g1>')) '</a:g1>'))
@ -435,7 +464,7 @@ def test_5027_2(self):
# #
# This test demonstrates the bug by direct manipulation of the # This test demonstrates the bug by direct manipulation of the
# XMLGenerator. # XMLGenerator.
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
@ -450,15 +479,57 @@ def test_5027_2(self):
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), self.assertEqual(result.getvalue(),
start + ( self.xml(
'<a:g1 xmlns:a="http://example.com/ns">' '<a:g1 xmlns:a="http://example.com/ns">'
'<a:g2 xml:lang="en">Hello</a:g2>' '<a:g2 xml:lang="en">Hello</a:g2>'
'</a:g1>')) '</a:g1>'))
def test_no_close_file(self):
result = self.ioclass()
def func(out):
gen = XMLGenerator(out)
gen.startDocument()
gen.startElement("doc", {})
func(result)
self.assertFalse(result.closed)
class StringXmlgenTest(XmlgenTest, unittest.TestCase):
ioclass = StringIO
def xml(self, doc, encoding='iso-8859-1'):
return '<?xml version="1.0" encoding="%s"?>\n%s' % (encoding, doc)
test_xmlgen_unencodable = None
class BytesXmlgenTest(XmlgenTest, unittest.TestCase):
ioclass = BytesIO
def xml(self, doc, encoding='iso-8859-1'):
return ('<?xml version="1.0" encoding="%s"?>\n%s' %
(encoding, doc)).encode(encoding, 'xmlcharrefreplace')
class WriterXmlgenTest(BytesXmlgenTest):
class ioclass(list):
write = list.append
closed = False
def seekable(self):
return True
def tell(self):
# return 0 at start and not 0 after start
return len(self)
def getvalue(self):
return b''.join(self)
start = b'<?xml version="1.0" encoding="iso-8859-1"?>\n'
class XMLFilterBaseTest(unittest.TestCase): class XMLFilterBaseTest(unittest.TestCase):
def test_filter_basic(self): def test_filter_basic(self):
result = StringIO() result = BytesIO()
gen = XMLGenerator(result) gen = XMLGenerator(result)
filter = XMLFilterBase() filter = XMLFilterBase()
filter.setContentHandler(gen) filter.setContentHandler(gen)
@ -470,7 +541,7 @@ def test_filter_basic(self):
filter.endElement("doc") filter.endElement("doc")
filter.endDocument() filter.endDocument()
self.assertEqual(result.getvalue(), start + "<doc>content </doc>") self.assertEqual(result.getvalue(), start + b"<doc>content </doc>")
# =========================================================================== # ===========================================================================
# #
@ -478,7 +549,7 @@ def test_filter_basic(self):
# #
# =========================================================================== # ===========================================================================
with open(TEST_XMLFILE_OUT) as f: with open(TEST_XMLFILE_OUT, 'rb') as f:
xml_test_out = f.read() xml_test_out = f.read()
class ExpatReaderTest(XmlTestBase): class ExpatReaderTest(XmlTestBase):
@ -487,11 +558,11 @@ class ExpatReaderTest(XmlTestBase):
def test_expat_file(self): def test_expat_file(self):
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
with open(TEST_XMLFILE) as f: with open(TEST_XMLFILE, 'rb') as f:
parser.parse(f) parser.parse(f)
self.assertEqual(result.getvalue(), xml_test_out) self.assertEqual(result.getvalue(), xml_test_out)
@ -503,7 +574,7 @@ def test_expat_file_nonascii(self):
self.addCleanup(support.unlink, fname) self.addCleanup(support.unlink, fname)
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
@ -547,13 +618,13 @@ class TestEntityResolver:
def resolveEntity(self, publicId, systemId): def resolveEntity(self, publicId, systemId):
inpsrc = InputSource() inpsrc = InputSource()
inpsrc.setByteStream(StringIO("<entity/>")) inpsrc.setByteStream(BytesIO(b"<entity/>"))
return inpsrc return inpsrc
def test_expat_entityresolver(self): def test_expat_entityresolver(self):
parser = create_parser() parser = create_parser()
parser.setEntityResolver(self.TestEntityResolver()) parser.setEntityResolver(self.TestEntityResolver())
result = StringIO() result = BytesIO()
parser.setContentHandler(XMLGenerator(result)) parser.setContentHandler(XMLGenerator(result))
parser.feed('<!DOCTYPE doc [\n') parser.feed('<!DOCTYPE doc [\n')
@ -563,7 +634,7 @@ def test_expat_entityresolver(self):
parser.close() parser.close()
self.assertEqual(result.getvalue(), start + self.assertEqual(result.getvalue(), start +
"<doc><entity></entity></doc>") b"<doc><entity></entity></doc>")
# ===== Attributes support # ===== Attributes support
@ -632,7 +703,7 @@ def test_expat_nsattrs_wattr(self):
def test_expat_inpsource_filename(self): def test_expat_inpsource_filename(self):
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
@ -642,7 +713,7 @@ def test_expat_inpsource_filename(self):
def test_expat_inpsource_sysid(self): def test_expat_inpsource_sysid(self):
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
@ -657,7 +728,7 @@ def test_expat_inpsource_sysid_nonascii(self):
self.addCleanup(support.unlink, fname) self.addCleanup(support.unlink, fname)
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
@ -667,12 +738,12 @@ def test_expat_inpsource_sysid_nonascii(self):
def test_expat_inpsource_stream(self): def test_expat_inpsource_stream(self):
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
inpsrc = InputSource() inpsrc = InputSource()
with open(TEST_XMLFILE) as f: with open(TEST_XMLFILE, 'rb') as f:
inpsrc.setByteStream(f) inpsrc.setByteStream(f)
parser.parse(inpsrc) parser.parse(inpsrc)
@ -681,7 +752,7 @@ def test_expat_inpsource_stream(self):
# ===== IncrementalParser support # ===== IncrementalParser support
def test_expat_incremental(self): def test_expat_incremental(self):
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser = create_parser() parser = create_parser()
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
@ -690,10 +761,10 @@ def test_expat_incremental(self):
parser.feed("</doc>") parser.feed("</doc>")
parser.close() parser.close()
self.assertEqual(result.getvalue(), start + "<doc></doc>") self.assertEqual(result.getvalue(), start + b"<doc></doc>")
def test_expat_incremental_reset(self): def test_expat_incremental_reset(self):
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser = create_parser() parser = create_parser()
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
@ -701,7 +772,7 @@ def test_expat_incremental_reset(self):
parser.feed("<doc>") parser.feed("<doc>")
parser.feed("text") parser.feed("text")
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
parser.reset() parser.reset()
@ -711,12 +782,12 @@ def test_expat_incremental_reset(self):
parser.feed("</doc>") parser.feed("</doc>")
parser.close() parser.close()
self.assertEqual(result.getvalue(), start + "<doc>text</doc>") self.assertEqual(result.getvalue(), start + b"<doc>text</doc>")
# ===== Locator support # ===== Locator support
def test_expat_locator_noinfo(self): def test_expat_locator_noinfo(self):
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser = create_parser() parser = create_parser()
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
@ -730,7 +801,7 @@ def test_expat_locator_noinfo(self):
self.assertEqual(parser.getLineNumber(), 1) self.assertEqual(parser.getLineNumber(), 1)
def test_expat_locator_withinfo(self): def test_expat_locator_withinfo(self):
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser = create_parser() parser = create_parser()
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
@ -745,7 +816,7 @@ def test_expat_locator_withinfo_nonascii(self):
shutil.copyfile(TEST_XMLFILE, fname) shutil.copyfile(TEST_XMLFILE, fname)
self.addCleanup(support.unlink, fname) self.addCleanup(support.unlink, fname)
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser = create_parser() parser = create_parser()
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
@ -766,7 +837,7 @@ def test_expat_inpsource_location(self):
parser = create_parser() parser = create_parser()
parser.setContentHandler(ContentHandler()) # do nothing parser.setContentHandler(ContentHandler()) # do nothing
source = InputSource() source = InputSource()
source.setByteStream(StringIO("<foo bar foobar>")) #ill-formed source.setByteStream(BytesIO(b"<foo bar foobar>")) #ill-formed
name = "a file name" name = "a file name"
source.setSystemId(name) source.setSystemId(name)
try: try:
@ -857,7 +928,9 @@ def test_nsattrs_wattr(self):
def test_main(): def test_main():
run_unittest(MakeParserTest, run_unittest(MakeParserTest,
SaxutilsTest, SaxutilsTest,
XmlgenTest, StringXmlgenTest,
BytesXmlgenTest,
WriterXmlgenTest,
ExpatReaderTest, ExpatReaderTest,
ErrorReportingTest, ErrorReportingTest,
XmlReaderTest) XmlReaderTest)

View File

@ -4,18 +4,10 @@
""" """
import os, urllib.parse, urllib.request import os, urllib.parse, urllib.request
import io
from . import handler from . import handler
from . import xmlreader from . import xmlreader
# See whether the xmlcharrefreplace error handler is
# supported
try:
from codecs import xmlcharrefreplace_errors
_error_handling = "xmlcharrefreplace"
del xmlcharrefreplace_errors
except ImportError:
_error_handling = "strict"
def __dict_replace(s, d): def __dict_replace(s, d):
"""Replace substrings of a string using a dictionary.""" """Replace substrings of a string using a dictionary."""
for key, value in d.items(): for key, value in d.items():
@ -76,14 +68,50 @@ def quoteattr(data, entities={}):
return data return data
def _gettextwriter(out, encoding):
if out is None:
import sys
return sys.stdout
if isinstance(out, io.TextIOBase):
# use a text writer as is
return out
# wrap a binary writer with TextIOWrapper
if isinstance(out, io.RawIOBase):
# Keep the original file open when the TextIOWrapper is
# destroyed
class _wrapper:
__class__ = out.__class__
def __getattr__(self, name):
return getattr(out, name)
buffer = _wrapper()
buffer.close = lambda: None
else:
# This is to handle passed objects that aren't in the
# IOBase hierarchy, but just have a write method
buffer = io.BufferedIOBase()
buffer.writable = lambda: True
buffer.write = out.write
try:
# TextIOWrapper uses this methods to determine
# if BOM (for UTF-16, etc) should be added
buffer.seekable = out.seekable
buffer.tell = out.tell
except AttributeError:
pass
return io.TextIOWrapper(buffer, encoding=encoding,
errors='xmlcharrefreplace',
newline='\n',
write_through=True)
class XMLGenerator(handler.ContentHandler): class XMLGenerator(handler.ContentHandler):
def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False): def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
if out is None:
import sys
out = sys.stdout
handler.ContentHandler.__init__(self) handler.ContentHandler.__init__(self)
self._out = out out = _gettextwriter(out, encoding)
self._write = out.write
self._flush = out.flush
self._ns_contexts = [{}] # contains uri -> prefix dicts self._ns_contexts = [{}] # contains uri -> prefix dicts
self._current_context = self._ns_contexts[-1] self._current_context = self._ns_contexts[-1]
self._undeclared_ns_maps = [] self._undeclared_ns_maps = []
@ -91,12 +119,6 @@ def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
self._short_empty_elements = short_empty_elements self._short_empty_elements = short_empty_elements
self._pending_start_element = False self._pending_start_element = False
def _write(self, text):
if isinstance(text, str):
self._out.write(text)
else:
self._out.write(text.encode(self._encoding, _error_handling))
def _qname(self, name): def _qname(self, name):
"""Builds a qualified name from a (ns_url, localname) pair""" """Builds a qualified name from a (ns_url, localname) pair"""
if name[0]: if name[0]:
@ -125,6 +147,9 @@ def startDocument(self):
self._write('<?xml version="1.0" encoding="%s"?>\n' % self._write('<?xml version="1.0" encoding="%s"?>\n' %
self._encoding) self._encoding)
def endDocument(self):
self._flush()
def startPrefixMapping(self, prefix, uri): def startPrefixMapping(self, prefix, uri):
self._ns_contexts.append(self._current_context.copy()) self._ns_contexts.append(self._current_context.copy())
self._current_context[uri] = prefix self._current_context[uri] = prefix
@ -157,9 +182,9 @@ def startElementNS(self, name, qname, attrs):
for prefix, uri in self._undeclared_ns_maps: for prefix, uri in self._undeclared_ns_maps:
if prefix: if prefix:
self._out.write(' xmlns:%s="%s"' % (prefix, uri)) self._write(' xmlns:%s="%s"' % (prefix, uri))
else: else:
self._out.write(' xmlns="%s"' % uri) self._write(' xmlns="%s"' % uri)
self._undeclared_ns_maps = [] self._undeclared_ns_maps = []
for (name, value) in attrs.items(): for (name, value) in attrs.items():

View File

@ -244,6 +244,8 @@ Core and Builtins
Library Library
------- -------
- Issue #1470548: XMLGenerator now works with binary output streams.
- Issue #6975: os.path.realpath() now correctly resolves multiple nested - Issue #6975: os.path.realpath() now correctly resolves multiple nested
symlinks on POSIX platforms. symlinks on POSIX platforms.