Break up peer_protocol into several files

This commit is contained in:
Matt Joiner 2018-02-02 21:29:57 +11:00
parent 6441e98f62
commit b610107d8d
4 changed files with 249 additions and 241 deletions

124
peer_protocol/decoder.go Normal file
View File

@ -0,0 +1,124 @@
package peer_protocol
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"sync"
)
type Decoder struct {
R *bufio.Reader
Pool *sync.Pool
MaxLength Integer // TODO: Should this include the length header or not?
}
// io.EOF is returned if the source terminates cleanly on a message boundary.
func (d *Decoder) Decode(msg *Message) (err error) {
var length Integer
err = binary.Read(d.R, binary.BigEndian, &length)
if err != nil {
if err != io.EOF {
err = fmt.Errorf("error reading message length: %s", err)
}
return
}
if length > d.MaxLength {
return errors.New("message too long")
}
if length == 0 {
msg.Keepalive = true
return
}
msg.Keepalive = false
r := &io.LimitedReader{d.R, int64(length)}
// Check that all of r was utilized.
defer func() {
if err != nil {
return
}
if r.N != 0 {
err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
}
}()
msg.Keepalive = false
c, err := readByte(r)
if err != nil {
return
}
msg.Type = MessageType(c)
switch msg.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
return
case Have:
err = msg.Index.Read(r)
case Request, Cancel, Reject:
for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
err = data.Read(r)
if err != nil {
break
}
}
case Bitfield:
b := make([]byte, length-1)
_, err = io.ReadFull(r, b)
msg.Bitfield = unmarshalBitfield(b)
case Piece:
for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
err = pi.Read(r)
if err != nil {
break
}
}
if err != nil {
break
}
//msg.Piece, err = ioutil.ReadAll(r)
b := *d.Pool.Get().(*[]byte)
n, err := io.ReadFull(r, b)
if err != nil {
if err != io.ErrUnexpectedEOF || n != int(length-9) {
return err
}
b = b[0:n]
}
msg.Piece = b
case Extended:
msg.ExtendedID, err = readByte(r)
if err != nil {
break
}
msg.ExtendedPayload, err = ioutil.ReadAll(r)
case Port:
err = binary.Read(r, binary.BigEndian, &msg.Port)
default:
err = fmt.Errorf("unknown message type %#v", c)
}
return
}
func readByte(r io.Reader) (b byte, err error) {
var arr [1]byte
n, err := r.Read(arr[:])
b = arr[0]
if n == 1 {
err = nil
return
}
if err == nil {
panic(err)
}
return
}
func unmarshalBitfield(b []byte) (bf []bool) {
for _, c := range b {
for i := 7; i >= 0; i-- {
bf = append(bf, (c>>uint(i))&1 == 1)
}
}
return
}

21
peer_protocol/int.go Normal file
View File

@ -0,0 +1,21 @@
package peer_protocol
import (
"encoding/binary"
"io"
)
type Integer uint32
func (i *Integer) Read(r io.Reader) error {
return binary.Read(r, binary.BigEndian, i)
}
// It's perfectly fine to cast these to an int. TODO: Or is it?
func (i Integer) Int() int {
return int(i)
}
func (i Integer) Uint64() uint64 {
return uint64(i)
}

102
peer_protocol/msg.go Normal file
View File

@ -0,0 +1,102 @@
package peer_protocol
import (
"bytes"
"encoding/binary"
"fmt"
)
type Message struct {
Keepalive bool
Type MessageType
Index, Begin, Length Integer
Piece []byte
Bitfield []bool
ExtendedID byte
ExtendedPayload []byte
Port uint16
}
func MakeCancelMessage(piece, offset, length Integer) Message {
return Message{
Type: Cancel,
Index: piece,
Begin: offset,
Length: length,
}
}
func (msg Message) MustMarshalBinary() []byte {
b, err := msg.MarshalBinary()
if err != nil {
panic(err)
}
return b
}
func (msg Message) MarshalBinary() (data []byte, err error) {
buf := &bytes.Buffer{}
if !msg.Keepalive {
err = buf.WriteByte(byte(msg.Type))
if err != nil {
return
}
switch msg.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
case Have:
err = binary.Write(buf, binary.BigEndian, msg.Index)
case Request, Cancel, Reject:
for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
err = binary.Write(buf, binary.BigEndian, i)
if err != nil {
break
}
}
case Bitfield:
_, err = buf.Write(marshalBitfield(msg.Bitfield))
case Piece:
for _, i := range []Integer{msg.Index, msg.Begin} {
err = binary.Write(buf, binary.BigEndian, i)
if err != nil {
return
}
}
n, err := buf.Write(msg.Piece)
if err != nil {
break
}
if n != len(msg.Piece) {
panic(n)
}
case Extended:
err = buf.WriteByte(msg.ExtendedID)
if err != nil {
return
}
_, err = buf.Write(msg.ExtendedPayload)
case Port:
err = binary.Write(buf, binary.BigEndian, msg.Port)
default:
err = fmt.Errorf("unknown message type: %v", msg.Type)
}
}
data = make([]byte, 4+buf.Len())
binary.BigEndian.PutUint32(data, uint32(buf.Len()))
if buf.Len() != copy(data[4:], buf.Bytes()) {
panic("bad copy")
}
return
}
func marshalBitfield(bf []bool) (b []byte) {
b = make([]byte, (len(bf)+7)/8)
for i, have := range bf {
if !have {
continue
}
c := b[i/8]
c |= 1 << uint(7-i%8)
b[i/8] = c
}
return
}

View File

@ -1,36 +1,11 @@
package peer_protocol
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"sync"
const (
Protocol = "\x13BitTorrent protocol"
)
type (
MessageType byte
Integer uint32
)
func (i *Integer) Read(r io.Reader) error {
return binary.Read(r, binary.BigEndian, i)
}
// It's perfectly fine to cast these to an int. TODO: Or is it?
func (i Integer) Int() int {
return int(i)
}
func (i Integer) Uint64() uint64 {
return uint64(i)
}
const (
Protocol = "\x13BitTorrent protocol"
)
const (
@ -60,217 +35,3 @@ const (
DataMetadataExtensionMsgType = 1
RejectMetadataExtensionMsgType = 2
)
type Message struct {
Keepalive bool
Type MessageType
Index, Begin, Length Integer
Piece []byte
Bitfield []bool
ExtendedID byte
ExtendedPayload []byte
Port uint16
}
func (msg Message) MustMarshalBinary() []byte {
b, err := msg.MarshalBinary()
if err != nil {
panic(err)
}
return b
}
func (msg Message) MarshalBinary() (data []byte, err error) {
buf := &bytes.Buffer{}
if !msg.Keepalive {
err = buf.WriteByte(byte(msg.Type))
if err != nil {
return
}
switch msg.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
case Have:
err = binary.Write(buf, binary.BigEndian, msg.Index)
case Request, Cancel, Reject:
for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
err = binary.Write(buf, binary.BigEndian, i)
if err != nil {
break
}
}
case Bitfield:
_, err = buf.Write(marshalBitfield(msg.Bitfield))
case Piece:
for _, i := range []Integer{msg.Index, msg.Begin} {
err = binary.Write(buf, binary.BigEndian, i)
if err != nil {
return
}
}
n, err := buf.Write(msg.Piece)
if err != nil {
break
}
if n != len(msg.Piece) {
panic(n)
}
case Extended:
err = buf.WriteByte(msg.ExtendedID)
if err != nil {
return
}
_, err = buf.Write(msg.ExtendedPayload)
case Port:
err = binary.Write(buf, binary.BigEndian, msg.Port)
default:
err = fmt.Errorf("unknown message type: %v", msg.Type)
}
}
data = make([]byte, 4+buf.Len())
binary.BigEndian.PutUint32(data, uint32(buf.Len()))
if buf.Len() != copy(data[4:], buf.Bytes()) {
panic("bad copy")
}
return
}
type Decoder struct {
R *bufio.Reader
Pool *sync.Pool
MaxLength Integer // TODO: Should this include the length header or not?
}
func readByte(r io.Reader) (b byte, err error) {
var arr [1]byte
n, err := r.Read(arr[:])
b = arr[0]
if n == 1 {
err = nil
return
}
if err == nil {
panic(err)
}
return
}
// io.EOF is returned if the source terminates cleanly on a message boundary.
func (d *Decoder) Decode(msg *Message) (err error) {
var length Integer
err = binary.Read(d.R, binary.BigEndian, &length)
if err != nil {
if err != io.EOF {
err = fmt.Errorf("error reading message length: %s", err)
}
return
}
if length > d.MaxLength {
return errors.New("message too long")
}
if length == 0 {
msg.Keepalive = true
return
}
msg.Keepalive = false
r := &io.LimitedReader{d.R, int64(length)}
// Check that all of r was utilized.
defer func() {
if err != nil {
return
}
if r.N != 0 {
err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
}
}()
msg.Keepalive = false
c, err := readByte(r)
if err != nil {
return
}
msg.Type = MessageType(c)
switch msg.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
return
case Have:
err = msg.Index.Read(r)
case Request, Cancel, Reject:
for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
err = data.Read(r)
if err != nil {
break
}
}
case Bitfield:
b := make([]byte, length-1)
_, err = io.ReadFull(r, b)
msg.Bitfield = unmarshalBitfield(b)
case Piece:
for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
err = pi.Read(r)
if err != nil {
break
}
}
if err != nil {
break
}
//msg.Piece, err = ioutil.ReadAll(r)
b := *d.Pool.Get().(*[]byte)
n, err := io.ReadFull(r, b)
if err != nil {
if err != io.ErrUnexpectedEOF || n != int(length-9) {
return err
}
b = b[0:n]
}
msg.Piece = b
case Extended:
msg.ExtendedID, err = readByte(r)
if err != nil {
break
}
msg.ExtendedPayload, err = ioutil.ReadAll(r)
case Port:
err = binary.Read(r, binary.BigEndian, &msg.Port)
default:
err = fmt.Errorf("unknown message type %#v", c)
}
return
}
type Bytes []byte
func (b Bytes) MarshalBinary() ([]byte, error) {
return b, nil
}
func unmarshalBitfield(b []byte) (bf []bool) {
for _, c := range b {
for i := 7; i >= 0; i-- {
bf = append(bf, (c>>uint(i))&1 == 1)
}
}
return
}
func marshalBitfield(bf []bool) (b []byte) {
b = make([]byte, (len(bf)+7)/8)
for i, have := range bf {
if !have {
continue
}
c := b[i/8]
c |= 1 << uint(7-i%8)
b[i/8] = c
}
return
}
func MakeCancelMessage(piece, offset, length Integer) Message {
return Message{
Type: Cancel,
Index: piece,
Begin: offset,
Length: length,
}
}