Check for chunks overflowing piece bounds on request read

Test for integer overflow in when checking read requests are within the bounds of the associated piece. Another fix is required to limit the amount of memory that can be allocated for such requests.
This commit is contained in:
Matt Joiner 2023-02-13 23:27:15 +11:00
parent 60fd7581e7
commit abb5cbc96e
No known key found for this signature in database
GPG Key ID: 6B990B8185E7F782
3 changed files with 40 additions and 4 deletions

View File

@ -3,11 +3,18 @@ package peer_protocol
import ( import (
"encoding/binary" "encoding/binary"
"io" "io"
"math"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type Integer uint32 type (
// An alias for the underlying type of Integer. This is needed for fuzzing.
IntegerKind = uint32
Integer IntegerKind
)
const IntegerMax = math.MaxUint32
func (i *Integer) UnmarshalBinary(b []byte) error { func (i *Integer) UnmarshalBinary(b []byte) error {
if len(b) != 4 { if len(b) != 4 {

View File

@ -1003,6 +1003,18 @@ func (c *PeerConn) maximumPeerRequestChunkLength() (_ Option[int]) {
return Some(uploadRateLimiter.Burst()) return Some(uploadRateLimiter.Burst())
} }
// Returns whether any part of the chunk would lie outside a piece of the given length.
func chunkOverflowsPiece(cs ChunkSpec, pieceLength pp.Integer) bool {
switch {
default:
return false
case cs.Begin+cs.Length > pieceLength:
// Check for integer overflow
case cs.Begin > pp.IntegerMax-cs.Length:
}
return true
}
// startFetch is for testing purposes currently. // startFetch is for testing purposes currently.
func (c *PeerConn) onReadRequest(r Request, startFetch bool) error { func (c *PeerConn) onReadRequest(r Request, startFetch bool) error {
requestedChunkLengths.Add(strconv.FormatUint(r.Length.Uint64(), 10), 1) requestedChunkLengths.Add(strconv.FormatUint(r.Length.Uint64(), 10), 1)
@ -1045,10 +1057,11 @@ func (c *PeerConn) onReadRequest(r Request, startFetch bool) error {
requestsReceivedForMissingPieces.Add(1) requestsReceivedForMissingPieces.Add(1)
return fmt.Errorf("peer requested piece we don't have: %v", r.Index.Int()) return fmt.Errorf("peer requested piece we don't have: %v", r.Index.Int())
} }
pieceLength := c.t.pieceLength(pieceIndex(r.Index))
// Check this after we know we have the piece, so that the piece length will be known. // Check this after we know we have the piece, so that the piece length will be known.
if r.Begin+r.Length > c.t.pieceLength(pieceIndex(r.Index)) { if chunkOverflowsPiece(r.ChunkSpec, pieceLength) {
torrent.Add("bad requests received", 1) torrent.Add("bad requests received", 1)
return errors.New("bad Request") return errors.New("chunk overflows piece")
} }
if c.peerRequests == nil { if c.peerRequests == nil {
c.peerRequests = make(map[Request]*peerRequestState, localClientReqq) c.peerRequests = make(map[Request]*peerRequestState, localClientReqq)
@ -1255,6 +1268,9 @@ func (c *PeerConn) mainReadLoop() (err error) {
case pp.Request: case pp.Request:
r := newRequestFromMessage(&msg) r := newRequestFromMessage(&msg)
err = c.onReadRequest(r, true) err = c.onReadRequest(r, true)
if err != nil {
err = fmt.Errorf("on reading request %v: %w", r, err)
}
case pp.Piece: case pp.Piece:
c.doChunkReadStats(int64(len(msg.Piece))) c.doChunkReadStats(int64(len(msg.Piece)))
err = c.receiveChunk(&msg) err = c.receiveChunk(&msg)

View File

@ -4,12 +4,13 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/time/rate"
"io" "io"
"net" "net"
"sync" "sync"
"testing" "testing"
"golang.org/x/time/rate"
"github.com/frankban/quicktest" "github.com/frankban/quicktest"
qt "github.com/frankban/quicktest" qt "github.com/frankban/quicktest"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -317,3 +318,15 @@ func TestReceiveLargeRequest(t *testing.T) {
c.Check(pc.onReadRequest(req, false), qt.IsNil) c.Check(pc.onReadRequest(req, false), qt.IsNil)
c.Check(pc.messageWriter.writeBuffer.Len(), qt.Equals, 17) c.Check(pc.messageWriter.writeBuffer.Len(), qt.Equals, 17)
} }
func TestChunkOverflowsPiece(t *testing.T) {
c := qt.New(t)
check := func(begin, length, limit pp.Integer, expected bool) {
c.Check(chunkOverflowsPiece(ChunkSpec{begin, length}, limit), qt.Equals, expected)
}
check(2, 3, 1, true)
check(2, pp.IntegerMax, 1, true)
check(2, pp.IntegerMax, 3, true)
check(2, pp.IntegerMax, pp.IntegerMax, true)
check(2, pp.IntegerMax-2, pp.IntegerMax, false)
}