Avoid rebuffering in peer_protocol.Decode

This commit is contained in:
Matt Joiner 2014-05-23 00:36:47 +10:00
parent 65fa317244
commit dd30d144ae
2 changed files with 16 additions and 6 deletions

View File

@ -105,7 +105,21 @@ func (d *Decoder) Decode(msg *Message) (err error) {
if length > d.MaxLength {
return errors.New("message too long")
}
r := bufio.NewReader(io.LimitReader(d.R, int64(length)))
if length == 0 {
msg.Keepalive = true
return
}
msg.Keepalive = false
b := make([]byte, length)
_, err = io.ReadFull(d.R, b)
if err == io.EOF {
err = io.ErrUnexpectedEOF
return
}
if err != nil {
return
}
r := bytes.NewReader(b)
defer func() {
written, _ := io.Copy(ioutil.Discard, r)
if written != 0 && err == nil {
@ -114,10 +128,6 @@ func (d *Decoder) Decode(msg *Message) (err error) {
err = io.ErrUnexpectedEOF
}
}()
if length == 0 {
msg.Keepalive = true
return
}
msg.Keepalive = false
c, err := r.ReadByte()
if err != nil {

View File

@ -104,7 +104,7 @@ func TestUnexpectedEOF(t *testing.T) {
}
err := dec.Decode(msg)
if err != io.ErrUnexpectedEOF {
t.Fatal(err)
t.Fatalf("expected ErrUnexpectedEOF decoding %q, got %s", stream, err)
}
}
}