From 91730696cfe5dd9f088e21a4f3b3a1e915348758 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Sat, 14 Jul 2018 11:50:43 +1000 Subject: [PATCH] Rewrite piece data decoding and relax test --- peer_protocol/decoder.go | 28 +++++++++++++--------------- peer_protocol/decoder_test.go | 2 +- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/peer_protocol/decoder.go b/peer_protocol/decoder.go index 47a8b65e..b7ccab14 100644 --- a/peer_protocol/decoder.go +++ b/peer_protocol/decoder.go @@ -3,11 +3,12 @@ package peer_protocol import ( "bufio" "encoding/binary" - "errors" "fmt" "io" "io/ioutil" "sync" + + "github.com/pkg/errors" ) type Decoder struct { @@ -69,24 +70,21 @@ func (d *Decoder) Decode(msg *Message) (err error) { msg.Bitfield = unmarshalBitfield(b) case Piece: for _, pi := range []*Integer{&msg.Index, &msg.Begin} { - err = pi.Read(r) + err := pi.Read(r) if err != nil { - break - } - } - if err != nil { - break - } - //msg.Piece, err = ioutil.ReadAll(r) - b := *d.Pool.Get().(*[]byte) - n, err := io.ReadFull(r, b) - if err != nil { - if err != io.ErrUnexpectedEOF || n != int(length-9) { return err } - b = b[0:n] } - msg.Piece = b + dataLen := r.N + msg.Piece = (*d.Pool.Get().(*[]byte)) + if int64(cap(msg.Piece)) < dataLen { + return errors.New("piece data longer than expected") + } + msg.Piece = msg.Piece[:dataLen] + _, err := io.ReadFull(r, msg.Piece) + if err != nil { + return errors.Wrap(err, "reading piece data") + } case Extended: b, err := readByte(r) if err != nil { diff --git a/peer_protocol/decoder_test.go b/peer_protocol/decoder_test.go index bbd8194d..33909cdc 100644 --- a/peer_protocol/decoder_test.go +++ b/peer_protocol/decoder_test.go @@ -89,5 +89,5 @@ func TestDecodeOverlongPiece(t *testing.T) { }}, } var m Message - require.EqualError(t, d.Decode(&m), "piece data longer than expected") + require.Error(t, d.Decode(&m)) }