diff --git a/dht/dht.go b/dht/dht.go index 14a0af0f..fc703085 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -148,6 +148,14 @@ func (m Msg) T() (t string) { return } +func (m Msg) Nodes() []NodeInfo { + var r findNodeResponse + if err := r.UnmarshalKRPCMsg(m); err != nil { + return nil + } + return r.Nodes +} + type KRPCError struct { Code int Msg string @@ -159,12 +167,21 @@ func (me KRPCError) Error() string { var _ error = KRPCError{} -func (m Msg) Error() *KRPCError { +func (m Msg) Error() (ret *KRPCError) { if m["y"] != "e" { - return nil + return } - l := m["e"].([]interface{}) - return &KRPCError{int(l[0].(int64)), l[1].(string)} + ret = &KRPCError{} + switch e := m["e"].(type) { + case []interface{}: + ret.Code = int(e[0].(int64)) + ret.Msg = e[1].(string) + case string: + ret.Msg = e + default: + logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e) + } + return } // Returns the token given in response to a get_peers request for future @@ -175,6 +192,7 @@ func (m Msg) AnnounceToken() string { } type transaction struct { + mu sync.Mutex remoteAddr dHTAddr t string Response chan Msg @@ -183,12 +201,36 @@ type transaction struct { } func (t *transaction) timeout() { + t.Close() +} + +func (t *transaction) closing() bool { + select { + case <-t.done: + return true + default: + return false + } +} + +func (t *transaction) Close() { + t.mu.Lock() + defer t.mu.Unlock() + if t.closing() { + return + } close(t.Response) close(t.done) } func (t *transaction) handleResponse(m Msg) { + t.mu.Lock() + if t.closing() { + t.mu.Unlock() + return + } close(t.done) + t.mu.Unlock() if t.onResponse != nil { t.onResponse(m) } @@ -272,6 +314,8 @@ func (s *Server) serve() error { } func (s *Server) AddNode(ni NodeInfo) { + s.mu.Lock() + defer s.mu.Unlock() if s.nodes == nil { s.nodes = make(map[string]*Node) } @@ -697,27 +741,6 @@ func (s *Server) findNode(addr dHTAddr, targetID string) (t *transaction, err er return } -type peerStreamValue struct { - Peers []util.CompactPeer // Peers given in get_peers response. - NodeInfo // The node that gave the response. -} - -type peerStream struct { - mu sync.Mutex - Values chan peerStreamValue - stop chan struct{} -} - -func (ps *peerStream) Close() { - ps.mu.Lock() - select { - case <-ps.stop: - default: - close(ps.stop) - } - ps.mu.Unlock() -} - func extractValues(m Msg) (vs []util.CompactPeer) { r, ok := m["r"] if !ok { @@ -752,63 +775,6 @@ func extractValues(m Msg) (vs []util.CompactPeer) { return } -func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) { - ps = &peerStream{ - Values: make(chan peerStreamValue), - stop: make(chan struct{}), - } - done := make(chan struct{}) - pending := 0 - s.mu.Lock() - for _, n := range s.closestGoodNodes(160, infoHash) { - var t *transaction - t, err = s.getPeers(n.addr, infoHash) - if err != nil { - ps.Close() - break - } - go func() { - select { - case m := <-t.Response: - vs := extractValues(m) - if vs != nil { - nodeInfo := NodeInfo{ - Addr: t.remoteAddr, - } - id := func() string { - defer func() { - recover() - }() - return m["r"].(map[string]interface{})["id"].(string) - }() - copy(nodeInfo.ID[:], id) - select { - case ps.Values <- peerStreamValue{ - Peers: vs, - NodeInfo: nodeInfo, - }: - case <-ps.stop: - } - } - case <-ps.stop: - } - done <- struct{}{} - }() - pending++ - } - s.mu.Unlock() - go func() { - for ; pending > 0; pending-- { - select { - case <-done: - case <-s.closed: - } - } - close(ps.Values) - }() - return -} - func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *transaction, err error) { if len(infoHash) != 20 { err = fmt.Errorf("infohash has bad length") @@ -825,6 +791,10 @@ func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *transaction, err er return } +func bootstrapAddr() (net.Addr, error) { + return net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881") +} + func (s *Server) addRootNode() error { addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881") if err != nil { diff --git a/dht/getpeers.go b/dht/getpeers.go new file mode 100644 index 00000000..1f136c9c --- /dev/null +++ b/dht/getpeers.go @@ -0,0 +1,170 @@ +package dht + +import ( + "log" + "net" + "sync" + + "bitbucket.org/anacrolix/go.torrent/util" +) + +type peerDiscovery struct { + *peerStream + triedAddrs map[string]struct{} + contactAddrs chan net.Addr + pending int + transactionClosed chan struct{} + server *Server + infoHash string +} + +func (me *peerDiscovery) Close() { + me.peerStream.Close() + close(me.contactAddrs) +} + +func (s *Server) GetPeers(infoHash string) (*peerStream, error) { + disc := &peerDiscovery{ + peerStream: &peerStream{ + Values: make(chan peerStreamValue), + stop: make(chan struct{}), + }, + triedAddrs: make(map[string]struct{}, 500), + contactAddrs: make(chan net.Addr), + transactionClosed: make(chan struct{}), + server: s, + infoHash: infoHash, + } + go disc.loop() + s.mu.Lock() + startAddrs := func() (ret []net.Addr) { + for _, n := range s.closestGoodNodes(160, infoHash) { + ret = append(ret, n.addr) + } + return + }() + s.mu.Unlock() + for _, addr := range startAddrs { + disc.contact(addr) + } + if len(startAddrs) == 0 { + addr, err := bootstrapAddr() + if err != nil { + disc.Close() + return nil, err + } + disc.contact(addr) + } + return disc.peerStream, nil +} + +func (me *peerDiscovery) contact(addr net.Addr) { + select { + case me.contactAddrs <- addr: + case <-me.closingCh(): + } +} + +func (me *peerDiscovery) responseNode(node NodeInfo) { + me.contact(node.Addr) +} + +func (me *peerDiscovery) loop() { + for { + select { + case addr := <-me.contactAddrs: + if me.pending >= 160 { + break + } + if _, ok := me.triedAddrs[addr.String()]; ok { + break + } + me.triedAddrs[addr.String()] = struct{}{} + if err := me.getPeers(addr); err != nil { + log.Printf("error sending get_peers request to %s: %s", addr, err) + break + } + // log.Printf("contacting %s", addr) + me.pending++ + case <-me.transactionClosed: + me.pending-- + // log.Printf("pending: %d", me.pending) + if me.pending == 0 { + me.Close() + return + } + } + } +} + +func (me *peerDiscovery) closingCh() chan struct{} { + return me.peerStream.stop +} + +func (me *peerDiscovery) getPeers(addr net.Addr) error { + me.server.mu.Lock() + defer me.server.mu.Unlock() + t, err := me.server.getPeers(addr, me.infoHash) + if err != nil { + return err + } + go func() { + select { + case m := <-t.Response: + if nodes := m.Nodes(); len(nodes) != 0 { + for _, n := range nodes { + me.responseNode(n) + } + } + if vs := extractValues(m); vs != nil { + nodeInfo := NodeInfo{ + Addr: t.remoteAddr, + } + id := func() string { + defer func() { + recover() + }() + return m["r"].(map[string]interface{})["id"].(string) + }() + copy(nodeInfo.ID[:], id) + select { + case me.peerStream.Values <- peerStreamValue{ + Peers: vs, + NodeInfo: nodeInfo, + }: + case <-me.peerStream.stop: + } + } + case <-me.closingCh(): + } + t.Close() + me.transactionClosed <- struct{}{} + }() + return nil +} + +func (me *peerDiscovery) streamValue(psv peerStreamValue) { + me.peerStream.Values <- psv +} + +type peerStreamValue struct { + Peers []util.CompactPeer // Peers given in get_peers response. + NodeInfo // The node that gave the response. +} + +type peerStream struct { + mu sync.Mutex + Values chan peerStreamValue + stop chan struct{} +} + +func (ps *peerStream) Close() { + ps.mu.Lock() + select { + case <-ps.stop: + default: + close(ps.stop) + close(ps.Values) + } + ps.mu.Unlock() +}