Check that incoming peer request chunk lengths don't exceed the upload rate limiter burst size

Should fix #759.
This commit is contained in:
Matt Joiner 2022-06-25 23:16:58 +10:00
parent ae4eb8569b
commit 12279621e4
No known key found for this signature in database
GPG Key ID: 6B990B8185E7F782
4 changed files with 75 additions and 15 deletions

View File

@ -973,7 +973,7 @@ func (cl *Client) runHandshookConn(c *PeerConn, t *Torrent) error {
return fmt.Errorf("adding connection: %w", err) return fmt.Errorf("adding connection: %w", err)
} }
defer t.dropConnection(c) defer t.dropConnection(c)
c.startWriter() c.startMessageWriter()
cl.sendInitialMessages(c, t) cl.sendInitialMessages(c, t)
c.initUpdateRequestsTimer() c.initUpdateRequestsTimer()
err := c.mainReadLoop() err := c.mainReadLoop()

View File

@ -12,7 +12,7 @@ import (
pp "github.com/anacrolix/torrent/peer_protocol" pp "github.com/anacrolix/torrent/peer_protocol"
) )
func (pc *PeerConn) startWriter() { func (pc *PeerConn) initMessageWriter() {
w := &pc.messageWriter w := &pc.messageWriter
*w = peerConnMsgWriter{ *w = peerConnMsgWriter{
fillWriteBuffer: func() { fillWriteBuffer: func() {
@ -33,12 +33,18 @@ func (pc *PeerConn) startWriter() {
}, },
writeBuffer: new(bytes.Buffer), writeBuffer: new(bytes.Buffer),
} }
go func() { }
func (pc *PeerConn) startMessageWriter() {
pc.initMessageWriter()
go pc.messageWriterRunner()
}
func (pc *PeerConn) messageWriterRunner() {
defer pc.locker().Unlock() defer pc.locker().Unlock()
defer pc.close() defer pc.close()
defer pc.locker().Lock() defer pc.locker().Lock()
pc.messageWriter.run(pc.t.cl.config.KeepAliveTimeout) pc.messageWriter.run(pc.t.cl.config.KeepAliveTimeout)
}()
} }
type peerConnMsgWriter struct { type peerConnMsgWriter struct {

View File

@ -5,6 +5,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/time/rate"
"io" "io"
"math/rand" "math/rand"
"net" "net"
@ -986,10 +987,22 @@ func (c *PeerConn) reject(r Request) {
delete(c.peerRequests, r) delete(c.peerRequests, r)
} }
func (c *PeerConn) onReadRequest(r Request) error { func (c *PeerConn) maximumPeerRequestChunkLength() (_ Option[int]) {
uploadRateLimiter := c.t.cl.config.UploadRateLimiter
if uploadRateLimiter.Limit() == rate.Inf {
return
}
return Some(uploadRateLimiter.Burst())
}
// startFetch is for testing purposes currently.
func (c *PeerConn) onReadRequest(r Request, startFetch bool) error {
requestedChunkLengths.Add(strconv.FormatUint(r.Length.Uint64(), 10), 1) requestedChunkLengths.Add(strconv.FormatUint(r.Length.Uint64(), 10), 1)
if _, ok := c.peerRequests[r]; ok { if _, ok := c.peerRequests[r]; ok {
torrent.Add("duplicate requests received", 1) torrent.Add("duplicate requests received", 1)
if c.fastEnabled() {
return errors.New("received duplicate request with fast enabled")
}
return nil return nil
} }
if c.choking { if c.choking {
@ -1009,10 +1022,18 @@ func (c *PeerConn) onReadRequest(r Request) error {
// BEP 6 says we may close here if we choose. // BEP 6 says we may close here if we choose.
return nil return nil
} }
if opt := c.maximumPeerRequestChunkLength(); opt.Ok && int(r.Length) > opt.Value {
err := fmt.Errorf("peer requested chunk too long (%v)", r.Length)
c.logger.Levelf(log.Warning, err.Error())
if c.fastEnabled() {
c.reject(r)
return nil
} else {
return err
}
}
if !c.t.havePiece(pieceIndex(r.Index)) { if !c.t.havePiece(pieceIndex(r.Index)) {
// This isn't necessarily them screwing up. We can drop pieces // TODO: Tell the peer we don't have the piece, and reject this request.
// from our storage, and can't communicate this to peers
// except by reconnecting.
requestsReceivedForMissingPieces.Add(1) requestsReceivedForMissingPieces.Add(1)
return fmt.Errorf("peer requested piece we don't have: %v", r.Index.Int()) return fmt.Errorf("peer requested piece we don't have: %v", r.Index.Int())
} }
@ -1026,7 +1047,10 @@ func (c *PeerConn) onReadRequest(r Request) error {
} }
value := &peerRequestState{} value := &peerRequestState{}
c.peerRequests[r] = value c.peerRequests[r] = value
if startFetch {
// TODO: Limit peer request data read concurrency.
go c.peerRequestDataReader(r, value) go c.peerRequestDataReader(r, value)
}
return nil return nil
} }
@ -1222,7 +1246,7 @@ func (c *PeerConn) mainReadLoop() (err error) {
err = c.peerSentBitfield(msg.Bitfield) err = c.peerSentBitfield(msg.Bitfield)
case pp.Request: case pp.Request:
r := newRequestFromMessage(&msg) r := newRequestFromMessage(&msg)
err = c.onReadRequest(r) err = c.onReadRequest(r, true)
case pp.Piece: case pp.Piece:
c.doChunkReadStats(int64(len(msg.Piece))) c.doChunkReadStats(int64(len(msg.Piece)))
err = c.receiveChunk(&msg) err = c.receiveChunk(&msg)

View File

@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/time/rate"
"io" "io"
"net" "net"
"sync" "sync"
@ -32,7 +33,7 @@ func TestSendBitfieldThenHave(t *testing.T) {
r, w := io.Pipe() r, w := io.Pipe()
// c.r = r // c.r = r
c.w = w c.w = w
c.startWriter() c.startMessageWriter()
c.locker().Lock() c.locker().Lock()
c.t._completedPieces.Add(1) c.t._completedPieces.Add(1)
c.postBitfield( /*[]bool{false, true, false}*/ ) c.postBitfield( /*[]bool{false, true, false}*/ )
@ -282,3 +283,32 @@ func TestPreferredNetworkDirection(t *testing.T) {
// No difference // No difference
c.Assert(pc(1, 2, false, false, false).hasPreferredNetworkOver(pc(1, 2, false, false, false)), qt.IsFalse) c.Assert(pc(1, 2, false, false, false).hasPreferredNetworkOver(pc(1, 2, false, false, false)), qt.IsFalse)
} }
func TestReceiveLargeRequest(t *testing.T) {
c := qt.New(t)
cl := newTestingClient(t)
pc := cl.newConnection(nil, false, nil, "test", "")
tor := cl.newTorrentForTesting()
tor.info = &metainfo.Info{PieceLength: 3 << 20}
pc.setTorrent(tor)
tor._completedPieces.Add(0)
pc.PeerExtensionBytes.SetBit(pp.ExtensionBitFast, true)
pc.choking = false
pc.initMessageWriter()
req := Request{}
req.Length = defaultChunkSize
c.Assert(pc.fastEnabled(), qt.IsTrue)
c.Check(pc.onReadRequest(req, false), qt.IsNil)
c.Check(pc.peerRequests, qt.HasLen, 1)
req.Length = 2 << 20
c.Check(pc.onReadRequest(req, false), qt.IsNil)
c.Check(pc.peerRequests, qt.HasLen, 2)
pc.peerRequests = nil
pc.t.cl.config.UploadRateLimiter = rate.NewLimiter(1, defaultChunkSize)
req.Length = defaultChunkSize
c.Check(pc.onReadRequest(req, false), qt.IsNil)
c.Check(pc.peerRequests, qt.HasLen, 1)
req.Length = 2 << 20
c.Check(pc.onReadRequest(req, false), qt.IsNil)
c.Check(pc.messageWriter.writeBuffer.Len(), qt.Equals, 17)
}