Split connection.rw into separate Reader and Writer

This will make it easier to write hooks for Read and Write separately.
This commit is contained in:
Matt Joiner 2016-10-10 16:30:51 +11:00
parent 099fb9546e
commit c34234bf34
3 changed files with 44 additions and 33 deletions

View File

@ -809,18 +809,16 @@ func (r deadlineReader) Read(b []byte) (n int, err error) {
return
}
type readWriter struct {
io.Reader
io.Writer
}
func maybeReceiveEncryptedHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, encrypted bool, err error) {
var protocol [len(pp.Protocol)]byte
_, err = io.ReadFull(rw, protocol[:])
if err != nil {
return
}
ret = readWriter{
ret = struct {
io.Reader
io.Writer
}{
io.MultiReader(bytes.NewReader(protocol[:]), rw),
rw,
}
@ -841,7 +839,12 @@ func (cl *Client) receiveSkeys() (ret [][]byte) {
func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) {
if c.encrypted {
c.rw, err = mse.InitiateHandshake(c.rw, t.infoHash[:], nil)
var rw io.ReadWriter
rw, err = mse.InitiateHandshake(struct {
io.Reader
io.Writer
}{c.r, c.w}, t.infoHash[:], nil)
c.setRW(rw)
if err != nil {
return
}
@ -859,7 +862,9 @@ func (cl *Client) receiveHandshakes(c *connection) (t *Torrent, err error) {
skeys := cl.receiveSkeys()
cl.mu.Unlock()
if !cl.config.DisableEncryption {
c.rw, c.encrypted, err = maybeReceiveEncryptedHandshake(c.rw, skeys)
var rw io.ReadWriter
rw, c.encrypted, err = maybeReceiveEncryptedHandshake(c.rw(), skeys)
c.setRW(rw)
if err != nil {
if err == mse.ErrNoSecretKeyMatch {
err = nil
@ -887,7 +892,7 @@ func (cl *Client) receiveHandshakes(c *connection) (t *Torrent, err error) {
// Returns !ok if handshake failed for valid reasons.
func (cl *Client) connBTHandshake(c *connection, ih *metainfo.Hash) (ret metainfo.Hash, ok bool, err error) {
res, ok, err := handshake(c.rw, ih, cl.peerID, cl.extensionBytes)
res, ok, err := handshake(c.rw(), ih, cl.peerID, cl.extensionBytes)
if err != nil || !ok {
return
}
@ -937,10 +942,7 @@ func (cl *Client) runReceivedConn(c *connection) {
func (cl *Client) runHandshookConn(c *connection, t *Torrent) {
c.conn.SetWriteDeadline(time.Time{})
c.rw = readWriter{
deadlineReader{c.conn, c.rw},
c.rw,
}
c.r = deadlineReader{c.conn, c.r}
completedHandshakeConnectionFlags.Add(c.connectionFlags(), 1)
if !t.addConnection(c) {
return

View File

@ -38,9 +38,14 @@ const (
// Maintains the state of a connection with a peer.
type connection struct {
t *Torrent
conn net.Conn
rw io.ReadWriter // The real slim shady
t *Torrent
// The actual Conn, used for closing, and setting socket options.
conn net.Conn
// The Reader and Writer for this Conn, with hooks installed for stats,
// limiting, deadlines etc.
w io.Writer
r io.Reader
// True if the connection is operating over MSE obfuscation.
encrypted bool
Discovery peerSource
uTP bool
@ -109,7 +114,7 @@ func newConnection(nc net.Conn, l sync.Locker) (c *connection) {
PeerChoked: true,
PeerMaxRequests: 250,
}
c.rw = connStatsReadWriter{nc, l, c}
c.setRW(connStatsReadWriter{nc, l, c})
return
}
@ -407,7 +412,7 @@ func (cn *connection) writer(keepAliveTimeout time.Duration) {
cn.Close()
}()
// Reduce write syscalls.
buf := bufio.NewWriter(cn.rw)
buf := bufio.NewWriter(cn.w)
keepAliveTimer := time.NewTimer(keepAliveTimeout)
for {
cn.mu().Lock()
@ -700,7 +705,7 @@ func (c *connection) mainReadLoop() error {
cl := t.cl
decoder := pp.Decoder{
R: bufio.NewReader(c.rw),
R: bufio.NewReader(c.r),
MaxLength: 256 * 1024,
Pool: t.chunkPool,
}
@ -907,3 +912,17 @@ func (c *connection) mainReadLoop() error {
}
}
}
// Set both the Reader and Writer for the connection from a single ReadWriter.
func (cn *connection) setRW(rw io.ReadWriter) {
cn.r = rw
cn.w = rw
}
// Returns the Reader and Writer as a combined ReadWriter.
func (cn *connection) rw() io.ReadWriter {
return struct {
io.Reader
io.Writer
}{cn.r, cn.w}
}

View File

@ -29,12 +29,7 @@ func TestCancelRequestOptimized(t *testing.T) {
bm.Set(1, true)
return bm
}(),
rw: struct {
io.Reader
io.Writer
}{
Writer: w,
},
w: w,
conn: new(net.TCPConn),
// For the locks
t: &Torrent{cl: &Client{}},
@ -74,10 +69,8 @@ func TestSendBitfieldThenHave(t *testing.T) {
t: &Torrent{
cl: &Client{},
},
rw: struct {
io.Reader
io.Writer
}{r, w},
r: r,
w: w,
outgoingUnbufferedMessages: list.New(),
}
go c.writer(time.Minute)
@ -153,10 +146,7 @@ func BenchmarkConnectionMainReadLoop(b *testing.B) {
r, w := io.Pipe()
cn := &connection{
t: t,
rw: struct {
io.Reader
io.Writer
}{r, nil},
r: r,
}
mrlErr := make(chan error)
cl.mu.Lock()