dht: Concurrency improvements and fixes to bootstrapping and getting peers
This commit is contained in:
parent
ae45175015
commit
ba83f65ddf
143
dht/dht.go
143
dht/dht.go
|
@ -24,6 +24,11 @@ type Server struct {
|
||||||
transactionIDInt uint64
|
transactionIDInt uint64
|
||||||
nodes map[string]*Node
|
nodes map[string]*Node
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
closed chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) String() string {
|
||||||
|
return fmt.Sprintf("dht server on %s", s.Socket.LocalAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
type Node struct {
|
type Node struct {
|
||||||
|
@ -55,7 +60,15 @@ type transaction struct {
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
t string
|
t string
|
||||||
Response chan Msg
|
Response chan Msg
|
||||||
response chan Msg
|
onResponse func(Msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *transaction) handleResponse(m Msg) {
|
||||||
|
if t.onResponse != nil {
|
||||||
|
t.onResponse(m)
|
||||||
|
}
|
||||||
|
t.Response <- m
|
||||||
|
close(t.Response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) setDefaults() (err error) {
|
func (s *Server) setDefaults() (err error) {
|
||||||
|
@ -91,9 +104,13 @@ func (s *Server) setDefaults() (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Init() error {
|
func (s *Server) Init() (err error) {
|
||||||
return s.setDefaults()
|
err = s.setDefaults()
|
||||||
//s.nodes = make(map[string]*Node)
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.closed = make(chan struct{})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Serve() error {
|
func (s *Server) Serve() error {
|
||||||
|
@ -106,7 +123,7 @@ func (s *Server) Serve() error {
|
||||||
var d map[string]interface{}
|
var d map[string]interface{}
|
||||||
err = bencode.Unmarshal(b[:n], &d)
|
err = bencode.Unmarshal(b[:n], &d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("bad krpc message: %s: %q", err, b[:n])
|
log.Printf("%s: received bad krpc message: %s: %q", s, err, b[:n])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
@ -121,7 +138,7 @@ func (s *Server) Serve() error {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
t.response <- d
|
t.handleResponse(d)
|
||||||
s.removeTransaction(t)
|
s.removeTransaction(t)
|
||||||
id := ""
|
id := ""
|
||||||
if d["y"] == "r" {
|
if d["y"] == "r" {
|
||||||
|
@ -143,8 +160,8 @@ func (s *Server) AddNode(ni NodeInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
|
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
|
||||||
log.Print(m["q"])
|
|
||||||
if m["q"] != "ping" {
|
if m["q"] != "ping" {
|
||||||
|
log.Printf("%s: not handling received query: q=%s", s, m["q"])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
|
s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
|
||||||
|
@ -264,7 +281,6 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
|
||||||
t: tid,
|
t: tid,
|
||||||
Response: make(chan Msg, 1),
|
Response: make(chan Msg, 1),
|
||||||
}
|
}
|
||||||
t.response = t.Response
|
|
||||||
s.addTransaction(t)
|
s.addTransaction(t)
|
||||||
err = s.writeToNode(b, node)
|
err = s.writeToNode(b, node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -346,18 +362,54 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *transaction) onResponse(f func(m Msg)) {
|
func (t *transaction) setOnResponse(f func(m Msg)) {
|
||||||
ch := make(chan Msg)
|
if t.onResponse != nil {
|
||||||
t.response = ch
|
panic(t.onResponse)
|
||||||
go func() {
|
}
|
||||||
d, ok := <-t.response
|
t.onResponse = f
|
||||||
if !ok {
|
}
|
||||||
close(t.Response)
|
|
||||||
|
func unmarshalNodeInfoBinary(b []byte) (ret []NodeInfo, err error) {
|
||||||
|
if len(b)%26 != 0 {
|
||||||
|
err = errors.New("bad buffer length")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ret = make([]NodeInfo, 0, len(b)/26)
|
||||||
|
for i := 0; i < len(b); i += 26 {
|
||||||
|
var ni NodeInfo
|
||||||
|
err = ni.UnmarshalCompact(b[i : i+26])
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f(d)
|
ret = append(ret, ni)
|
||||||
t.Response <- d
|
}
|
||||||
}()
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractNodes(d Msg) (nodes []NodeInfo, err error) {
|
||||||
|
if d["y"] != "r" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r, ok := d["r"]
|
||||||
|
if !ok {
|
||||||
|
err = errors.New("missing r dict")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rd, ok := r.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
err = errors.New("bad r value type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n, ok := rd["nodes"]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ns, ok := n.(string)
|
||||||
|
if !ok {
|
||||||
|
err = errors.New("bad nodes value type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return unmarshalNodeInfoBinary([]byte(ns))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) liftNodes(d Msg) {
|
func (s *Server) liftNodes(d Msg) {
|
||||||
|
@ -369,25 +421,23 @@ func (s *Server) liftNodes(d Msg) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// log.Print(err)
|
// log.Print(err)
|
||||||
} else {
|
} else {
|
||||||
s.mu.Lock()
|
|
||||||
for _, cni := range r.Nodes {
|
for _, cni := range r.Nodes {
|
||||||
n := s.getNode(cni.Addr)
|
n := s.getNode(cni.Addr)
|
||||||
n.id = string(cni.ID[:])
|
n.id = string(cni.ID[:])
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
|
||||||
// log.Printf("lifted %d nodes", len(r.Nodes))
|
// log.Printf("lifted %d nodes", len(r.Nodes))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sends a find_node query to addr. targetID is the node we're looking for.
|
// Sends a find_node query to addr. targetID is the node we're looking for.
|
||||||
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
|
func (s *Server) findNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
|
||||||
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
|
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Scrape peers from the response to put in the server's table before
|
// Scrape peers from the response to put in the server's table before
|
||||||
// handing the response back to the caller.
|
// handing the response back to the caller.
|
||||||
t.onResponse(func(d Msg) {
|
t.setOnResponse(func(d Msg) {
|
||||||
s.liftNodes(d)
|
s.liftNodes(d)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -471,9 +521,10 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
|
||||||
case m := <-t.Response:
|
case m := <-t.Response:
|
||||||
vs := extractValues(m)
|
vs := extractValues(m)
|
||||||
if vs != nil {
|
if vs != nil {
|
||||||
ps.Values <- vs
|
select {
|
||||||
// } else {
|
case ps.Values <- vs:
|
||||||
// log.Print("get_peers response had no values")
|
case <-ps.stop:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case <-ps.stop:
|
case <-ps.stop:
|
||||||
}
|
}
|
||||||
|
@ -484,7 +535,10 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
go func() {
|
go func() {
|
||||||
for ; pending > 0; pending-- {
|
for ; pending > 0; pending-- {
|
||||||
<-done
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-s.closed:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ps.Close()
|
ps.Close()
|
||||||
}()
|
}()
|
||||||
|
@ -500,7 +554,7 @@ func (s *Server) getPeers(addr *net.UDPAddr, infoHash string) (t *transaction, e
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.onResponse(func(m Msg) {
|
t.setOnResponse(func(m Msg) {
|
||||||
s.liftNodes(m)
|
s.liftNodes(m)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -523,24 +577,38 @@ func (s *Server) Bootstrap() (err error) {
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if len(s.nodes) == 0 {
|
if len(s.nodes) == 0 {
|
||||||
err = s.addRootNode()
|
err = s.addRootNode()
|
||||||
if err != nil {
|
}
|
||||||
return
|
if err != nil {
|
||||||
}
|
return
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
|
var outstanding sync.WaitGroup
|
||||||
for _, node := range s.nodes {
|
for _, node := range s.nodes {
|
||||||
var t *transaction
|
var t *transaction
|
||||||
s.mu.Unlock()
|
t, err = s.findNode(node.addr, s.ID)
|
||||||
t, err = s.FindNode(node.addr, s.ID)
|
|
||||||
s.mu.Lock()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
outstanding.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
<-t.Response
|
<-t.Response
|
||||||
|
outstanding.Done()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
time.Sleep(5 * time.Second)
|
noOutstanding := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
outstanding.Wait()
|
||||||
|
close(noOutstanding)
|
||||||
|
}()
|
||||||
|
s.mu.Unlock()
|
||||||
|
select {
|
||||||
|
case <-s.closed:
|
||||||
|
s.mu.Lock()
|
||||||
|
return
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
case <-noOutstanding:
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
log.Printf("now have %d nodes", len(s.nodes))
|
log.Printf("now have %d nodes", len(s.nodes))
|
||||||
if len(s.nodes) >= 8*160 {
|
if len(s.nodes) >= 8*160 {
|
||||||
break
|
break
|
||||||
|
@ -569,6 +637,13 @@ func (s *Server) Nodes() (nis []NodeInfo) {
|
||||||
|
|
||||||
func (s *Server) StopServing() {
|
func (s *Server) StopServing() {
|
||||||
s.Socket.Close()
|
s.Socket.Close()
|
||||||
|
s.mu.Lock()
|
||||||
|
select {
|
||||||
|
case <-s.closed:
|
||||||
|
default:
|
||||||
|
close(s.closed)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func idDistance(a, b string) (ret int) {
|
func idDistance(a, b string) (ret int) {
|
||||||
|
|
Loading…
Reference in New Issue