diff --git a/bencode/api.go b/bencode/api.go index b5f026c5..1b77f588 100644 --- a/bencode/api.go +++ b/bencode/api.go @@ -129,7 +129,7 @@ func Unmarshal(data []byte, v interface{}) error { } func NewDecoder(r io.Reader) *Decoder { - return &Decoder{r: bufio.NewReader(r)} + return &Decoder{r: &scanner{r: r}} } func NewEncoder(w io.Writer) *Encoder { diff --git a/bencode/decode_test.go b/bencode/decode_test.go index 1db0c7e4..572740b8 100644 --- a/bencode/decode_test.go +++ b/bencode/decode_test.go @@ -72,16 +72,28 @@ func TestDecoderConsecutive(t *testing.T) { func TestDecoderConsecutiveDicts(t *testing.T) { bb := bytes.NewBufferString("d4:herp4:derped3:wat1:ke17:oh baby a triple!") + d := NewDecoder(bb) + assert.EqualValues(t, "d4:herp4:derped3:wat1:ke17:oh baby a triple!", bb.Bytes()) + assert.EqualValues(t, 0, d.offset) + var m map[string]interface{} + require.NoError(t, d.Decode(&m)) assert.Len(t, m, 1) assert.Equal(t, "derp", m["herp"]) + assert.Equal(t, "d3:wat1:ke17:oh baby a triple!", bb.String()) + assert.EqualValues(t, 14, d.offset) + require.NoError(t, d.Decode(&m)) assert.Equal(t, "k", m["wat"]) + assert.Equal(t, "17:oh baby a triple!", bb.String()) + assert.EqualValues(t, 24, d.offset) + var s string require.NoError(t, d.Decode(&s)) assert.Equal(t, "oh baby a triple!", s) + assert.EqualValues(t, 44, d.offset) } func check_error(t *testing.T, err error) { diff --git a/bencode/scanner.go b/bencode/scanner.go new file mode 100644 index 00000000..eaf22dd4 --- /dev/null +++ b/bencode/scanner.go @@ -0,0 +1,41 @@ +package bencode + +import ( + "errors" + "io" +) + +// Implements io.ByteScanner over io.Reader, for use in Decoder, to ensure +// that as little as the undecoded input Reader is consumed as possible. +type scanner struct { + r io.Reader + b [1]byte // Buffer for ReadByte + unread bool // True if b has been unread, and so should be returned next +} + +func (me *scanner) Read(b []byte) (int, error) { + return me.r.Read(b) +} + +func (me *scanner) ReadByte() (byte, error) { + if me.unread { + me.unread = false + return me.b[0], nil + } + n, err := me.r.Read(me.b[:]) + if err != nil { + return me.b[0], err + } + if n != 1 { + panic(n) + } + return me.b[0], err +} + +func (me *scanner) UnreadByte() error { + if me.unread { + return errors.New("byte already unread") + } + me.unread = true + return nil +}