From 92b77a7cae0a1e2b5f895d17eabb47faae70495a Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Sun, 25 May 2014 23:04:55 +1000 Subject: [PATCH] dht-server: Save and load node table between invocations --- cmd/dht-server/main.go | 25 ++++++++++++- dht/dht.go | 82 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 99 insertions(+), 8 deletions(-) diff --git a/cmd/dht-server/main.go b/cmd/dht-server/main.go index 1e6abe72..c0ba99b3 100644 --- a/cmd/dht-server/main.go +++ b/cmd/dht-server/main.go @@ -4,6 +4,7 @@ import ( "bitbucket.org/anacrolix/go.torrent/dht" "log" "net" + "os" ) type pingResponse struct { @@ -20,6 +21,20 @@ func main() { log.Fatal(err) } s.Init() + func() { + f, err := os.Open("nodes") + if os.IsNotExist(err) { + return + } + if err != nil { + log.Fatal(err) + } + defer f.Close() + err = s.ReadNodes(f) + if err != nil { + log.Fatal(err) + } + }() log.Printf("dht server on %s", s.Socket.LocalAddr()) go func() { err := s.Serve() @@ -28,8 +43,16 @@ func main() { } }() err = s.Bootstrap() + func() { + f, err := os.OpenFile("nodes", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) + if err != nil { + log.Print(err) + return + } + defer f.Close() + s.WriteNodes(f) + }() if err != nil { log.Fatal(err) } - select {} } diff --git a/dht/dht.go b/dht/dht.go index 27c1a79e..cfc123ba 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -42,6 +42,26 @@ type transaction struct { response chan Msg } +func (s *Server) ReadNodes(r io.Reader) error { + for { + var b [compactNodeInfoLen]byte + _, err := io.ReadFull(r, b[:]) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + var cni compactNodeInfo + err = cni.UnmarshalBinary(b[:]) + if err != nil { + return err + } + n := s.getNode(cni.Addr) + n.id = string(cni.ID[:]) + } +} + func (s *Server) WriteNodes(w io.Writer) (n int, err error) { for _, node := range s.nodes { cni := compactNodeInfo{ @@ -91,7 +111,15 @@ func (s *Server) Serve() error { log.Printf("bad krpc message: %s", err) continue } + if d["y"] == "q" { + s.handleQuery(addr, d) + continue + } t := s.findResponseTransaction(d["t"].(string), addr) + if t == nil { + log.Printf("unexpected message: %#v", d) + continue + } t.response <- d s.removeTransaction(t) id := "" @@ -102,6 +130,32 @@ func (s *Server) Serve() error { } } +func (s *Server) handleQuery(source *net.UDPAddr, m Msg) { + if m["q"] != "ping" { + return + } + s.heardFromNode(source, m["a"].(map[string]string)["id"]) + s.reply(source, m["t"].(string)) +} + +func (s *Server) reply(addr *net.UDPAddr, t string) { + m := map[string]interface{}{ + "t": t, + "y": "r", + "r": map[string]string{ + "id": s.IDString(), + }, + } + b, err := bencode.Marshal(m) + if err != nil { + panic(err) + } + _, err = s.Socket.WriteTo(b, addr) + if err != nil { + panic(err) + } +} + func (s *Server) heardFromNode(addr *net.UDPAddr, id string) { n := s.getNode(addr) n.id = id @@ -200,7 +254,7 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra return } -const compactAddrInfoLen = 26 +const compactNodeInfoLen = 26 type compactAddrInfo *net.UDPAddr @@ -246,9 +300,23 @@ type findNodeResponse struct { Nodes []compactNodeInfo } +func getResponseNodes(m Msg) (s string, err error) { + defer func() { + r := recover() + if r == nil { + return + } + err = fmt.Errorf("couldn't get response nodes: %s: %#v", r, m) + }() + s = m["r"].(map[string]interface{})["nodes"].(string) + return +} + func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error { - b := m["r"].(map[string]interface{})["nodes"].(string) - log.Printf("%q", b) + b, err := getResponseNodes(m) + if err != nil { + return err + } for i := 0; i < len(b); i += 26 { var n compactNodeInfo err := n.UnmarshalBinary([]byte(b[i : i+26])) @@ -261,7 +329,7 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error { } func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) { - log.Print(addr) + // log.Print(addr) t, err = s.query(addr, "find_node", map[string]string{"target": targetID}) if err != nil { return @@ -302,10 +370,10 @@ func (s *Server) Bootstrap() error { } } queriedNodes := make(map[string]bool, 1000) - for { + for i := 0; i < 3; i++ { + log.Printf("node table length: %d", len(s.nodes)) for _, node := range s.nodes { if queriedNodes[node.addr.String()] { - log.Printf("skipping already queried: %s", node.addr) continue } t, err := s.FindNode(node.addr, s.ID) @@ -314,7 +382,7 @@ func (s *Server) Bootstrap() error { } queriedNodes[node.addr.String()] = true go func() { - log.Print(<-t.Response) + <-t.Response }() } time.Sleep(3 * time.Second)