Break up peer_protocol into several files
This commit is contained in:
parent
6441e98f62
commit
b610107d8d
|
@ -0,0 +1,124 @@
|
|||
package peer_protocol
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Decoder struct {
|
||||
R *bufio.Reader
|
||||
Pool *sync.Pool
|
||||
MaxLength Integer // TODO: Should this include the length header or not?
|
||||
}
|
||||
|
||||
// io.EOF is returned if the source terminates cleanly on a message boundary.
|
||||
func (d *Decoder) Decode(msg *Message) (err error) {
|
||||
var length Integer
|
||||
err = binary.Read(d.R, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
err = fmt.Errorf("error reading message length: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if length > d.MaxLength {
|
||||
return errors.New("message too long")
|
||||
}
|
||||
if length == 0 {
|
||||
msg.Keepalive = true
|
||||
return
|
||||
}
|
||||
msg.Keepalive = false
|
||||
r := &io.LimitedReader{d.R, int64(length)}
|
||||
// Check that all of r was utilized.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if r.N != 0 {
|
||||
err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
|
||||
}
|
||||
}()
|
||||
msg.Keepalive = false
|
||||
c, err := readByte(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
msg.Type = MessageType(c)
|
||||
switch msg.Type {
|
||||
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
|
||||
return
|
||||
case Have:
|
||||
err = msg.Index.Read(r)
|
||||
case Request, Cancel, Reject:
|
||||
for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
|
||||
err = data.Read(r)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
case Bitfield:
|
||||
b := make([]byte, length-1)
|
||||
_, err = io.ReadFull(r, b)
|
||||
msg.Bitfield = unmarshalBitfield(b)
|
||||
case Piece:
|
||||
for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
|
||||
err = pi.Read(r)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
//msg.Piece, err = ioutil.ReadAll(r)
|
||||
b := *d.Pool.Get().(*[]byte)
|
||||
n, err := io.ReadFull(r, b)
|
||||
if err != nil {
|
||||
if err != io.ErrUnexpectedEOF || n != int(length-9) {
|
||||
return err
|
||||
}
|
||||
b = b[0:n]
|
||||
}
|
||||
msg.Piece = b
|
||||
case Extended:
|
||||
msg.ExtendedID, err = readByte(r)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
msg.ExtendedPayload, err = ioutil.ReadAll(r)
|
||||
case Port:
|
||||
err = binary.Read(r, binary.BigEndian, &msg.Port)
|
||||
default:
|
||||
err = fmt.Errorf("unknown message type %#v", c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func readByte(r io.Reader) (b byte, err error) {
|
||||
var arr [1]byte
|
||||
n, err := r.Read(arr[:])
|
||||
b = arr[0]
|
||||
if n == 1 {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func unmarshalBitfield(b []byte) (bf []bool) {
|
||||
for _, c := range b {
|
||||
for i := 7; i >= 0; i-- {
|
||||
bf = append(bf, (c>>uint(i))&1 == 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package peer_protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Integer uint32
|
||||
|
||||
func (i *Integer) Read(r io.Reader) error {
|
||||
return binary.Read(r, binary.BigEndian, i)
|
||||
}
|
||||
|
||||
// It's perfectly fine to cast these to an int. TODO: Or is it?
|
||||
func (i Integer) Int() int {
|
||||
return int(i)
|
||||
}
|
||||
|
||||
func (i Integer) Uint64() uint64 {
|
||||
return uint64(i)
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
package peer_protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Keepalive bool
|
||||
Type MessageType
|
||||
Index, Begin, Length Integer
|
||||
Piece []byte
|
||||
Bitfield []bool
|
||||
ExtendedID byte
|
||||
ExtendedPayload []byte
|
||||
Port uint16
|
||||
}
|
||||
|
||||
func MakeCancelMessage(piece, offset, length Integer) Message {
|
||||
return Message{
|
||||
Type: Cancel,
|
||||
Index: piece,
|
||||
Begin: offset,
|
||||
Length: length,
|
||||
}
|
||||
}
|
||||
|
||||
func (msg Message) MustMarshalBinary() []byte {
|
||||
b, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (msg Message) MarshalBinary() (data []byte, err error) {
|
||||
buf := &bytes.Buffer{}
|
||||
if !msg.Keepalive {
|
||||
err = buf.WriteByte(byte(msg.Type))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch msg.Type {
|
||||
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
|
||||
case Have:
|
||||
err = binary.Write(buf, binary.BigEndian, msg.Index)
|
||||
case Request, Cancel, Reject:
|
||||
for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
|
||||
err = binary.Write(buf, binary.BigEndian, i)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
case Bitfield:
|
||||
_, err = buf.Write(marshalBitfield(msg.Bitfield))
|
||||
case Piece:
|
||||
for _, i := range []Integer{msg.Index, msg.Begin} {
|
||||
err = binary.Write(buf, binary.BigEndian, i)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
n, err := buf.Write(msg.Piece)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if n != len(msg.Piece) {
|
||||
panic(n)
|
||||
}
|
||||
case Extended:
|
||||
err = buf.WriteByte(msg.ExtendedID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buf.Write(msg.ExtendedPayload)
|
||||
case Port:
|
||||
err = binary.Write(buf, binary.BigEndian, msg.Port)
|
||||
default:
|
||||
err = fmt.Errorf("unknown message type: %v", msg.Type)
|
||||
}
|
||||
}
|
||||
data = make([]byte, 4+buf.Len())
|
||||
binary.BigEndian.PutUint32(data, uint32(buf.Len()))
|
||||
if buf.Len() != copy(data[4:], buf.Bytes()) {
|
||||
panic("bad copy")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func marshalBitfield(bf []bool) (b []byte) {
|
||||
b = make([]byte, (len(bf)+7)/8)
|
||||
for i, have := range bf {
|
||||
if !have {
|
||||
continue
|
||||
}
|
||||
c := b[i/8]
|
||||
c |= 1 << uint(7-i%8)
|
||||
b[i/8] = c
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,36 +1,11 @@
|
|||
package peer_protocol
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"sync"
|
||||
const (
|
||||
Protocol = "\x13BitTorrent protocol"
|
||||
)
|
||||
|
||||
type (
|
||||
MessageType byte
|
||||
Integer uint32
|
||||
)
|
||||
|
||||
func (i *Integer) Read(r io.Reader) error {
|
||||
return binary.Read(r, binary.BigEndian, i)
|
||||
}
|
||||
|
||||
// It's perfectly fine to cast these to an int. TODO: Or is it?
|
||||
func (i Integer) Int() int {
|
||||
return int(i)
|
||||
}
|
||||
|
||||
func (i Integer) Uint64() uint64 {
|
||||
return uint64(i)
|
||||
}
|
||||
|
||||
const (
|
||||
Protocol = "\x13BitTorrent protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -60,217 +35,3 @@ const (
|
|||
DataMetadataExtensionMsgType = 1
|
||||
RejectMetadataExtensionMsgType = 2
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Keepalive bool
|
||||
Type MessageType
|
||||
Index, Begin, Length Integer
|
||||
Piece []byte
|
||||
Bitfield []bool
|
||||
ExtendedID byte
|
||||
ExtendedPayload []byte
|
||||
Port uint16
|
||||
}
|
||||
|
||||
func (msg Message) MustMarshalBinary() []byte {
|
||||
b, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (msg Message) MarshalBinary() (data []byte, err error) {
|
||||
buf := &bytes.Buffer{}
|
||||
if !msg.Keepalive {
|
||||
err = buf.WriteByte(byte(msg.Type))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch msg.Type {
|
||||
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
|
||||
case Have:
|
||||
err = binary.Write(buf, binary.BigEndian, msg.Index)
|
||||
case Request, Cancel, Reject:
|
||||
for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
|
||||
err = binary.Write(buf, binary.BigEndian, i)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
case Bitfield:
|
||||
_, err = buf.Write(marshalBitfield(msg.Bitfield))
|
||||
case Piece:
|
||||
for _, i := range []Integer{msg.Index, msg.Begin} {
|
||||
err = binary.Write(buf, binary.BigEndian, i)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
n, err := buf.Write(msg.Piece)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if n != len(msg.Piece) {
|
||||
panic(n)
|
||||
}
|
||||
case Extended:
|
||||
err = buf.WriteByte(msg.ExtendedID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buf.Write(msg.ExtendedPayload)
|
||||
case Port:
|
||||
err = binary.Write(buf, binary.BigEndian, msg.Port)
|
||||
default:
|
||||
err = fmt.Errorf("unknown message type: %v", msg.Type)
|
||||
}
|
||||
}
|
||||
data = make([]byte, 4+buf.Len())
|
||||
binary.BigEndian.PutUint32(data, uint32(buf.Len()))
|
||||
if buf.Len() != copy(data[4:], buf.Bytes()) {
|
||||
panic("bad copy")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Decoder struct {
|
||||
R *bufio.Reader
|
||||
Pool *sync.Pool
|
||||
MaxLength Integer // TODO: Should this include the length header or not?
|
||||
}
|
||||
|
||||
func readByte(r io.Reader) (b byte, err error) {
|
||||
var arr [1]byte
|
||||
n, err := r.Read(arr[:])
|
||||
b = arr[0]
|
||||
if n == 1 {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// io.EOF is returned if the source terminates cleanly on a message boundary.
|
||||
func (d *Decoder) Decode(msg *Message) (err error) {
|
||||
var length Integer
|
||||
err = binary.Read(d.R, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
err = fmt.Errorf("error reading message length: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if length > d.MaxLength {
|
||||
return errors.New("message too long")
|
||||
}
|
||||
if length == 0 {
|
||||
msg.Keepalive = true
|
||||
return
|
||||
}
|
||||
msg.Keepalive = false
|
||||
r := &io.LimitedReader{d.R, int64(length)}
|
||||
// Check that all of r was utilized.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if r.N != 0 {
|
||||
err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
|
||||
}
|
||||
}()
|
||||
msg.Keepalive = false
|
||||
c, err := readByte(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
msg.Type = MessageType(c)
|
||||
switch msg.Type {
|
||||
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
|
||||
return
|
||||
case Have:
|
||||
err = msg.Index.Read(r)
|
||||
case Request, Cancel, Reject:
|
||||
for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
|
||||
err = data.Read(r)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
case Bitfield:
|
||||
b := make([]byte, length-1)
|
||||
_, err = io.ReadFull(r, b)
|
||||
msg.Bitfield = unmarshalBitfield(b)
|
||||
case Piece:
|
||||
for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
|
||||
err = pi.Read(r)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
//msg.Piece, err = ioutil.ReadAll(r)
|
||||
b := *d.Pool.Get().(*[]byte)
|
||||
n, err := io.ReadFull(r, b)
|
||||
if err != nil {
|
||||
if err != io.ErrUnexpectedEOF || n != int(length-9) {
|
||||
return err
|
||||
}
|
||||
b = b[0:n]
|
||||
}
|
||||
msg.Piece = b
|
||||
case Extended:
|
||||
msg.ExtendedID, err = readByte(r)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
msg.ExtendedPayload, err = ioutil.ReadAll(r)
|
||||
case Port:
|
||||
err = binary.Read(r, binary.BigEndian, &msg.Port)
|
||||
default:
|
||||
err = fmt.Errorf("unknown message type %#v", c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Bytes []byte
|
||||
|
||||
func (b Bytes) MarshalBinary() ([]byte, error) {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func unmarshalBitfield(b []byte) (bf []bool) {
|
||||
for _, c := range b {
|
||||
for i := 7; i >= 0; i-- {
|
||||
bf = append(bf, (c>>uint(i))&1 == 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func marshalBitfield(bf []bool) (b []byte) {
|
||||
b = make([]byte, (len(bf)+7)/8)
|
||||
for i, have := range bf {
|
||||
if !have {
|
||||
continue
|
||||
}
|
||||
c := b[i/8]
|
||||
c |= 1 << uint(7-i%8)
|
||||
b[i/8] = c
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func MakeCancelMessage(piece, offset, length Integer) Message {
|
||||
return Message{
|
||||
Type: Cancel,
|
||||
Index: piece,
|
||||
Begin: offset,
|
||||
Length: length,
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue