diff --git a/bencode/both_test.go b/bencode/both_test.go index 837e6115..a000f35b 100644 --- a/bencode/both_test.go +++ b/bencode/both_test.go @@ -3,7 +3,6 @@ package bencode import "testing" import "bytes" import "io/ioutil" -import "time" func load_file(name string, t *testing.T) []byte { data, err := ioutil.ReadFile(name) @@ -13,8 +12,8 @@ func load_file(name string, t *testing.T) []byte { return data } -func TestBothInterface(t *testing.T) { - data1 := load_file("_testdata/archlinux-2011.08.19-netinstall-i686.iso.torrent", t) +func test_file_interface(t *testing.T, filename string) { + data1 := load_file(filename, t) var iface interface{} err := Unmarshal(data1, &iface) @@ -30,6 +29,12 @@ func TestBothInterface(t *testing.T) { if !bytes.Equal(data1, data2) { t.Fatalf("equality expected\n") } + +} + +func TestBothInterface(t *testing.T) { + test_file_interface(t, "_testdata/archlinux-2011.08.19-netinstall-i686.iso.torrent") + test_file_interface(t, "_testdata/continuum.torrent") } type torrent_file struct { @@ -50,8 +55,8 @@ type torrent_file struct { URLList interface{} `bencode:"url-list,omitempty"` } -func TestBoth(t *testing.T) { - data1 := load_file("_testdata/archlinux-2011.08.19-netinstall-i686.iso.torrent", t) +func test_file(t *testing.T, filename string) { + data1 := load_file(filename, t) var f torrent_file err := Unmarshal(data1, &f) @@ -59,19 +64,17 @@ func TestBoth(t *testing.T) { t.Fatal(err) } - t.Logf("Name: %s\n", f.Info.Name) - t.Logf("Length: %v bytes\n", f.Info.Length) - t.Logf("Announce: %s\n", f.Announce) - t.Logf("CreationDate: %s\n", time.Unix(f.CreationDate, 0).String()) - t.Logf("CreatedBy: %s\n", f.CreatedBy) - t.Logf("Comment: %s\n", f.Comment) - data2, err := Marshal(&f) if err != nil { t.Fatal(err) } if !bytes.Equal(data1, data2) { + println(string(data2)) t.Fatalf("equality expected") } } + +func TestBoth(t *testing.T) { + test_file(t, "_testdata/archlinux-2011.08.19-netinstall-i686.iso.torrent") +} diff --git a/bencode/decode.go b/bencode/decode.go index 903dfe4a..72251e50 100644 --- a/bencode/decode.go +++ b/bencode/decode.go @@ -319,20 +319,105 @@ func (d *decoder) parse_list(v reflect.Value) { } } +func (d *decoder) read_one_value() bool { + b, err := d.ReadByte() + if err != nil { + panic(err) + } + if b == 'e' { + d.UnreadByte() + return false + } else { + d.offset++ + d.buf.WriteByte(b) + } + + switch b { + case 'd', 'l': + // read until there is nothing to read + for d.read_one_value() {} + // consume 'e' as well + b = d.read_byte() + d.buf.WriteByte(b) + case 'i': + d.read_until('e') + d.buf.WriteString("e") + default: + if b >= '0' && b <= '9' { + start := d.buf.Len() - 1 + d.read_until(':') + length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64) + check_for_int_parse_error(err, d.offset - 1) + + d.buf.WriteString(":") + n, err := io.CopyN(&d.buf, d, length) + d.offset += n + if err != nil { + check_for_unexpected_eof(err, d.offset) + panic(&SyntaxError{ + Offset: d.offset, + what: "unexpected I/O error: " + err.Error(), + }) + } + break + } + + // unknown value + panic(&SyntaxError{ + Offset: d.offset - 1, + what: "unknown value type (invalid bencode?)", + }) + } + + return true + +} + +func (d *decoder) parse_unmarshaler(v reflect.Value) bool { + m, ok := v.Interface().(Unmarshaler) + if !ok { + // T doesn't work, try *T + if v.Kind() != reflect.Ptr && v.CanAddr() { + m, ok = v.Addr().Interface().(Unmarshaler) + if ok { + v = v.Addr() + } + } + } + if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) { + if d.read_one_value() { + err := m.UnmarshalBencode(d.buf.Bytes()) + d.buf.Reset() + if err != nil { + panic(err) + } + return true + } + d.buf.Reset() + } + + return false +} + // returns true if there was a value and it's now stored in 'v', otherwise there // was an end symbol ("e") and no value was stored func (d *decoder) parse_value(v reflect.Value) bool { - if pv := v; pv.Kind() == reflect.Ptr { + // we support one level of indirection at the moment + if v.Kind() == reflect.Ptr { // if the pointer is nil, allocate a new element of the type it // points to - if pv.IsNil() { - pv.Set(reflect.New(pv.Type().Elem())) + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) } - v = pv.Elem() + v = v.Elem() } - // common case - if v.Kind() == reflect.Interface { + if d.parse_unmarshaler(v) { + return true + } + + // common case: interface{} + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { iface, _ := d.parse_value_interface() v.Set(reflect.ValueOf(iface)) return true diff --git a/bencode/decode_test.go b/bencode/decode_test.go index 4f06f391..b0714a74 100644 --- a/bencode/decode_test.go +++ b/bencode/decode_test.go @@ -34,3 +34,44 @@ func TestRandomDecode(t *testing.T) { } } } + +func check_error(t *testing.T, err error) { + if err != nil { + t.Error(err) + } +} + +func assert_equal(t *testing.T, x, y interface{}) { + if !reflect.DeepEqual(x, y) { + t.Errorf("got: %v (%T), expected: %v (%T)\n", x, x, y, y) + } +} + +type unmarshaler_int struct { + x int +} + +func (this *unmarshaler_int) UnmarshalBencode(data []byte) error { + return Unmarshal(data, &this.x) +} + +type unmarshaler_string struct { + x string +} + +func (this *unmarshaler_string) UnmarshalBencode(data []byte) error { + this.x = string(data) + return nil +} + +func TestUnmarshalerBencode(t *testing.T) { + var i unmarshaler_int + var ss []unmarshaler_string + check_error(t, Unmarshal([]byte("i71e"), &i)) + assert_equal(t, i.x, 71) + check_error(t, Unmarshal([]byte("l5:hello5:fruit3:waye"), &ss)) + assert_equal(t, ss[0].x, "5:hello") + assert_equal(t, ss[1].x, "5:fruit") + assert_equal(t, ss[2].x, "3:way") + +} diff --git a/bencode/encode.go b/bencode/encode.go index 4a6eea17..196aa095 100644 --- a/bencode/encode.go +++ b/bencode/encode.go @@ -78,13 +78,12 @@ func (e *encoder) reflect_byte_slice(s []byte) { e.write(s) } -func (e *encoder) reflect_value(v reflect.Value) { - if !v.IsValid() { - return - } - +// returns true if the value implements Marshaler interface and marshaling was +// done successfully +func (e *encoder) reflect_marshaler(v reflect.Value) bool { m, ok := v.Interface().(Marshaler) if !ok { + // T doesn't work, try *T if v.Kind() != reflect.Ptr && v.CanAddr() { m, ok = v.Addr().Interface().(Marshaler) if ok { @@ -98,6 +97,18 @@ func (e *encoder) reflect_value(v reflect.Value) { panic(&MarshalerError{v.Type(), err}) } e.write(data) + return true + } + + return false +} + +func (e *encoder) reflect_value(v reflect.Value) { + if !v.IsValid() { + return + } + + if e.reflect_marshaler(v) { return }