This is roughly socket2.diff from issue 1378, with a few changes applied

to ssl.py (no need to test whether we can dup any more).
Regular sockets no longer have a _base, but we still have explicit
reference counting of socket objects for the benefit of makefile();
using duplicate sockets won't work for SSLSocket.
This commit is contained in:
Guido van Rossum 2007-11-16 01:24:05 +00:00
parent dd9e3b8736
commit 39eb8fa0db
5 changed files with 126 additions and 118 deletions

View File

@ -26,6 +26,15 @@ PyAPI_FUNC(size_t) PyLong_AsSize_t(PyObject *);
PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLong(PyObject *); PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLong(PyObject *);
PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLongMask(PyObject *); PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLongMask(PyObject *);
/* Used by socketmodule.c */
#if SIZEOF_SOCKET_T <= SIZEOF_LONG
#define PyLong_FromSocket_t(fd) PyLong_FromLong((SOCKET_T)(fd))
#define PyLong_AsSocket_t(fd) (SOCKET_T)PyLong_AsLong(fd)
#else
#define PyLong_FromSocket_t(fd) PyLong_FromLongLong(((SOCKET_T)(fd));
#define PyLong_AsSocket_t(fd) (SOCKET_T)PyLong_AsLongLong(fd)
#endif
/* For use by intobject.c only */ /* For use by intobject.c only */
PyAPI_DATA(int) _PyLong_DigitValue[256]; PyAPI_DATA(int) _PyLong_DigitValue[256];

View File

@ -79,27 +79,13 @@
__all__.append("errorTab") __all__.append("errorTab")
# True if os.dup() can duplicate socket descriptors.
# (On Windows at least, os.dup only works on files)
_can_dup_socket = hasattr(_socket.socket, "dup")
if _can_dup_socket:
def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0):
nfd = os.dup(fd)
return socket(family, type, proto, fileno=nfd)
class socket(_socket.socket): class socket(_socket.socket):
"""A subclass of _socket.socket adding the makefile() method.""" """A subclass of _socket.socket adding the makefile() method."""
__slots__ = ["__weakref__", "_io_refs", "_closed"] __slots__ = ["__weakref__", "_io_refs", "_closed"]
if not _can_dup_socket:
__slots__.append("_base")
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None): def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
if fileno is None:
_socket.socket.__init__(self, family, type, proto)
else:
_socket.socket.__init__(self, family, type, proto, fileno) _socket.socket.__init__(self, family, type, proto, fileno)
self._io_refs = 0 self._io_refs = 0
self._closed = False self._closed = False
@ -114,23 +100,29 @@ def __repr__(self):
s[7:]) s[7:])
return s return s
def dup(self):
"""dup() -> socket object
Return a new socket object connected to the same system resource.
"""
fd = dup(self.fileno())
sock = self.__class__(self.family, self.type, self.proto, fileno=fd)
sock.settimeout(self.gettimeout())
return sock
def accept(self): def accept(self):
"""Wrap accept() to give the connection the right type.""" """accept() -> (socket object, address info)
conn, addr = _socket.socket.accept(self)
fd = conn.fileno() Wait for an incoming connection. Return a new socket
nfd = fd representing the connection, and the address of the client.
if _can_dup_socket: For IP sockets, the address info is a pair (hostaddr, port).
nfd = os.dup(fd) """
wrapper = socket(self.family, self.type, self.proto, fileno=nfd) fd, addr = self._accept()
if fd == nfd: return socket(self.family, self.type, self.proto, fileno=fd), addr
wrapper._base = conn # Keep the base alive
else:
conn.close()
return wrapper, addr
def makefile(self, mode="r", buffering=None, *, def makefile(self, mode="r", buffering=None, *,
encoding=None, newline=None): encoding=None, newline=None):
"""Return an I/O stream connected to the socket. """makefile(...) -> an I/O stream connected to the socket
The arguments are as for io.open() after the filename, The arguments are as for io.open() after the filename,
except the only mode characters supported are 'r', 'w' and 'b'. except the only mode characters supported are 'r', 'w' and 'b'.
@ -184,23 +176,20 @@ def _decref_socketios(self):
def close(self): def close(self):
self._closed = True self._closed = True
if self._io_refs < 1: if self._io_refs <= 0:
self._real_close()
# _real_close calls close on the _socket.socket base class.
if not _can_dup_socket:
def _real_close(self):
_socket.socket.close(self)
base = getattr(self, "_base", None)
if base is not None:
self._base = None
base.close()
else:
def _real_close(self):
_socket.socket.close(self) _socket.socket.close(self)
def fromfd(fd, family, type, proto=0):
""" fromfd(fd, family, type[, proto]) -> socket object
Create a socket object from a duplicate of the given file
descriptor. The remaining arguments are the same as for socket().
"""
nfd = dup(fd)
return socket(family, type, proto, nfd)
class SocketIO(io.RawIOBase): class SocketIO(io.RawIOBase):
"""Raw I/O implementation for stream sockets. """Raw I/O implementation for stream sockets.

View File

@ -78,8 +78,8 @@
from socket import socket, AF_INET, SOCK_STREAM, error from socket import socket, AF_INET, SOCK_STREAM, error
from socket import getnameinfo as _getnameinfo from socket import getnameinfo as _getnameinfo
from socket import error as socket_error from socket import error as socket_error
from socket import dup as _dup
import base64 # for DER-to-PEM translation import base64 # for DER-to-PEM translation
_can_dup_socket = hasattr(socket, "dup")
class SSLSocket(socket): class SSLSocket(socket):
@ -99,20 +99,11 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
if sock is not None: if sock is not None:
# copied this code from socket.accept() # copied this code from socket.accept()
fd = sock.fileno() fd = sock.fileno()
nfd = fd nfd = _dup(fd)
if _can_dup_socket:
nfd = os.dup(fd)
try:
socket.__init__(self, family=sock.family, type=sock.type, socket.__init__(self, family=sock.family, type=sock.type,
proto=sock.proto, fileno=nfd) proto=sock.proto, fileno=nfd)
except:
if nfd != fd:
os.close(nfd)
else:
if fd != nfd:
sock.close() sock.close()
sock = None sock = None
elif fileno is not None: elif fileno is not None:
socket.__init__(self, fileno=fileno) socket.__init__(self, fileno=fileno)
else: else:

View File

@ -575,6 +575,15 @@ def testFromFd(self):
def _testFromFd(self): def _testFromFd(self):
self.serv_conn.send(MSG) self.serv_conn.send(MSG)
def testDup(self):
# Testing dup()
sock = self.cli_conn.dup()
msg = sock.recv(1024)
self.assertEqual(msg, MSG)
def _testDup(self):
self.serv_conn.send(MSG)
def testShutdown(self): def testShutdown(self):
# Testing shutdown() # Testing shutdown()
msg = self.cli_conn.recv(1024) msg = self.cli_conn.recv(1024)

View File

@ -89,12 +89,12 @@ A socket object represents one endpoint of a network connection.\n\
\n\ \n\
Methods of socket objects (keyword arguments not allowed):\n\ Methods of socket objects (keyword arguments not allowed):\n\
\n\ \n\
accept() -- accept a connection, returning new socket and client address\n\ _accept() -- accept connection, returning new socket fd and client address\n\
bind(addr) -- bind the socket to a local address\n\ bind(addr) -- bind the socket to a local address\n\
close() -- close the socket\n\ close() -- close the socket\n\
connect(addr) -- connect the socket to a remote address\n\ connect(addr) -- connect the socket to a remote address\n\
connect_ex(addr) -- connect, return an error code instead of an exception\n\ connect_ex(addr) -- connect, return an error code instead of an exception\n\
dup() -- return a new socket object identical to the current one [*]\n\ _dup() -- return a new socket fd duplicated from fileno()\n\
fileno() -- return underlying file descriptor\n\ fileno() -- return underlying file descriptor\n\
getpeername() -- return remote address [*]\n\ getpeername() -- return remote address [*]\n\
getsockname() -- return local address\n\ getsockname() -- return local address\n\
@ -327,10 +327,26 @@ const char *inet_ntop(int af, const void *src, char *dst, socklen_t size);
#include "getnameinfo.c" #include "getnameinfo.c"
#endif #endif
#if defined(MS_WINDOWS) #ifdef MS_WINDOWS
/* seem to be a few differences in the API */ /* On Windows a socket is really a handle not an fd */
static SOCKET
dup_socket(SOCKET handle)
{
HANDLE newhandle;
if (!DuplicateHandle(GetCurrentProcess(), (HANDLE)handle,
GetCurrentProcess(), &newhandle,
0, FALSE, DUPLICATE_SAME_ACCESS))
{
WSASetLastError(GetLastError());
return INVALID_SOCKET;
}
return (SOCKET)newhandle;
}
#define SOCKETCLOSE closesocket #define SOCKETCLOSE closesocket
#define NO_DUP /* Actually it exists on NT 3.5, but what the heck... */ #else
/* On Unix we can use dup to duplicate the file descriptor of a socket*/
#define dup_socket(fd) dup(fd)
#endif #endif
#ifdef MS_WIN32 #ifdef MS_WIN32
@ -1423,7 +1439,7 @@ getsockaddrlen(PySocketSockObject *s, socklen_t *len_ret)
} }
/* s.accept() method */ /* s._accept() -> (fd, address) */
static PyObject * static PyObject *
sock_accept(PySocketSockObject *s) sock_accept(PySocketSockObject *s)
@ -1457,17 +1473,12 @@ sock_accept(PySocketSockObject *s)
if (newfd == INVALID_SOCKET) if (newfd == INVALID_SOCKET)
return s->errorhandler(); return s->errorhandler();
/* Create the new object with unspecified family, sock = PyLong_FromSocket_t(newfd);
to avoid calls to bind() etc. on it. */
sock = (PyObject *) new_sockobject(newfd,
s->sock_family,
s->sock_type,
s->sock_proto);
if (sock == NULL) { if (sock == NULL) {
SOCKETCLOSE(newfd); SOCKETCLOSE(newfd);
goto finally; goto finally;
} }
addr = makesockaddr(s->sock_fd, SAS2SA(&addrbuf), addr = makesockaddr(s->sock_fd, SAS2SA(&addrbuf),
addrlen, s->sock_proto); addrlen, s->sock_proto);
if (addr == NULL) if (addr == NULL)
@ -1482,11 +1493,11 @@ sock_accept(PySocketSockObject *s)
} }
PyDoc_STRVAR(accept_doc, PyDoc_STRVAR(accept_doc,
"accept() -> (socket object, address info)\n\ "_accept() -> (integer, address info)\n\
\n\ \n\
Wait for an incoming connection. Return a new socket representing the\n\ Wait for an incoming connection. Return a new socket file descriptor\n\
connection, and the address of the client. For IP sockets, the address\n\ representing the connection, and the address of the client.\n\
info is a pair (hostaddr, port)."); For IP sockets, the address info is a pair (hostaddr, port).");
/* s.setblocking(flag) method. Argument: /* s.setblocking(flag) method. Argument:
False -- non-blocking mode; same as settimeout(0) False -- non-blocking mode; same as settimeout(0)
@ -1882,11 +1893,7 @@ instead of raising an exception when an error occurs.");
static PyObject * static PyObject *
sock_fileno(PySocketSockObject *s) sock_fileno(PySocketSockObject *s)
{ {
#if SIZEOF_SOCKET_T <= SIZEOF_LONG return PyLong_FromSocket_t(s->sock_fd);
return PyInt_FromLong((long) s->sock_fd);
#else
return PyLong_FromLongLong((PY_LONG_LONG)s->sock_fd);
#endif
} }
PyDoc_STRVAR(fileno_doc, PyDoc_STRVAR(fileno_doc,
@ -1895,35 +1902,6 @@ PyDoc_STRVAR(fileno_doc,
Return the integer file descriptor of the socket."); Return the integer file descriptor of the socket.");
#ifndef NO_DUP
/* s.dup() method */
static PyObject *
sock_dup(PySocketSockObject *s)
{
SOCKET_T newfd;
PyObject *sock;
newfd = dup(s->sock_fd);
if (newfd < 0)
return s->errorhandler();
sock = (PyObject *) new_sockobject(newfd,
s->sock_family,
s->sock_type,
s->sock_proto);
if (sock == NULL)
SOCKETCLOSE(newfd);
return sock;
}
PyDoc_STRVAR(dup_doc,
"dup() -> socket object\n\
\n\
Return a new socket object connected to the same system resource.");
#endif
/* s.getsockname() method */ /* s.getsockname() method */
static PyObject * static PyObject *
@ -2542,7 +2520,7 @@ of the socket (flag == SHUT_WR), or both ends (flag == SHUT_RDWR).");
/* List of methods for socket objects */ /* List of methods for socket objects */
static PyMethodDef sock_methods[] = { static PyMethodDef sock_methods[] = {
{"accept", (PyCFunction)sock_accept, METH_NOARGS, {"_accept", (PyCFunction)sock_accept, METH_NOARGS,
accept_doc}, accept_doc},
{"bind", (PyCFunction)sock_bind, METH_O, {"bind", (PyCFunction)sock_bind, METH_O,
bind_doc}, bind_doc},
@ -2552,10 +2530,6 @@ static PyMethodDef sock_methods[] = {
connect_doc}, connect_doc},
{"connect_ex", (PyCFunction)sock_connect_ex, METH_O, {"connect_ex", (PyCFunction)sock_connect_ex, METH_O,
connect_ex_doc}, connect_ex_doc},
#ifndef NO_DUP
{"dup", (PyCFunction)sock_dup, METH_NOARGS,
dup_doc},
#endif
{"fileno", (PyCFunction)sock_fileno, METH_NOARGS, {"fileno", (PyCFunction)sock_fileno, METH_NOARGS,
fileno_doc}, fileno_doc},
#ifdef HAVE_GETPEERNAME #ifdef HAVE_GETPEERNAME
@ -2672,8 +2646,8 @@ sock_initobj(PyObject *self, PyObject *args, PyObject *kwds)
&family, &type, &proto, &fdobj)) &family, &type, &proto, &fdobj))
return -1; return -1;
if (fdobj != NULL) { if (fdobj != NULL && fdobj != Py_None) {
fd = PyLong_AsLongLong(fdobj); fd = PyLong_AsSocket_t(fdobj);
if (fd == (SOCKET_T)(-1) && PyErr_Occurred()) if (fd == (SOCKET_T)(-1) && PyErr_Occurred())
return -1; return -1;
if (fd == INVALID_SOCKET) { if (fd == INVALID_SOCKET) {
@ -3172,6 +3146,38 @@ PyDoc_STRVAR(getprotobyname_doc,
Return the protocol number for the named protocol. (Rarely used.)"); Return the protocol number for the named protocol. (Rarely used.)");
#ifndef NO_DUP
/* dup() function for socket fds */
static PyObject *
socket_dup(PyObject *self, PyObject *fdobj)
{
SOCKET_T fd, newfd;
PyObject *newfdobj;
fd = PyLong_AsSocket_t(fdobj);
if (fd == (SOCKET_T)(-1) && PyErr_Occurred())
return NULL;
newfd = dup_socket(fd);
if (newfd == INVALID_SOCKET)
return set_error();
newfdobj = PyLong_FromSocket_t(newfd);
if (newfdobj == NULL)
SOCKETCLOSE(newfd);
return newfdobj;
}
PyDoc_STRVAR(dup_doc,
"dup(integer) -> integer\n\
\n\
Duplicate an integer socket file descriptor. This is like os.dup(), but for\n\
sockets; on some platforms os.dup() won't work for socket file descriptors.");
#endif
#ifdef HAVE_SOCKETPAIR #ifdef HAVE_SOCKETPAIR
/* Create a pair of sockets using the socketpair() function. /* Create a pair of sockets using the socketpair() function.
Arguments as for socket() except the default family is AF_UNIX if Arguments as for socket() except the default family is AF_UNIX if
@ -3811,6 +3817,10 @@ static PyMethodDef socket_methods[] = {
METH_VARARGS, getservbyport_doc}, METH_VARARGS, getservbyport_doc},
{"getprotobyname", socket_getprotobyname, {"getprotobyname", socket_getprotobyname,
METH_VARARGS, getprotobyname_doc}, METH_VARARGS, getprotobyname_doc},
#ifndef NO_DUP
{"dup", socket_dup,
METH_O, dup_doc},
#endif
#ifdef HAVE_SOCKETPAIR #ifdef HAVE_SOCKETPAIR
{"socketpair", socket_socketpair, {"socketpair", socket_socketpair,
METH_VARARGS, socketpair_doc}, METH_VARARGS, socketpair_doc},