Fix request/chunk confusion, missing outgoing message prefix, protocol tests; improve request triggering

This commit is contained in:
Matt Joiner 2013-10-01 18:43:18 +10:00
parent 081a6805c5
commit 28531a4fcc
4 changed files with 116 additions and 63 deletions

116
client.go
View File

@ -49,17 +49,18 @@ const (
)
type piece struct {
State pieceState
Hash pieceSum
PendingChunks map[chunk]struct{}
State pieceState
Hash pieceSum
PendingChunkSpecs map[chunkSpec]struct{}
}
type chunk struct {
type chunkSpec struct {
Begin, Length peer_protocol.Integer
}
type request struct {
Index, Begin, Length peer_protocol.Integer
Index peer_protocol.Integer
chunkSpec
}
type connection struct {
@ -102,6 +103,9 @@ func (c *connection) Request(chunk request) bool {
Length: chunk.Length,
})
}
if c.Requests == nil {
c.Requests = make(map[request]struct{}, maxRequests)
}
c.Requests[chunk] = struct{}{}
return true
}
@ -114,8 +118,9 @@ func (c *connection) SetInterested(interested bool) {
Type: func() peer_protocol.MessageType {
if interested {
return peer_protocol.Interested
} else {
return peer_protocol.NotInterested
}
return peer_protocol.NotInterested
}(),
})
c.Interested = interested
@ -124,7 +129,6 @@ func (c *connection) SetInterested(interested bool) {
func (conn *connection) writer() {
for {
b := <-conn.write
log.Printf("writing %#v", string(b))
n, err := conn.Socket.Write(b)
if err != nil {
log.Print(err)
@ -134,6 +138,7 @@ func (conn *connection) writer() {
if n != len(b) {
panic("didn't write all bytes")
}
log.Printf("wrote %#v", string(b))
}
}
@ -144,7 +149,6 @@ func (conn *connection) writeOptimizer() {
write := conn.write
if pending.Len() == 0 {
write = nil
nextWrite = nil
} else {
var err error
nextWrite, err = pending.Front().Value.(encoding.BinaryMarshaler).MarshalBinary()
@ -177,9 +181,9 @@ func (t *torrent) bitfield() (bf []bool) {
return
}
func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) {
cs = make(map[chunk]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize)
c := chunk{
func (t *torrent) pieceChunkSpecs(index int) (cs map[chunkSpec]struct{}) {
cs = make(map[chunkSpec]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize)
c := chunkSpec{
Begin: 0,
}
for left := peer_protocol.Integer(t.PieceSize(index)); left > 0; left -= c.Length {
@ -193,7 +197,7 @@ func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) {
return
}
func (t *torrent) chunkHeat() (ret map[request]int) {
func (t *torrent) requestHeat() (ret map[request]int) {
ret = make(map[request]int)
for _, conn := range t.Conns {
for req, _ := range conn.Requests {
@ -352,6 +356,15 @@ func (me *client) initiateConn(peer Peer, torrent *torrent) {
}()
}
func (me *torrent) haveAnyPieces() bool {
for _, piece := range me.Pieces {
if piece.State == pieceStateComplete {
return true
}
}
return false
}
func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
conn := &connection{
Socket: sock,
@ -400,10 +413,12 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
}
me.withContext(func() {
me.addConnection(torrent, conn)
conn.Post(peer_protocol.Message{
Type: peer_protocol.Bitfield,
Bitfield: torrent.bitfield(),
})
if torrent.haveAnyPieces() {
conn.Post(peer_protocol.Message{
Type: peer_protocol.Bitfield,
Bitfield: torrent.bitfield(),
})
}
go func() {
defer me.withContext(func() {
me.dropConnection(torrent, conn)
@ -417,29 +432,22 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
}
func (me *client) peerGotPiece(torrent *torrent, conn *connection, piece int) {
if torrent.Pieces[piece].State != pieceStateIncomplete {
return
if conn.PeerPieces == nil {
conn.PeerPieces = make([]bool, len(torrent.Pieces))
}
conn.SetInterested(true)
conn.PeerPieces[piece] = true
if torrent.wantPiece(piece) {
conn.SetInterested(true)
me.replenishConnRequests(torrent, conn)
}
}
func (t *torrent) wantPiece(index int) bool {
return t.Pieces[index].State == pieceStateIncomplete
}
func (me *client) peerUnchoked(torrent *torrent, conn *connection) {
chunkHeatMap := torrent.chunkHeat()
for index, has := range conn.PeerPieces {
if !has {
continue
}
for chunk, _ := range torrent.Pieces[index].PendingChunks {
if _, ok := chunkHeatMap[chunk]; ok {
continue
}
conn.SetInterested(true)
if !conn.Request(chunk) {
return
}
}
}
conn.SetInterested(false)
me.replenishConnRequests(torrent, conn)
}
func (me *client) runConnection(torrent *torrent, conn *connection) error {
@ -457,6 +465,7 @@ func (me *client) runConnection(torrent *torrent, conn *connection) error {
continue
}
go me.withContext(func() {
log.Print(msg)
var err error
switch msg.Type {
case peer_protocol.Choke:
@ -470,12 +479,10 @@ func (me *client) runConnection(torrent *torrent, conn *connection) error {
conn.PeerInterested = false
case peer_protocol.Have:
me.peerGotPiece(torrent, conn, int(msg.Index))
conn.PeerPieces[msg.Index] = true
case peer_protocol.Request:
conn.PeerRequests[request{
Index: msg.Index,
Begin: msg.Begin,
Length: msg.Length,
Index: msg.Index,
chunkSpec: chunkSpec{msg.Begin, msg.Length},
}] = struct{}{}
case peer_protocol.Bitfield:
if len(msg.Bitfield) < len(torrent.Pieces) {
@ -589,6 +596,33 @@ func (me *client) withContext(f func()) {
me.actorTask <- f
}
func (me *client) replenishConnRequests(torrent *torrent, conn *connection) {
if len(conn.Requests) >= maxRequests {
return
}
if conn.PeerChoked {
return
}
requestHeatMap := torrent.requestHeat()
for index, has := range conn.PeerPieces {
if !has {
continue
}
for chunkSpec, _ := range torrent.Pieces[index].PendingChunkSpecs {
request := request{peer_protocol.Integer(index), chunkSpec}
if heat := requestHeatMap[request]; heat > 0 {
continue
}
conn.SetInterested(true)
if !conn.Request(request) {
return
}
}
}
//conn.SetInterested(false)
}
func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
torrent := me.torrents[ih]
newState := func() pieceState {
@ -604,7 +638,7 @@ func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
}
torrent.Pieces[piece].State = newState
if newState == pieceStateIncomplete {
torrent.Pieces[piece].PendingChunks = torrent.pieceChunks(piece)
torrent.Pieces[piece].PendingChunkSpecs = torrent.pieceChunkSpecs(piece)
}
for _, conn := range torrent.Conns {
if correct {
@ -614,7 +648,7 @@ func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
})
} else {
if conn.PeerHasPiece(piece) {
conn.SetInterested(true)
me.replenishConnRequests(torrent, conn)
}
}
}

View File

@ -30,7 +30,7 @@ func main() {
}
err = client.AddPeers(torrent.BytesInfoHash(metaInfo.InfoHash), []torrent.Peer{{
IP: net.IPv4(127, 0, 0, 1),
Port: 53219,
Port: 50933,
}})
if err != nil {
log.Fatal(err)

View File

@ -68,7 +68,11 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
default:
err = errors.New("unknown message type")
}
data = buf.Bytes()
data = make([]byte, 4+buf.Len())
binary.BigEndian.PutUint32(data, uint32(buf.Len()))
if buf.Len() != copy(data[4:], buf.Bytes()) {
panic("bad copy")
}
return
}
@ -113,16 +117,6 @@ func (d *Decoder) Decode(msg *Message) (err error) {
return
}
func encodeMessage(type_ MessageType, data interface{}) []byte {
w := &bytes.Buffer{}
w.WriteByte(byte(type_))
err := binary.Write(w, binary.BigEndian, data)
if err != nil {
panic(err)
}
return w.Bytes()
}
type Bytes []byte
func (b Bytes) MarshalBinary() ([]byte, error) {

View File

@ -10,22 +10,47 @@ func TestConstants(t *testing.T) {
t.FailNow()
}
}
func TestBitfieldEncode(t *testing.T) {
bm := make(Bitfield, 37)
bm[2] = true
bm[7] = true
bm[32] = true
s := string(bm.Encode())
bf := make([]bool, 37)
bf[2] = true
bf[7] = true
bf[32] = true
s := string(marshalBitfield(bf))
const expected = "\x21\x00\x00\x00\x80"
if s != expected {
t.Fatalf("got %#v, expected %#v", s, expected)
}
}
func TestHaveEncode(t *testing.T) {
actual := string(Have(42).Encode())
expected := "\x04\x00\x00\x00\x2a"
if actual != expected {
t.Fatalf("expected %#v, got %#v", expected, actual)
func TestBitfieldUnmarshal(t *testing.T) {
bf := unmarshalBitfield([]byte("\x81\x06"))
expected := make([]bool, 16)
expected[0] = true
expected[7] = true
expected[13] = true
expected[14] = true
if len(bf) != len(expected) {
t.FailNow()
}
for i := range expected {
if bf[i] != expected[i] {
t.FailNow()
}
}
}
func TestHaveEncode(t *testing.T) {
actualBytes, err := Message{
Type: Have,
Index: 42,
}.MarshalBinary()
if err != nil {
t.Fatal(err)
}
actualString := string(actualBytes)
expected := "\x04\x00\x00\x00\x2a"
if actualString != expected {
t.Fatalf("expected %#v, got %#v", expected, actualString)
}
}