get SSL support to work again

This commit is contained in:
Bill Janssen 2007-11-15 22:23:56 +00:00
parent f83088aefe
commit 6e027dba93
3 changed files with 536 additions and 570 deletions

View File

@ -1,8 +1,6 @@
# Wrapper module for _ssl, providing some additional facilities # Wrapper module for _ssl, providing some additional facilities
# implemented in Python. Written by Bill Janssen. # implemented in Python. Written by Bill Janssen.
raise ImportError("ssl.py is temporarily out of order")
"""\ """\
This module provides some more Pythonic support for SSL. This module provides some more Pythonic support for SSL.
@ -76,9 +74,11 @@
SSL_ERROR_EOF, \ SSL_ERROR_EOF, \
SSL_ERROR_INVALID_ERROR_CODE SSL_ERROR_INVALID_ERROR_CODE
from socket import socket from socket import socket, AF_INET, SOCK_STREAM, error
from socket import getnameinfo as _getnameinfo from socket import getnameinfo as _getnameinfo
from socket import error as socket_error
import base64 # for DER-to-PEM translation import base64 # for DER-to-PEM translation
_can_dup_socket = hasattr(socket, "dup")
class SSLSocket (socket): class SSLSocket (socket):
@ -86,10 +86,38 @@ class SSLSocket (socket):
the underlying OS socket in an SSL context when necessary, and the underlying OS socket in an SSL context when necessary, and
provides read and write methods over that channel.""" provides read and write methods over that channel."""
def __init__(self, sock, keyfile=None, certfile=None, def __init__(self, sock=None, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None): ssl_version=PROTOCOL_SSLv23, ca_certs=None,
socket.__init__(self, _sock=sock._sock) do_handshake_on_connect=True,
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True):
self._base = None
if sock is not None:
# copied this code from socket.accept()
fd = sock.fileno()
nfd = fd
if _can_dup_socket:
nfd = os.dup(fd)
try:
wrapper = socket.__init__(self, family=sock.family, type=sock.type, proto=sock.proto, fileno=nfd)
except:
if nfd != fd:
os.close(nfd)
else:
if fd != nfd:
sock.close()
sock = None
elif fileno is not None:
socket.__init__(self, fileno=fileno)
else:
socket.__init__(self, family=family, type=type, proto=proto)
self._closed = False
if certfile and not keyfile: if certfile and not keyfile:
keyfile = certfile keyfile = certfile
# see if it's connected # see if it's connected
@ -100,27 +128,52 @@ def __init__(self, sock, keyfile=None, certfile=None,
self._sslobj = None self._sslobj = None
else: else:
# yes, create the SSL object # yes, create the SSL object
self._sslobj = _ssl.sslwrap(self._sock, server_side, try:
keyfile, certfile, self._sslobj = _ssl.sslwrap(self, server_side,
cert_reqs, ssl_version, ca_certs) keyfile, certfile,
cert_reqs, ssl_version, ca_certs)
if do_handshake_on_connect:
self.do_handshake()
except socket_error as x:
self.close()
raise x
self._base = sock
self.keyfile = keyfile self.keyfile = keyfile
self.certfile = certfile self.certfile = certfile
self.cert_reqs = cert_reqs self.cert_reqs = cert_reqs
self.ssl_version = ssl_version self.ssl_version = ssl_version
self.ca_certs = ca_certs self.ca_certs = ca_certs
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs
def read(self, len=1024): def _checkClosed(self, msg=None):
# raise an exception here if you wish to check for spurious closes
pass
def read(self, len=1024, buffer=None):
"""Read up to LEN bytes and return them. """Read up to LEN bytes and return them.
Return zero-length string on EOF.""" Return zero-length string on EOF."""
return self._sslobj.read(len) self._checkClosed()
try:
if buffer:
return self._sslobj.read(buffer, len)
else:
return self._sslobj.read(len)
except SSLError as x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
return b''
else:
raise
def write(self, data): def write(self, data):
"""Write DATA to the underlying SSL channel. Returns """Write DATA to the underlying SSL channel. Returns
number of bytes of DATA actually transmitted.""" number of bytes of DATA actually transmitted."""
self._checkClosed()
return self._sslobj.write(data) return self._sslobj.write(data)
def getpeercert(self, binary_form=False): def getpeercert(self, binary_form=False):
@ -130,26 +183,42 @@ def getpeercert(self, binary_form=False):
Return None if no certificate was provided, {} if a Return None if no certificate was provided, {} if a
certificate was provided, but not validated.""" certificate was provided, but not validated."""
self._checkClosed()
return self._sslobj.peer_certificate(binary_form) return self._sslobj.peer_certificate(binary_form)
def cipher (self): def cipher (self):
self._checkClosed()
if not self._sslobj: if not self._sslobj:
return None return None
else: else:
return self._sslobj.cipher() return self._sslobj.cipher()
def send (self, data, flags=0): def send (self, data, flags=0):
self._checkClosed()
if self._sslobj: if self._sslobj:
if flags != 0: if flags != 0:
raise ValueError( raise ValueError(
"non-zero flags not allowed in calls to send() on %s" % "non-zero flags not allowed in calls to send() on %s" %
self.__class__) self.__class__)
return self._sslobj.write(data) while True:
try:
v = self._sslobj.write(data)
except SSLError as x:
if x.args[0] == SSL_ERROR_WANT_READ:
return 0
elif x.args[0] == SSL_ERROR_WANT_WRITE:
return 0
else:
raise
else:
return v
else: else:
return socket.send(self, data, flags) return socket.send(self, data, flags)
def send_to (self, data, addr, flags=0): def send_to (self, data, addr, flags=0):
self._checkClosed()
if self._sslobj: if self._sslobj:
raise ValueError("send_to not allowed on instances of %s" % raise ValueError("send_to not allowed on instances of %s" %
self.__class__) self.__class__)
@ -157,39 +226,95 @@ def send_to (self, data, addr, flags=0):
return socket.send_to(self, data, addr, flags) return socket.send_to(self, data, addr, flags)
def sendall (self, data, flags=0): def sendall (self, data, flags=0):
self._checkClosed()
if self._sslobj: if self._sslobj:
if flags != 0: amount = len(data)
raise ValueError( count = 0
"non-zero flags not allowed in calls to sendall() on %s" % while (count < amount):
self.__class__) v = self.send(data[count:])
return self._sslobj.write(data) count += v
return amount
else: else:
return socket.sendall(self, data, flags) return socket.sendall(self, data, flags)
def recv (self, buflen=1024, flags=0): def recv (self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj: if self._sslobj:
if flags != 0: if flags != 0:
raise ValueError( raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" % "non-zero flags not allowed in calls to recv_into() on %s" %
self.__class__) self.__class__)
return self._sslobj.read(data, buflen) while True:
try:
return self.read(buflen)
except SSLError as x:
if x.args[0] == SSL_ERROR_WANT_READ:
continue
else:
raise x
else: else:
return socket.recv(self, buflen, flags) return socket.recv(self, buflen, flags)
def recv_into (self, buffer, nbytes=None, flags=0):
self._checkClosed()
if buffer and (nbytes is None):
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv_into() on %s" %
self.__class__)
while True:
try:
v = self.read(nbytes, buffer)
sys.stdout.flush()
return v
except SSLError as x:
if x.args[0] == SSL_ERROR_WANT_READ:
continue
else:
raise x
else:
return socket.recv_into(self, buffer, nbytes, flags)
def recv_from (self, addr, buflen=1024, flags=0): def recv_from (self, addr, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj: if self._sslobj:
raise ValueError("recv_from not allowed on instances of %s" % raise ValueError("recv_from not allowed on instances of %s" %
self.__class__) self.__class__)
else: else:
return socket.recv_from(self, addr, buflen, flags) return socket.recv_from(self, addr, buflen, flags)
def shutdown(self, how): def pending (self):
self._checkClosed()
if self._sslobj:
return self._sslobj.pending()
else:
return 0
def shutdown (self, how):
self._checkClosed()
self._sslobj = None self._sslobj = None
socket.shutdown(self, how) socket.shutdown(self, how)
def close(self): def _real_close (self):
self._sslobj = None self._sslobj = None
socket.close(self) # self._closed = True
if self._base:
self._base.close()
socket._real_close(self)
def do_handshake (self):
"""Perform a TLS/SSL handshake."""
try:
self._sslobj.do_handshake()
except:
self._sslobj = None
raise
def connect(self, addr): def connect(self, addr):
@ -201,9 +326,11 @@ def connect(self, addr):
if self._sslobj: if self._sslobj:
raise ValueError("attempt to connect already-connected SSLSocket!") raise ValueError("attempt to connect already-connected SSLSocket!")
socket.connect(self, addr) socket.connect(self, addr)
self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs) self.ca_certs)
if self.do_handshake_on_connect:
self.do_handshake()
def accept(self): def accept(self):
@ -212,260 +339,24 @@ def accept(self):
SSL channel, and the address of the remote client.""" SSL channel, and the address of the remote client."""
newsock, addr = socket.accept(self) newsock, addr = socket.accept(self)
return (SSLSocket(newsock, True, self.keyfile, self.certfile, return (SSLSocket(sock=newsock,
self.cert_reqs, self.ssl_version, keyfile=self.keyfile, certfile=self.certfile,
self.ca_certs), addr) server_side=True,
cert_reqs=self.cert_reqs, ssl_version=self.ssl_version,
ca_certs=self.ca_certs,
def makefile(self, mode='r', bufsize=-1): do_handshake_on_connect=self.do_handshake_on_connect),
addr)
"""Ouch. Need to make and return a file-like object that
works with the SSL connection."""
if self._sslobj:
return SSLFileStream(self._sslobj, mode, bufsize)
else:
return socket.makefile(self, mode, bufsize)
class SSLFileStream:
"""A class to simulate a file stream on top of a socket.
Most of this is just lifted from the socket module, and
adjusted to work with an SSL stream instead of a socket."""
default_bufsize = 8192
name = "<SSL stream>"
__slots__ = ["mode", "bufsize", "softspace",
# "closed" is a property, see below
"_sslobj", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf",
"_close", "_fileno"]
def __init__(self, sslobj, mode='rb', bufsize=-1, close=False):
self._sslobj = sslobj
self.mode = mode # Not actually used in this version
if bufsize < 0:
bufsize = self.default_bufsize
self.bufsize = bufsize
self.softspace = False
if bufsize == 0:
self._rbufsize = 1
elif bufsize == 1:
self._rbufsize = self.default_bufsize
else:
self._rbufsize = bufsize
self._wbufsize = bufsize
self._rbuf = "" # A string
self._wbuf = [] # A list of strings
self._close = close
self._fileno = -1
def _getclosed(self):
return self._sslobj is None
closed = property(_getclosed, doc="True if the file is closed")
def fileno(self):
return self._fileno
def close(self):
try:
if self._sslobj:
self.flush()
finally:
if self._close and self._sslobj:
self._sslobj.close()
self._sslobj = None
def __del__(self):
try:
self.close()
except:
# close() may fail if __init__ didn't complete
pass
def flush(self):
if self._wbuf:
buffer = "".join(self._wbuf)
self._wbuf = []
count = 0
while (count < len(buffer)):
written = self._sslobj.write(buffer)
count += written
buffer = buffer[written:]
def write(self, data):
data = str(data) # XXX Should really reject non-string non-buffers
if not data:
return
self._wbuf.append(data)
if (self._wbufsize == 0 or
self._wbufsize == 1 and '\n' in data or
self._get_wbuf_len() >= self._wbufsize):
self.flush()
def writelines(self, list):
# XXX We could do better here for very long lists
# XXX Should really reject non-string non-buffers
self._wbuf.extend(filter(None, map(str, list)))
if (self._wbufsize <= 1 or
self._get_wbuf_len() >= self._wbufsize):
self.flush()
def _get_wbuf_len(self):
buf_len = 0
for x in self._wbuf:
buf_len += len(x)
return buf_len
def read(self, size=-1):
data = self._rbuf
if size < 0:
# Read until EOF
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
if self._rbufsize <= 1:
recv_size = self.default_bufsize
else:
recv_size = self._rbufsize
while True:
data = self._sslobj.read(recv_size)
if not data:
break
buffers.append(data)
return "".join(buffers)
else:
# Read until size bytes or EOF seen, whichever comes first
buf_len = len(data)
if buf_len >= size:
self._rbuf = data[size:]
return data[:size]
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
while True:
left = size - buf_len
recv_size = max(self._rbufsize, left)
data = self._sslobj.read(recv_size)
if not data:
break
buffers.append(data)
n = len(data)
if n >= left:
self._rbuf = data[left:]
buffers[-1] = data[:left]
break
buf_len += n
return "".join(buffers)
def readline(self, size=-1):
data = self._rbuf
if size < 0:
# Read until \n or EOF, whichever comes first
if self._rbufsize <= 1:
# Speed up unbuffered case
assert data == ""
buffers = []
while data != "\n":
data = self._sslobj.read(1)
if not data:
break
buffers.append(data)
return "".join(buffers)
nl = data.find('\n')
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
return data[:nl]
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
while True:
data = self._sslobj.read(self._rbufsize)
if not data:
break
buffers.append(data)
nl = data.find('\n')
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
buffers[-1] = data[:nl]
break
return "".join(buffers)
else:
# Read until size bytes or \n or EOF seen, whichever comes first
nl = data.find('\n', 0, size)
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
return data[:nl]
buf_len = len(data)
if buf_len >= size:
self._rbuf = data[size:]
return data[:size]
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
while True:
data = self._sslobj.read(self._rbufsize)
if not data:
break
buffers.append(data)
left = size - buf_len
nl = data.find('\n', 0, left)
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
buffers[-1] = data[:nl]
break
n = len(data)
if n >= left:
self._rbuf = data[left:]
buffers[-1] = data[:left]
break
buf_len += n
return "".join(buffers)
def readlines(self, sizehint=0):
total = 0
list = []
while True:
line = self.readline()
if not line:
break
list.append(line)
total += len(line)
if sizehint and total >= sizehint:
break
return list
# Iterator protocols
def __iter__(self):
return self
def next(self):
line = self.readline()
if not line:
raise StopIteration
return line
def wrap_socket(sock, keyfile=None, certfile=None, def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None): ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True):
return SSLSocket(sock, keyfile=keyfile, certfile=certfile, return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs, server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs) ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect)
# some utility functions # some utility functions
@ -486,16 +377,10 @@ def DER_cert_to_PEM_cert(der_cert_bytes):
"""Takes a certificate in binary DER format and returns the """Takes a certificate in binary DER format and returns the
PEM version of it as a string.""" PEM version of it as a string."""
if hasattr(base64, 'standard_b64encode'): f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
# preferred because older API gets line-length wrong return (PEM_HEADER + '\n' +
f = base64.standard_b64encode(der_cert_bytes) textwrap.fill(f, 64) + '\n' +
return (PEM_HEADER + '\n' + PEM_FOOTER + '\n')
textwrap.fill(f, 64) +
PEM_FOOTER + '\n')
else:
return (PEM_HEADER + '\n' +
base64.encodestring(der_cert_bytes) +
PEM_FOOTER + '\n')
def PEM_cert_to_DER_cert(pem_cert_string): def PEM_cert_to_DER_cert(pem_cert_string):
@ -509,7 +394,7 @@ def PEM_cert_to_DER_cert(pem_cert_string):
raise ValueError("Invalid PEM encoding; must end with %s" raise ValueError("Invalid PEM encoding; must end with %s"
% PEM_FOOTER) % PEM_FOOTER)
d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
return base64.decodestring(d) return base64.decodestring(d.encode('ASCII', 'strict'))
def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
@ -541,15 +426,3 @@ def get_protocol_name (protocol_code):
return "SSLv3" return "SSLv3"
else: else:
return "<unknown>" return "<unknown>"
# a replacement for the old socket.ssl function
def sslwrap_simple (sock, keyfile=None, certfile=None):
"""A replacement for the old socket.ssl function. Designed
for compability with Python 2.5 and earlier. Will disappear in
Python 3.0."""
return _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE,
PROTOCOL_SSLv23, None)

View File

@ -4,6 +4,7 @@
import unittest import unittest
from test import test_support from test import test_support
import socket import socket
import select
import errno import errno
import subprocess import subprocess
import time import time
@ -36,27 +37,6 @@ def handle_error(prefix):
class BasicTests(unittest.TestCase): class BasicTests(unittest.TestCase):
def testSSLconnect(self):
import os
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE)
s.connect(("svn.python.org", 443))
c = s.getpeercert()
if c:
raise test_support.TestFailed("Peer cert %s shouldn't be here!")
s.close()
# this should fail because we have no verification certs
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_REQUIRED)
try:
s.connect(("svn.python.org", 443))
except ssl.SSLError:
pass
finally:
s.close()
def testCrucialConstants(self): def testCrucialConstants(self):
ssl.PROTOCOL_SSLv2 ssl.PROTOCOL_SSLv2
ssl.PROTOCOL_SSLv23 ssl.PROTOCOL_SSLv23
@ -97,11 +77,31 @@ def testDERtoPEM(self):
if (d1 != d2): if (d1 != d2):
raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed") raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed")
class NetworkedTests(unittest.TestCase):
class NetworkTests(unittest.TestCase): def testFetchServerCert(self):
pem = ssl.get_server_certificate(("svn.python.org", 443))
if not pem:
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
try:
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
except ssl.SSLError as x:
#should fail
if test_support.verbose:
sys.stdout.write("%s\n" % x)
else:
raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
if not pem:
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
if test_support.verbose:
sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
def testConnect(self): def testConnect(self):
import os
s = ssl.wrap_socket(socket.socket(socket.AF_INET), s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE) cert_reqs=ssl.CERT_NONE)
s.connect(("svn.python.org", 443)) s.connect(("svn.python.org", 443))
@ -131,25 +131,29 @@ def testConnect(self):
finally: finally:
s.close() s.close()
def testFetchServerCert(self): def testNonBlockingHandshake(self):
s = socket.socket(socket.AF_INET)
pem = ssl.get_server_certificate(("svn.python.org", 443)) s.connect(("svn.python.org", 443))
if not pem: s.setblocking(False)
raise test_support.TestFailed("No server certificate on svn.python.org:443!") s = ssl.wrap_socket(s,
cert_reqs=ssl.CERT_NONE,
try: do_handshake_on_connect=False)
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE) count = 0
except ssl.SSLError: while True:
#should fail try:
pass count += 1
else: s.do_handshake()
raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem) break
except ssl.SSLError as err:
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT) if err.args[0] == ssl.SSL_ERROR_WANT_READ:
if not pem: select.select([s], [], [])
raise test_support.TestFailed("No server certificate on svn.python.org:443!") elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
select.select([], [s], [])
else:
raise
s.close()
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem) sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
try: try:
@ -168,10 +172,11 @@ class ConnectionHandler(threading.Thread):
with and without the SSL wrapper around the socket connection, so with and without the SSL wrapper around the socket connection, so
that we can test the STARTTLS functionality.""" that we can test the STARTTLS functionality."""
def __init__(self, server, connsock): def __init__(self, server, connsock, addr):
self.server = server self.server = server
self.running = False self.running = False
self.sock = connsock self.sock = connsock
self.addr = addr
self.sock.setblocking(1) self.sock.setblocking(1)
self.sslconn = None self.sslconn = None
threading.Thread.__init__(self) threading.Thread.__init__(self)
@ -186,8 +191,7 @@ def wrap_conn (self):
cert_reqs=self.server.certreqs) cert_reqs=self.server.certreqs)
except: except:
if self.server.chatty: if self.server.chatty:
handle_error("\n server: bad connection attempt from " + handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n")
str(self.sock.getpeername()) + ":\n")
if not self.server.expect_bad_connects: if not self.server.expect_bad_connects:
# here, we want to stop the server, because this shouldn't # here, we want to stop the server, because this shouldn't
# happen in the context of our test case # happen in the context of our test case
@ -195,6 +199,7 @@ def wrap_conn (self):
# normally, we'd just stop here, but for the test # normally, we'd just stop here, but for the test
# harness, we want to stop the server # harness, we want to stop the server
self.server.stop() self.server.stop()
self.close()
return False return False
else: else:
@ -236,19 +241,21 @@ def run (self):
while self.running: while self.running:
try: try:
msg = self.read() msg = self.read()
amsg = (msg and str(msg, 'ASCII', 'strict')) or ''
if not msg: if not msg:
# eof, so quit this handler # eof, so quit this handler
self.running = False self.running = False
self.close() self.close()
elif msg.strip() == 'over': elif amsg.strip() == 'over':
if test_support.verbose and self.server.connectionchatty: if test_support.verbose and self.server.connectionchatty:
sys.stdout.write(" server: client closed connection\n") sys.stdout.write(" server: client closed connection\n")
self.close() self.close()
return return
elif self.server.starttls_server and msg.strip() == 'STARTTLS': elif (self.server.starttls_server and
amsg.strip() == 'STARTTLS'):
if test_support.verbose and self.server.connectionchatty: if test_support.verbose and self.server.connectionchatty:
sys.stdout.write(" server: read STARTTLS from client, sending OK...\n") sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
self.write("OK\n") self.write("OK\n".encode("ASCII", "strict"))
if not self.wrap_conn(): if not self.wrap_conn():
return return
else: else:
@ -257,8 +264,8 @@ def run (self):
ctype = (self.sslconn and "encrypted") or "unencrypted" ctype = (self.sslconn and "encrypted") or "unencrypted"
sys.stdout.write(" server: read %s (%s), sending back %s (%s)...\n" sys.stdout.write(" server: read %s (%s), sending back %s (%s)...\n"
% (repr(msg), ctype, repr(msg.lower()), ctype)) % (repr(msg), ctype, repr(msg.lower()), ctype))
self.write(msg.lower()) self.write(amsg.lower().encode('ASCII', 'strict'))
except ssl.SSLError: except socket.error:
if self.server.chatty: if self.server.chatty:
handle_error("Test server failure:\n") handle_error("Test server failure:\n")
self.close() self.close()
@ -311,8 +318,8 @@ def run (self):
newconn, connaddr = self.sock.accept() newconn, connaddr = self.sock.accept()
if test_support.verbose and self.chatty: if test_support.verbose and self.chatty:
sys.stdout.write(' server: new connection from ' sys.stdout.write(' server: new connection from '
+ str(connaddr) + '\n') + repr(connaddr) + '\n')
handler = self.ConnectionHandler(self, newconn) handler = self.ConnectionHandler(self, newconn, connaddr)
handler.start() handler.start()
except socket.timeout: except socket.timeout:
pass pass
@ -321,11 +328,10 @@ def run (self):
except: except:
if self.chatty: if self.chatty:
handle_error("Test server failure:\n") handle_error("Test server failure:\n")
self.sock.close()
def stop (self): def stop (self):
self.active = False self.active = False
self.sock.close()
class AsyncoreHTTPSServer(threading.Thread): class AsyncoreHTTPSServer(threading.Thread):
@ -339,6 +345,12 @@ def __init__(self, server_address, RequestHandlerClass, certfile):
self.active = False self.active = False
self.allow_reuse_address = True self.allow_reuse_address = True
def __str__(self):
return ('<%s %s:%s>' %
(self.__class__.__name__,
self.server_name,
self.server_port))
def get_request (self): def get_request (self):
# override this to wrap socket with SSL # override this to wrap socket with SSL
sock, addr = self.socket.accept() sock, addr = self.socket.accept()
@ -415,8 +427,8 @@ def log_message(self, format, *args):
# we override this to suppress logging unless "verbose" # we override this to suppress logging unless "verbose"
if test_support.verbose: if test_support.verbose:
sys.stdout.write(" server (%s, %d, %s):\n [%s] %s\n" % sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" %
(self.server.server_name, (self.server.server_address,
self.server.server_port, self.server.server_port,
self.request.cipher(), self.request.cipher(),
self.log_date_time_string(), self.log_date_time_string(),
@ -433,9 +445,7 @@ def __init__(self, port, certfile):
self.setDaemon(True) self.setDaemon(True)
def __str__(self): def __str__(self):
return '<%s %s:%d>' % (self.__class__.__name__, return "<%s %s>" % (self.__class__.__name__, self.server)
self.server.server_name,
self.server.server_port)
def start (self, flag=None): def start (self, flag=None):
self.flag = flag self.flag = flag
@ -456,7 +466,8 @@ def stop (self):
def badCertTest (certfile): def badCertTest (certfile):
server = ThreadedEchoServer(TESTPORT, CERTFILE, server = ThreadedEchoServer(TESTPORT, CERTFILE,
certreqs=ssl.CERT_REQUIRED, certreqs=ssl.CERT_REQUIRED,
cacerts=CERTFILE, chatty=False) cacerts=CERTFILE, chatty=False,
connectionchatty=False)
flag = threading.Event() flag = threading.Event()
server.start(flag) server.start(flag)
# wait for it to start # wait for it to start
@ -470,7 +481,7 @@ def badCertTest (certfile):
s.connect(('127.0.0.1', TESTPORT)) s.connect(('127.0.0.1', TESTPORT))
except ssl.SSLError as x: except ssl.SSLError as x:
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\nSSLError is %s\n" % x[1]) sys.stdout.write("\nSSLError is %s\n" % x)
else: else:
raise test_support.TestFailed( raise test_support.TestFailed(
"Use of invalid cert should have failed!") "Use of invalid cert should have failed!")
@ -479,15 +490,16 @@ def badCertTest (certfile):
server.join() server.join()
def serverParamsTest (certfile, protocol, certreqs, cacertsfile, def serverParamsTest (certfile, protocol, certreqs, cacertsfile,
client_certfile, client_protocol=None, indata="FOO\n", client_certfile, client_protocol=None,
chatty=True, connectionchatty=False): indata="FOO\n",
chatty=False, connectionchatty=False):
server = ThreadedEchoServer(TESTPORT, certfile, server = ThreadedEchoServer(TESTPORT, certfile,
certreqs=certreqs, certreqs=certreqs,
ssl_version=protocol, ssl_version=protocol,
cacerts=cacertsfile, cacerts=cacertsfile,
chatty=chatty, chatty=chatty,
connectionchatty=connectionchatty) connectionchatty=False)
flag = threading.Event() flag = threading.Event()
server.start(flag) server.start(flag)
# wait for it to start # wait for it to start
@ -496,37 +508,37 @@ def serverParamsTest (certfile, protocol, certreqs, cacertsfile,
if client_protocol is None: if client_protocol is None:
client_protocol = protocol client_protocol = protocol
try: try:
try: s = ssl.wrap_socket(socket.socket(),
s = ssl.wrap_socket(socket.socket(), certfile=client_certfile,
certfile=client_certfile, ca_certs=cacertsfile,
ca_certs=cacertsfile, cert_reqs=certreqs,
cert_reqs=certreqs, ssl_version=client_protocol)
ssl_version=client_protocol) s.connect(('127.0.0.1', TESTPORT))
s.connect(('127.0.0.1', TESTPORT)) except ssl.SSLError as x:
except ssl.SSLError as x: raise test_support.TestFailed("Unexpected SSL error: " + str(x))
raise test_support.TestFailed("Unexpected SSL error: " + str(x)) except Exception as x:
except Exception as x: raise test_support.TestFailed("Unexpected exception: " + str(x))
raise test_support.TestFailed("Unexpected exception: " + str(x)) else:
else: if connectionchatty:
if connectionchatty: if test_support.verbose:
if test_support.verbose: sys.stdout.write(
sys.stdout.write( " client: sending %s...\n" % (repr(indata)))
" client: sending %s...\n" % (repr(indata))) s.write(indata.encode('ASCII', 'strict'))
s.write(indata) outdata = s.read()
outdata = s.read() if connectionchatty:
if connectionchatty: if test_support.verbose:
if test_support.verbose: sys.stdout.write(" client: read %s\n" % repr(outdata))
sys.stdout.write(" client: read %s\n" % repr(outdata)) outdata = str(outdata, 'ASCII', 'strict')
if outdata != indata.lower(): if outdata != indata.lower():
raise test_support.TestFailed( raise test_support.TestFailed(
"bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata),20)], len(outdata), % (repr(outdata[:min(len(outdata),20)]), len(outdata),
indata[:min(len(indata),20)].lower(), len(indata))) repr(indata[:min(len(indata),20)].lower()), len(indata)))
s.write("over\n") s.write("over\n".encode("ASCII", "strict"))
if connectionchatty: if connectionchatty:
if test_support.verbose: if test_support.verbose:
sys.stdout.write(" client: closing connection.\n") sys.stdout.write(" client: closing connection.\n")
s.close() s.close()
finally: finally:
server.stop() server.stop()
server.join() server.join()
@ -553,7 +565,8 @@ def tryProtocolCombo (server_protocol,
certtype)) certtype))
try: try:
serverParamsTest(CERTFILE, server_protocol, certsreqs, serverParamsTest(CERTFILE, server_protocol, certsreqs,
CERTFILE, CERTFILE, client_protocol, chatty=False) CERTFILE, CERTFILE, client_protocol,
chatty=False, connectionchatty=False)
except test_support.TestFailed: except test_support.TestFailed:
if expectedToWork: if expectedToWork:
raise raise
@ -565,47 +578,7 @@ def tryProtocolCombo (server_protocol,
ssl.get_protocol_name(server_protocol))) ssl.get_protocol_name(server_protocol)))
class ConnectedTests(unittest.TestCase): class ThreadedTests(unittest.TestCase):
def testRudeShutdown(self):
listener_ready = threading.Event()
listener_gone = threading.Event()
# `listener` runs in a thread. It opens a socket listening on
# PORT, and sits in an accept() until the main thread connects.
# Then it rudely closes the socket, and sets Event `listener_gone`
# to let the main thread know the socket is gone.
def listener():
s = socket.socket()
if hasattr(socket, 'SO_REUSEADDR'):
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'):
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
s.bind(('127.0.0.1', TESTPORT))
s.listen(5)
listener_ready.set()
s.accept()
s = None # reclaim the socket object, which also closes it
listener_gone.set()
def connector():
listener_ready.wait()
s = socket.socket()
s.connect(('127.0.0.1', TESTPORT))
listener_gone.wait()
try:
ssl_sock = ssl.wrap_socket(s)
except socket.sslerror:
pass
else:
raise test_support.TestFailed(
'connecting to closed SSL socket should have failed')
t = threading.Thread(target=listener)
t.start()
connector()
t.join()
def testEcho (self): def testEcho (self):
@ -656,7 +629,7 @@ def testReadCert(self):
if test_support.verbose: if test_support.verbose:
sys.stdout.write(pprint.pformat(cert) + '\n') sys.stdout.write(pprint.pformat(cert) + '\n')
sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
if not cert.has_key('subject'): if 'subject' not in cert:
raise test_support.TestFailed( raise test_support.TestFailed(
"No subject field in certificate: %s." % "No subject field in certificate: %s." %
pprint.pformat(cert)) pprint.pformat(cert))
@ -680,6 +653,46 @@ def testMalformedKey(self):
badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
"badkey.pem")) "badkey.pem"))
def testRudeShutdown(self):
listener_ready = threading.Event()
listener_gone = threading.Event()
# `listener` runs in a thread. It opens a socket listening on
# PORT, and sits in an accept() until the main thread connects.
# Then it rudely closes the socket, and sets Event `listener_gone`
# to let the main thread know the socket is gone.
def listener():
s = socket.socket()
if hasattr(socket, 'SO_REUSEADDR'):
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'):
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
s.bind(('127.0.0.1', TESTPORT))
s.listen(5)
listener_ready.set()
s.accept()
s = None # reclaim the socket object, which also closes it
listener_gone.set()
def connector():
listener_ready.wait()
s = socket.socket()
s.connect(('127.0.0.1', TESTPORT))
listener_gone.wait()
try:
ssl_sock = ssl.wrap_socket(s)
except IOError:
pass
else:
raise test_support.TestFailed(
'connecting to closed SSL socket should have failed')
t = threading.Thread(target=listener)
t.start()
connector()
t.join()
def testProtocolSSL2(self): def testProtocolSSL2(self):
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\n") sys.stdout.write("\n")
@ -759,39 +772,47 @@ def testSTARTTLS (self):
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\n") sys.stdout.write("\n")
for indata in msgs: for indata in msgs:
msg = indata.encode('ASCII', 'replace')
if test_support.verbose: if test_support.verbose:
sys.stdout.write( sys.stdout.write(
" client: sending %s...\n" % repr(indata)) " client: sending %s...\n" % repr(msg))
if wrapped: if wrapped:
conn.write(indata) conn.write(msg)
outdata = conn.read() outdata = conn.read()
else: else:
s.send(indata) s.send(msg)
outdata = s.recv(1024) outdata = s.recv(1024)
if (indata == "STARTTLS" and if (indata == "STARTTLS" and
outdata.strip().lower().startswith("ok")): str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")):
if test_support.verbose: if test_support.verbose:
msg = str(outdata, 'ASCII', 'replace')
sys.stdout.write( sys.stdout.write(
" client: read %s from server, starting TLS...\n" " client: read %s from server, starting TLS...\n"
% repr(outdata)) % repr(msg))
conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
wrapped = True wrapped = True
else: else:
if test_support.verbose: if test_support.verbose:
msg = str(outdata, 'ASCII', 'replace')
sys.stdout.write( sys.stdout.write(
" client: read %s from server\n" % repr(outdata)) " client: read %s from server\n" % repr(msg))
if test_support.verbose: if test_support.verbose:
sys.stdout.write(" client: closing connection.\n") sys.stdout.write(" client: closing connection.\n")
if wrapped: if wrapped:
conn.write("over\n") conn.write("over\n".encode("ASCII", "strict"))
else: else:
s.send("over\n") s.send("over\n")
if wrapped:
conn.close()
else:
s.close() s.close()
finally: finally:
server.stop() server.stop()
server.join() server.join()
class AsyncoreTests(unittest.TestCase):
def testAsyncore(self): def testAsyncore(self):
server = AsyncoreHTTPSServer(TESTPORT, CERTFILE) server = AsyncoreHTTPSServer(TESTPORT, CERTFILE)
@ -824,6 +845,8 @@ def testAsyncore(self):
raise test_support.TestFailed(msg) raise test_support.TestFailed(msg)
else: else:
if not (d1 == d2): if not (d1 == d2):
print("d1 is", len(d1), repr(d1))
print("d2 is", len(d2), repr(d2))
raise test_support.TestFailed( raise test_support.TestFailed(
"Couldn't fetch data from HTTPS server") "Couldn't fetch data from HTTPS server")
finally: finally:
@ -863,6 +886,7 @@ def test_main(verbose=False):
if (not os.path.exists(CERTFILE) or if (not os.path.exists(CERTFILE) or
not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)): not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)):
raise test_support.TestFailed("Can't read certificate files!") raise test_support.TestFailed("Can't read certificate files!")
TESTPORT = findtestsocket(10025, 12000) TESTPORT = findtestsocket(10025, 12000)
if not TESTPORT: if not TESTPORT:
raise test_support.TestFailed("Can't find open port to test servers on!") raise test_support.TestFailed("Can't find open port to test servers on!")
@ -870,12 +894,13 @@ def test_main(verbose=False):
tests = [BasicTests] tests = [BasicTests]
if test_support.is_resource_enabled('network'): if test_support.is_resource_enabled('network'):
tests.append(NetworkTests) tests.append(NetworkedTests)
if _have_threads: if _have_threads:
thread_info = test_support.threading_setup() thread_info = test_support.threading_setup()
if thread_info and test_support.is_resource_enabled('network'): if thread_info and test_support.is_resource_enabled('network'):
tests.append(ConnectedTests) tests.append(ThreadedTests)
tests.append(AsyncoreTests)
test_support.run_unittest(*tests) test_support.run_unittest(*tests)

View File

@ -2,14 +2,15 @@
SSL support based on patches by Brian E Gallew and Laszlo Kovacs. SSL support based on patches by Brian E Gallew and Laszlo Kovacs.
Re-worked a bit by Bill Janssen to add server-side support and Re-worked a bit by Bill Janssen to add server-side support and
certificate decoding. certificate decoding. Chris Stawarz contributed some non-blocking
patches.
This module is imported by ssl.py. It should *not* be used This module is imported by ssl.py. It should *not* be used
directly. directly.
XXX should partial writes be enabled, SSL_MODE_ENABLE_PARTIAL_WRITE? XXX should partial writes be enabled, SSL_MODE_ENABLE_PARTIAL_WRITE?
XXX what about SSL_MODE_AUTO_RETRY XXX what about SSL_MODE_AUTO_RETRY?
*/ */
#include "Python.h" #include "Python.h"
@ -17,7 +18,7 @@
#ifdef WITH_THREAD #ifdef WITH_THREAD
#include "pythread.h" #include "pythread.h"
#define PySSL_BEGIN_ALLOW_THREADS { \ #define PySSL_BEGIN_ALLOW_THREADS { \
PyThreadState *_save; \ PyThreadState *_save = NULL; \
if (_ssl_locks_count>0) {_save = PyEval_SaveThread();} if (_ssl_locks_count>0) {_save = PyEval_SaveThread();}
#define PySSL_BLOCK_THREADS if (_ssl_locks_count>0){PyEval_RestoreThread(_save)}; #define PySSL_BLOCK_THREADS if (_ssl_locks_count>0){PyEval_RestoreThread(_save)};
#define PySSL_UNBLOCK_THREADS if (_ssl_locks_count>0){_save = PyEval_SaveThread()}; #define PySSL_UNBLOCK_THREADS if (_ssl_locks_count>0){_save = PyEval_SaveThread()};
@ -114,8 +115,6 @@ typedef struct {
SSL_CTX* ctx; SSL_CTX* ctx;
SSL* ssl; SSL* ssl;
X509* peer_cert; X509* peer_cert;
char server[X509_NAME_MAXLEN];
char issuer[X509_NAME_MAXLEN];
} PySSLObject; } PySSLObject;
@ -265,15 +264,11 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
PySSLObject *self; PySSLObject *self;
char *errstr = NULL; char *errstr = NULL;
int ret; int ret;
int err;
int sockstate;
int verification_mode; int verification_mode;
self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */ self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */
if (self == NULL) if (self == NULL)
return NULL; return NULL;
memset(self->server, '\0', sizeof(char) * X509_NAME_MAXLEN);
memset(self->issuer, '\0', sizeof(char) * X509_NAME_MAXLEN);
self->peer_cert = NULL; self->peer_cert = NULL;
self->ssl = NULL; self->ssl = NULL;
self->ctx = NULL; self->ctx = NULL;
@ -388,57 +383,6 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
SSL_set_accept_state(self->ssl); SSL_set_accept_state(self->ssl);
PySSL_END_ALLOW_THREADS PySSL_END_ALLOW_THREADS
/* Actually negotiate SSL connection */
/* XXX If SSL_connect() returns 0, it's also a failure. */
sockstate = 0;
do {
PySSL_BEGIN_ALLOW_THREADS
if (socket_type == PY_SSL_CLIENT)
ret = SSL_connect(self->ssl);
else
ret = SSL_accept(self->ssl);
err = SSL_get_error(self->ssl, ret);
PySSL_END_ALLOW_THREADS
if(PyErr_CheckSignals()) {
goto fail;
}
if (err == SSL_ERROR_WANT_READ) {
sockstate = check_socket_and_wait_for_timeout(Sock, 0);
} else if (err == SSL_ERROR_WANT_WRITE) {
sockstate = check_socket_and_wait_for_timeout(Sock, 1);
} else {
sockstate = SOCKET_OPERATION_OK;
}
if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("The connect operation timed out"));
goto fail;
} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("Underlying socket has been closed."));
goto fail;
} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("Underlying socket too large for select()."));
goto fail;
} else if (sockstate == SOCKET_IS_NONBLOCKING) {
break;
}
} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
if (ret < 1) {
PySSL_SetError(self, ret, __FILE__, __LINE__);
goto fail;
}
self->ssl->debug = 1;
PySSL_BEGIN_ALLOW_THREADS
if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) {
X509_NAME_oneline(X509_get_subject_name(self->peer_cert),
self->server, X509_NAME_MAXLEN);
X509_NAME_oneline(X509_get_issuer_name(self->peer_cert),
self->issuer, X509_NAME_MAXLEN);
}
PySSL_END_ALLOW_THREADS
self->Socket = Sock; self->Socket = Sock;
Py_INCREF(self->Socket); Py_INCREF(self->Socket);
return self; return self;
@ -488,16 +432,58 @@ PyDoc_STRVAR(ssl_doc,
/* SSL object methods */ /* SSL object methods */
static PyObject * static PyObject *PySSL_SSLdo_handshake(PySSLObject *self)
PySSL_server(PySSLObject *self)
{ {
return PyUnicode_FromString(self->server); int ret;
} int err;
int sockstate;
static PyObject * /* Actually negotiate SSL connection */
PySSL_issuer(PySSLObject *self) /* XXX If SSL_do_handshake() returns 0, it's also a failure. */
{ sockstate = 0;
return PyUnicode_FromString(self->issuer); do {
PySSL_BEGIN_ALLOW_THREADS
ret = SSL_do_handshake(self->ssl);
err = SSL_get_error(self->ssl, ret);
PySSL_END_ALLOW_THREADS
if(PyErr_CheckSignals()) {
return NULL;
}
if (err == SSL_ERROR_WANT_READ) {
sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);
} else if (err == SSL_ERROR_WANT_WRITE) {
sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
} else {
sockstate = SOCKET_OPERATION_OK;
}
if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("The handshake operation timed out"));
return NULL;
} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("Underlying socket has been closed."));
return NULL;
} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("Underlying socket too large for select()."));
return NULL;
} else if (sockstate == SOCKET_IS_NONBLOCKING) {
break;
}
} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
if (ret < 1)
return PySSL_SetError(self, ret, __FILE__, __LINE__);
self->ssl->debug = 1;
if (self->peer_cert)
X509_free (self->peer_cert);
PySSL_BEGIN_ALLOW_THREADS
self->peer_cert = SSL_get_peer_certificate(self->ssl);
PySSL_END_ALLOW_THREADS
Py_INCREF(Py_None);
return Py_None;
} }
static PyObject * static PyObject *
@ -515,7 +501,7 @@ _create_tuple_for_attribute (ASN1_OBJECT *name, ASN1_STRING *value) {
_setSSLError(NULL, 0, __FILE__, __LINE__); _setSSLError(NULL, 0, __FILE__, __LINE__);
goto fail; goto fail;
} }
name_obj = PyString_FromStringAndSize(namebuf, buflen); name_obj = PyUnicode_FromStringAndSize(namebuf, buflen);
if (name_obj == NULL) if (name_obj == NULL)
goto fail; goto fail;
@ -681,21 +667,24 @@ _get_peer_alt_names (X509 *certificate) {
/* now decode the altName */ /* now decode the altName */
ext = X509_get_ext(certificate, i); ext = X509_get_ext(certificate, i);
if(!(method = X509V3_EXT_get(ext))) { if(!(method = X509V3_EXT_get(ext))) {
PyErr_SetString(PySSLErrorObject, PyErr_SetString
ERRSTR("No method for internalizing subjectAltName!")); (PySSLErrorObject,
ERRSTR("No method for internalizing subjectAltName!"));
goto fail; goto fail;
} }
p = ext->value->data; p = ext->value->data;
if(method->it) if(method->it)
names = (GENERAL_NAMES*) (ASN1_item_d2i(NULL, names = (GENERAL_NAMES*)
&p, (ASN1_item_d2i(NULL,
ext->value->length, &p,
ASN1_ITEM_ptr(method->it))); ext->value->length,
ASN1_ITEM_ptr(method->it)));
else else
names = (GENERAL_NAMES*) (method->d2i(NULL, names = (GENERAL_NAMES*)
&p, (method->d2i(NULL,
ext->value->length)); &p,
ext->value->length));
for(j = 0; j < sk_GENERAL_NAME_num(names); j++) { for(j = 0; j < sk_GENERAL_NAME_num(names); j++) {
@ -704,14 +693,15 @@ _get_peer_alt_names (X509 *certificate) {
name = sk_GENERAL_NAME_value(names, j); name = sk_GENERAL_NAME_value(names, j);
if (name->type == GEN_DIRNAME) { if (name->type == GEN_DIRNAME) {
/* we special-case DirName as a tuple of tuples of attributes */ /* we special-case DirName as a tuple of
tuples of attributes */
t = PyTuple_New(2); t = PyTuple_New(2);
if (t == NULL) { if (t == NULL) {
goto fail; goto fail;
} }
v = PyString_FromString("DirName"); v = PyUnicode_FromString("DirName");
if (v == NULL) { if (v == NULL) {
Py_DECREF(t); Py_DECREF(t);
goto fail; goto fail;
@ -742,13 +732,14 @@ _get_peer_alt_names (X509 *certificate) {
t = PyTuple_New(2); t = PyTuple_New(2);
if (t == NULL) if (t == NULL)
goto fail; goto fail;
v = PyString_FromStringAndSize(buf, (vptr - buf)); v = PyUnicode_FromStringAndSize(buf, (vptr - buf));
if (v == NULL) { if (v == NULL) {
Py_DECREF(t); Py_DECREF(t);
goto fail; goto fail;
} }
PyTuple_SET_ITEM(t, 0, v); PyTuple_SET_ITEM(t, 0, v);
v = PyString_FromStringAndSize((vptr + 1), (len - (vptr - buf + 1))); v = PyUnicode_FromStringAndSize((vptr + 1),
(len - (vptr - buf + 1)));
if (v == NULL) { if (v == NULL) {
Py_DECREF(t); Py_DECREF(t);
goto fail; goto fail;
@ -849,7 +840,7 @@ _decode_certificate (X509 *certificate, int verbose) {
_setSSLError(NULL, 0, __FILE__, __LINE__); _setSSLError(NULL, 0, __FILE__, __LINE__);
goto fail1; goto fail1;
} }
sn_obj = PyString_FromStringAndSize(buf, len); sn_obj = PyUnicode_FromStringAndSize(buf, len);
if (sn_obj == NULL) if (sn_obj == NULL)
goto fail1; goto fail1;
if (PyDict_SetItemString(retval, "serialNumber", sn_obj) < 0) { if (PyDict_SetItemString(retval, "serialNumber", sn_obj) < 0) {
@ -866,7 +857,7 @@ _decode_certificate (X509 *certificate, int verbose) {
_setSSLError(NULL, 0, __FILE__, __LINE__); _setSSLError(NULL, 0, __FILE__, __LINE__);
goto fail1; goto fail1;
} }
pnotBefore = PyString_FromStringAndSize(buf, len); pnotBefore = PyUnicode_FromStringAndSize(buf, len);
if (pnotBefore == NULL) if (pnotBefore == NULL)
goto fail1; goto fail1;
if (PyDict_SetItemString(retval, "notBefore", pnotBefore) < 0) { if (PyDict_SetItemString(retval, "notBefore", pnotBefore) < 0) {
@ -884,7 +875,7 @@ _decode_certificate (X509 *certificate, int verbose) {
_setSSLError(NULL, 0, __FILE__, __LINE__); _setSSLError(NULL, 0, __FILE__, __LINE__);
goto fail1; goto fail1;
} }
pnotAfter = PyString_FromStringAndSize(buf, len); pnotAfter = PyUnicode_FromStringAndSize(buf, len);
if (pnotAfter == NULL) if (pnotAfter == NULL)
goto fail1; goto fail1;
if (PyDict_SetItemString(retval, "notAfter", pnotAfter) < 0) { if (PyDict_SetItemString(retval, "notAfter", pnotAfter) < 0) {
@ -928,22 +919,26 @@ PySSL_test_decode_certificate (PyObject *mod, PyObject *args) {
BIO *cert; BIO *cert;
int verbose = 1; int verbose = 1;
if (!PyArg_ParseTuple(args, "s|i:test_decode_certificate", &filename, &verbose)) if (!PyArg_ParseTuple(args, "s|i:test_decode_certificate",
&filename, &verbose))
return NULL; return NULL;
if ((cert=BIO_new(BIO_s_file())) == NULL) { if ((cert=BIO_new(BIO_s_file())) == NULL) {
PyErr_SetString(PySSLErrorObject, "Can't malloc memory to read file"); PyErr_SetString(PySSLErrorObject,
"Can't malloc memory to read file");
goto fail0; goto fail0;
} }
if (BIO_read_filename(cert,filename) <= 0) { if (BIO_read_filename(cert,filename) <= 0) {
PyErr_SetString(PySSLErrorObject, "Can't open file"); PyErr_SetString(PySSLErrorObject,
"Can't open file");
goto fail0; goto fail0;
} }
x = PEM_read_bio_X509_AUX(cert,NULL, NULL, NULL); x = PEM_read_bio_X509_AUX(cert,NULL, NULL, NULL);
if (x == NULL) { if (x == NULL) {
PyErr_SetString(PySSLErrorObject, "Error decoding PEM-encoded file"); PyErr_SetString(PySSLErrorObject,
"Error decoding PEM-encoded file");
goto fail0; goto fail0;
} }
@ -981,7 +976,9 @@ PySSL_peercert(PySSLObject *self, PyObject *args)
PySSL_SetError(self, len, __FILE__, __LINE__); PySSL_SetError(self, len, __FILE__, __LINE__);
return NULL; return NULL;
} }
retval = PyString_FromStringAndSize((const char *) bytes_buf, len); /* this is actually an immutable bytes sequence */
retval = PyBytes_FromStringAndSize
((const char *) bytes_buf, len);
OPENSSL_free(bytes_buf); OPENSSL_free(bytes_buf);
return retval; return retval;
@ -1028,7 +1025,7 @@ static PyObject *PySSL_cipher (PySSLObject *self) {
if (cipher_name == NULL) { if (cipher_name == NULL) {
PyTuple_SET_ITEM(retval, 0, Py_None); PyTuple_SET_ITEM(retval, 0, Py_None);
} else { } else {
v = PyString_FromString(cipher_name); v = PyUnicode_FromString(cipher_name);
if (v == NULL) if (v == NULL)
goto fail0; goto fail0;
PyTuple_SET_ITEM(retval, 0, v); PyTuple_SET_ITEM(retval, 0, v);
@ -1037,7 +1034,7 @@ static PyObject *PySSL_cipher (PySSLObject *self) {
if (cipher_protocol == NULL) { if (cipher_protocol == NULL) {
PyTuple_SET_ITEM(retval, 1, Py_None); PyTuple_SET_ITEM(retval, 1, Py_None);
} else { } else {
v = PyString_FromString(cipher_protocol); v = PyUnicode_FromString(cipher_protocol);
if (v == NULL) if (v == NULL)
goto fail0; goto fail0;
PyTuple_SET_ITEM(retval, 1, v); PyTuple_SET_ITEM(retval, 1, v);
@ -1127,7 +1124,9 @@ check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing)
rc = select(s->sock_fd+1, &fds, NULL, NULL, &tv); rc = select(s->sock_fd+1, &fds, NULL, NULL, &tv);
PySSL_END_ALLOW_THREADS PySSL_END_ALLOW_THREADS
#ifdef HAVE_POLL
normal_return: normal_return:
#endif
/* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise /* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise
(when we are able to write or when there's something to read) */ (when we are able to write or when there's something to read) */
return rc == 0 ? SOCKET_HAS_TIMED_OUT : SOCKET_OPERATION_OK; return rc == 0 ? SOCKET_HAS_TIMED_OUT : SOCKET_OPERATION_OK;
@ -1140,10 +1139,16 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
int count; int count;
int sockstate; int sockstate;
int err; int err;
int nonblocking;
if (!PyArg_ParseTuple(args, "s#:write", &data, &count)) if (!PyArg_ParseTuple(args, "y#:write", &data, &count))
return NULL; return NULL;
/* just in case the blocking state of the socket has been changed */
nonblocking = (self->Socket->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
if (sockstate == SOCKET_HAS_TIMED_OUT) { if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject, PyErr_SetString(PySSLErrorObject,
@ -1200,19 +1205,58 @@ PyDoc_STRVAR(PySSL_SSLwrite_doc,
Writes the string s into the SSL object. Returns the number\n\ Writes the string s into the SSL object. Returns the number\n\
of bytes written."); of bytes written.");
static PyObject *PySSL_SSLpending(PySSLObject *self)
{
int count = 0;
PySSL_BEGIN_ALLOW_THREADS
count = SSL_pending(self->ssl);
PySSL_END_ALLOW_THREADS
if (count < 0)
return PySSL_SetError(self, count, __FILE__, __LINE__);
else
return PyInt_FromLong(count);
}
PyDoc_STRVAR(PySSL_SSLpending_doc,
"pending() -> count\n\
\n\
Returns the number of already decrypted bytes available for read,\n\
pending on the connection.\n");
static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
{ {
PyObject *buf; PyObject *buf = NULL;
int count = 0; int buf_passed = 0;
int count = -1;
int len = 1024; int len = 1024;
int sockstate; int sockstate;
int err; int err;
int nonblocking;
if (!PyArg_ParseTuple(args, "|i:read", &len)) if (!PyArg_ParseTuple(args, "|Oi:read", &buf, &count))
return NULL; return NULL;
if (!(buf = PyBytes_FromStringAndSize((char *) 0, len))) if ((buf == NULL) || (buf == Py_None)) {
return NULL; if (!(buf = PyBytes_FromStringAndSize((char *) 0, len)))
return NULL;
} else if (PyInt_Check(buf)) {
len = PyInt_AS_LONG(buf);
if (!(buf = PyBytes_FromStringAndSize((char *) 0, len)))
return NULL;
} else {
if (!PyBytes_Check(buf))
return NULL;
len = PyBytes_Size(buf);
if ((count > 0) && (count <= len))
len = count;
buf_passed = 1;
}
/* just in case the blocking state of the socket has been changed */
nonblocking = (self->Socket->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
/* first check if there are bytes ready to be read */ /* first check if there are bytes ready to be read */
PySSL_BEGIN_ALLOW_THREADS PySSL_BEGIN_ALLOW_THREADS
@ -1224,27 +1268,38 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
if (sockstate == SOCKET_HAS_TIMED_OUT) { if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject, PyErr_SetString(PySSLErrorObject,
"The read operation timed out"); "The read operation timed out");
Py_DECREF(buf); if (!buf_passed) {
Py_DECREF(buf);
}
return NULL; return NULL;
} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) { } else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) {
PyErr_SetString(PySSLErrorObject, PyErr_SetString(PySSLErrorObject,
"Underlying socket too large for select()."); "Underlying socket too large for select().");
if (!buf_passed) {
Py_DECREF(buf);
}
Py_DECREF(buf); Py_DECREF(buf);
return NULL; return NULL;
} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
/* should contain a zero-length string */ /* should contain a zero-length string */
_PyString_Resize(&buf, 0); if (!buf_passed) {
return buf; PyBytes_Resize(buf, 0);
return buf;
} else {
return PyInt_FromLong(0);
}
} }
} }
do { do {
err = 0; err = 0;
PySSL_BEGIN_ALLOW_THREADS PySSL_BEGIN_ALLOW_THREADS
count = SSL_read(self->ssl, PyBytes_AS_STRING(buf), len); count = SSL_read(self->ssl, PyBytes_AsString(buf), len);
err = SSL_get_error(self->ssl, count); err = SSL_get_error(self->ssl, count);
PySSL_END_ALLOW_THREADS PySSL_END_ALLOW_THREADS
if(PyErr_CheckSignals()) { if(PyErr_CheckSignals()) {
Py_DECREF(buf); if (!buf_passed) {
Py_DECREF(buf);
}
return NULL; return NULL;
} }
if (err == SSL_ERROR_WANT_READ) { if (err == SSL_ERROR_WANT_READ) {
@ -1257,44 +1312,55 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
(SSL_get_shutdown(self->ssl) == (SSL_get_shutdown(self->ssl) ==
SSL_RECEIVED_SHUTDOWN)) SSL_RECEIVED_SHUTDOWN))
{ {
_PyString_Resize(&buf, 0); if (!buf_passed) {
return buf; PyBytes_Resize(buf, 0);
return buf;
} else {
return PyInt_FromLong(0);
}
} else { } else {
sockstate = SOCKET_OPERATION_OK; sockstate = SOCKET_OPERATION_OK;
} }
if (sockstate == SOCKET_HAS_TIMED_OUT) { if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject, PyErr_SetString(PySSLErrorObject,
"The read operation timed out"); "The read operation timed out");
Py_DECREF(buf); if (!buf_passed) {
Py_DECREF(buf);
}
return NULL; return NULL;
} else if (sockstate == SOCKET_IS_NONBLOCKING) { } else if (sockstate == SOCKET_IS_NONBLOCKING) {
break; break;
} }
} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
if (count <= 0) { if (count <= 0) {
Py_DECREF(buf); if (!buf_passed) {
Py_DECREF(buf);
}
return PySSL_SetError(self, count, __FILE__, __LINE__); return PySSL_SetError(self, count, __FILE__, __LINE__);
} }
if (count != len) if (!buf_passed) {
if (PyBytes_Resize(buf, count) < 0) { if (count != len) {
Py_DECREF(buf); PyBytes_Resize(buf, count);
return NULL; }
} return buf;
return buf; } else {
return PyInt_FromLong(count);
}
} }
PyDoc_STRVAR(PySSL_SSLread_doc, PyDoc_STRVAR(PySSL_SSLread_doc,
"read([len]) -> bytes\n\ "read([len]) -> string\n\
\n\ \n\
Read up to len bytes from the SSL socket."); Read up to len bytes from the SSL socket.");
static PyMethodDef PySSLMethods[] = { static PyMethodDef PySSLMethods[] = {
{"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS},
{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS, {"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS,
PySSL_SSLwrite_doc}, PySSL_SSLwrite_doc},
{"read", (PyCFunction)PySSL_SSLread, METH_VARARGS, {"read", (PyCFunction)PySSL_SSLread, METH_VARARGS,
PySSL_SSLread_doc}, PySSL_SSLread_doc},
{"server", (PyCFunction)PySSL_server, METH_NOARGS}, {"pending", (PyCFunction)PySSL_SSLpending, METH_NOARGS,
{"issuer", (PyCFunction)PySSL_issuer, METH_NOARGS}, PySSL_SSLpending_doc},
{"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS, {"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS,
PySSL_peercert_doc}, PySSL_peercert_doc},
{"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS}, {"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS},
@ -1350,26 +1416,26 @@ bound on the entropy contained in string. See RFC 1750.");
static PyObject * static PyObject *
PySSL_RAND_status(PyObject *self) PySSL_RAND_status(PyObject *self)
{ {
return PyBool_FromLong(RAND_status()); return PyInt_FromLong(RAND_status());
} }
PyDoc_STRVAR(PySSL_RAND_status_doc, PyDoc_STRVAR(PySSL_RAND_status_doc,
"RAND_status() -> 0 or 1\n\ "RAND_status() -> 0 or 1\n\
\n\ \n\
Returns True if the OpenSSL PRNG has been seeded with enough data and\n\ Returns 1 if the OpenSSL PRNG has been seeded with enough data and 0 if not.\n\
False if not. It is necessary to seed the PRNG with RAND_add()\n\ It is necessary to seed the PRNG with RAND_add() on some platforms before\n\
on some platforms before using the ssl() function."); using the ssl() function.");
static PyObject * static PyObject *
PySSL_RAND_egd(PyObject *self, PyObject *arg) PySSL_RAND_egd(PyObject *self, PyObject *arg)
{ {
int bytes; int bytes;
if (!PyString_Check(arg)) if (!PyUnicode_Check(arg))
return PyErr_Format(PyExc_TypeError, return PyErr_Format(PyExc_TypeError,
"RAND_egd() expected string, found %s", "RAND_egd() expected string, found %s",
Py_Type(arg)->tp_name); Py_Type(arg)->tp_name);
bytes = RAND_egd(PyString_AS_STRING(arg)); bytes = RAND_egd(PyUnicode_AsString(arg));
if (bytes == -1) { if (bytes == -1) {
PyErr_SetString(PySSLErrorObject, PyErr_SetString(PySSLErrorObject,
"EGD connection failed or EGD did not return " "EGD connection failed or EGD did not return "
@ -1418,16 +1484,17 @@ static unsigned long _ssl_thread_id_function (void) {
return PyThread_get_thread_ident(); return PyThread_get_thread_ident();
} }
static void _ssl_thread_locking_function (int mode, int n, const char *file, int line) { static void _ssl_thread_locking_function
(int mode, int n, const char *file, int line) {
/* this function is needed to perform locking on shared data /* this function is needed to perform locking on shared data
structures. (Note that OpenSSL uses a number of global data structures. (Note that OpenSSL uses a number of global data
structures that will be implicitly shared whenever multiple threads structures that will be implicitly shared whenever multiple
use OpenSSL.) Multi-threaded applications will crash at random if threads use OpenSSL.) Multi-threaded applications will
it is not set. crash at random if it is not set.
locking_function() must be able to handle up to CRYPTO_num_locks() locking_function() must be able to handle up to
different mutex locks. It sets the n-th lock if mode & CRYPTO_LOCK, and CRYPTO_num_locks() different mutex locks. It sets the n-th
releases it otherwise. lock if mode & CRYPTO_LOCK, and releases it otherwise.
file and line are the file number of the function setting the file and line are the file number of the function setting the
lock. They can be useful for debugging. lock. They can be useful for debugging.
@ -1454,7 +1521,8 @@ static int _setup_ssl_threads(void) {
malloc(sizeof(PyThread_type_lock) * _ssl_locks_count); malloc(sizeof(PyThread_type_lock) * _ssl_locks_count);
if (_ssl_locks == NULL) if (_ssl_locks == NULL)
return 0; return 0;
memset(_ssl_locks, 0, sizeof(PyThread_type_lock) * _ssl_locks_count); memset(_ssl_locks, 0,
sizeof(PyThread_type_lock) * _ssl_locks_count);
for (i = 0; i < _ssl_locks_count; i++) { for (i = 0; i < _ssl_locks_count; i++) {
_ssl_locks[i] = PyThread_allocate_lock(); _ssl_locks[i] = PyThread_allocate_lock();
if (_ssl_locks[i] == NULL) { if (_ssl_locks[i] == NULL) {