FedP2P/bencode/encode.go

252 lines
5.0 KiB
Go
Raw Normal View History

2012-06-20 21:21:32 +08:00
package bencode
2015-10-18 22:25:56 +08:00
import (
"io"
2015-10-18 22:25:56 +08:00
"reflect"
"runtime"
"sort"
"strconv"
"sync"
"github.com/anacrolix/missinggo"
2015-10-18 22:25:56 +08:00
)
2012-06-20 21:21:32 +08:00
func isEmptyValue(v reflect.Value) bool {
return missinggo.IsEmptyValue(v)
2012-06-20 21:21:32 +08:00
}
type Encoder struct {
w interface {
Flush() error
io.Writer
WriteString(string) (int, error)
}
2012-06-20 21:21:32 +08:00
scratch [64]byte
}
func (e *Encoder) Encode(v interface{}) (err error) {
if v == nil {
return
}
2012-06-20 21:21:32 +08:00
defer func() {
if e := recover(); e != nil {
if _, ok := e.(runtime.Error); ok {
panic(e)
}
var ok bool
err, ok = e.(error)
if !ok {
panic(e)
}
2012-06-20 21:21:32 +08:00
}
}()
e.reflectValue(reflect.ValueOf(v))
return e.w.Flush()
2012-06-20 21:21:32 +08:00
}
type string_values []reflect.Value
func (sv string_values) Len() int { return len(sv) }
func (sv string_values) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
func (sv string_values) get(i int) string { return sv[i].String() }
func (e *Encoder) write(s []byte) {
_, err := e.w.Write(s)
if err != nil {
panic(err)
}
}
func (e *Encoder) writeString(s string) {
_, err := e.w.WriteString(s)
if err != nil {
panic(err)
}
}
func (e *Encoder) reflectString(s string) {
2012-06-20 21:21:32 +08:00
b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
e.write(b)
e.writeString(":")
e.writeString(s)
2012-06-20 21:21:32 +08:00
}
func (e *Encoder) reflectByteSlice(s []byte) {
2012-06-20 21:21:32 +08:00
b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
e.write(b)
e.writeString(":")
e.write(s)
2012-06-20 21:21:32 +08:00
}
// returns true if the value implements Marshaler interface and marshaling was
// done successfully
func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
m, ok := v.Interface().(Marshaler)
if !ok {
// T doesn't work, try *T
if v.Kind() != reflect.Ptr && v.CanAddr() {
m, ok = v.Addr().Interface().(Marshaler)
if ok {
v = v.Addr()
}
}
}
if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
data, err := m.MarshalBencode()
if err != nil {
panic(&MarshalerError{v.Type(), err})
}
e.write(data)
return true
}
return false
}
func (e *Encoder) reflectValue(v reflect.Value) {
if e.reflectMarshaler(v) {
2012-06-24 19:10:53 +08:00
return
}
2012-06-20 21:21:32 +08:00
switch v.Kind() {
case reflect.Bool:
if v.Bool() {
e.writeString("i1e")
2012-06-20 21:21:32 +08:00
} else {
e.writeString("i0e")
2012-06-20 21:21:32 +08:00
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
e.writeString("i")
e.write(b)
e.writeString("e")
2012-06-20 21:21:32 +08:00
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
e.writeString("i")
e.write(b)
e.writeString("e")
2012-06-20 21:21:32 +08:00
case reflect.String:
e.reflectString(v.String())
2012-06-20 21:21:32 +08:00
case reflect.Struct:
e.writeString("d")
for _, ef := range encodeFields(v.Type()) {
2012-06-20 21:21:32 +08:00
field_value := v.Field(ef.i)
if ef.omit_empty && isEmptyValue(field_value) {
2012-06-20 21:21:32 +08:00
continue
}
e.reflectString(ef.tag)
e.reflectValue(field_value)
2012-06-20 21:21:32 +08:00
}
e.writeString("e")
2012-06-20 21:21:32 +08:00
case reflect.Map:
if v.Type().Key().Kind() != reflect.String {
panic(&MarshalTypeError{v.Type()})
}
if v.IsNil() {
e.writeString("de")
2012-06-20 21:21:32 +08:00
break
}
e.writeString("d")
2012-06-20 21:21:32 +08:00
sv := string_values(v.MapKeys())
sort.Sort(sv)
for _, key := range sv {
e.reflectString(key.String())
e.reflectValue(v.MapIndex(key))
2012-06-20 21:21:32 +08:00
}
e.writeString("e")
2012-06-20 21:21:32 +08:00
case reflect.Slice:
if v.IsNil() {
e.writeString("le")
2012-06-20 21:21:32 +08:00
break
}
if v.Type().Elem().Kind() == reflect.Uint8 {
s := v.Bytes()
e.reflectByteSlice(s)
2012-06-20 21:21:32 +08:00
break
}
fallthrough
case reflect.Array:
e.writeString("l")
2012-06-20 21:21:32 +08:00
for i, n := 0, v.Len(); i < n; i++ {
e.reflectValue(v.Index(i))
2012-06-20 21:21:32 +08:00
}
e.writeString("e")
case reflect.Interface:
e.reflectValue(v.Elem())
case reflect.Ptr:
2012-06-20 21:21:32 +08:00
if v.IsNil() {
v = reflect.Zero(v.Type().Elem())
} else {
v = v.Elem()
2012-06-20 21:21:32 +08:00
}
e.reflectValue(v)
2012-06-20 21:21:32 +08:00
default:
panic(&MarshalTypeError{v.Type()})
}
}
type encodeField struct {
2012-06-20 21:21:32 +08:00
i int
tag string
omit_empty bool
}
type encodeFieldsSortType []encodeField
2012-06-20 21:21:32 +08:00
func (ef encodeFieldsSortType) Len() int { return len(ef) }
func (ef encodeFieldsSortType) Swap(i, j int) { ef[i], ef[j] = ef[j], ef[i] }
func (ef encodeFieldsSortType) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
2012-06-20 21:21:32 +08:00
var (
typeCacheLock sync.RWMutex
encodeFieldsCache = make(map[reflect.Type][]encodeField)
2012-06-20 21:21:32 +08:00
)
func encodeFields(t reflect.Type) []encodeField {
typeCacheLock.RLock()
fs, ok := encodeFieldsCache[t]
typeCacheLock.RUnlock()
2012-06-20 21:21:32 +08:00
if ok {
return fs
}
typeCacheLock.Lock()
defer typeCacheLock.Unlock()
fs, ok = encodeFieldsCache[t]
2012-06-20 21:21:32 +08:00
if ok {
return fs
}
for i, n := 0, t.NumField(); i < n; i++ {
f := t.Field(i)
if f.PkgPath != "" {
continue
}
if f.Anonymous {
continue
}
var ef encodeField
2012-06-20 21:21:32 +08:00
ef.i = i
ef.tag = f.Name
tv := f.Tag.Get("bencode")
if tv != "" {
if tv == "-" {
continue
}
2016-08-26 12:51:38 +08:00
name, opts := parseTag(tv)
if name != "" {
ef.tag = name
}
2012-06-20 21:21:32 +08:00
ef.omit_empty = opts.contains("omitempty")
}
fs = append(fs, ef)
}
fss := encodeFieldsSortType(fs)
2012-06-20 21:21:32 +08:00
sort.Sort(fss)
encodeFieldsCache[t] = fs
2012-06-20 21:21:32 +08:00
return fs
}