dht-server: Save and load node table between invocations

This commit is contained in:
Matt Joiner 2014-05-25 23:04:55 +10:00
parent 1b69e69461
commit 92b77a7cae
2 changed files with 99 additions and 8 deletions

View File

@ -4,6 +4,7 @@ import (
"bitbucket.org/anacrolix/go.torrent/dht" "bitbucket.org/anacrolix/go.torrent/dht"
"log" "log"
"net" "net"
"os"
) )
type pingResponse struct { type pingResponse struct {
@ -20,6 +21,20 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
s.Init() s.Init()
func() {
f, err := os.Open("nodes")
if os.IsNotExist(err) {
return
}
if err != nil {
log.Fatal(err)
}
defer f.Close()
err = s.ReadNodes(f)
if err != nil {
log.Fatal(err)
}
}()
log.Printf("dht server on %s", s.Socket.LocalAddr()) log.Printf("dht server on %s", s.Socket.LocalAddr())
go func() { go func() {
err := s.Serve() err := s.Serve()
@ -28,8 +43,16 @@ func main() {
} }
}() }()
err = s.Bootstrap() err = s.Bootstrap()
func() {
f, err := os.OpenFile("nodes", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
if err != nil {
log.Print(err)
return
}
defer f.Close()
s.WriteNodes(f)
}()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
select {}
} }

View File

@ -42,6 +42,26 @@ type transaction struct {
response chan Msg response chan Msg
} }
func (s *Server) ReadNodes(r io.Reader) error {
for {
var b [compactNodeInfoLen]byte
_, err := io.ReadFull(r, b[:])
if err == io.EOF {
return nil
}
if err != nil {
return err
}
var cni compactNodeInfo
err = cni.UnmarshalBinary(b[:])
if err != nil {
return err
}
n := s.getNode(cni.Addr)
n.id = string(cni.ID[:])
}
}
func (s *Server) WriteNodes(w io.Writer) (n int, err error) { func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
for _, node := range s.nodes { for _, node := range s.nodes {
cni := compactNodeInfo{ cni := compactNodeInfo{
@ -91,7 +111,15 @@ func (s *Server) Serve() error {
log.Printf("bad krpc message: %s", err) log.Printf("bad krpc message: %s", err)
continue continue
} }
if d["y"] == "q" {
s.handleQuery(addr, d)
continue
}
t := s.findResponseTransaction(d["t"].(string), addr) t := s.findResponseTransaction(d["t"].(string), addr)
if t == nil {
log.Printf("unexpected message: %#v", d)
continue
}
t.response <- d t.response <- d
s.removeTransaction(t) s.removeTransaction(t)
id := "" id := ""
@ -102,6 +130,32 @@ func (s *Server) Serve() error {
} }
} }
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
if m["q"] != "ping" {
return
}
s.heardFromNode(source, m["a"].(map[string]string)["id"])
s.reply(source, m["t"].(string))
}
func (s *Server) reply(addr *net.UDPAddr, t string) {
m := map[string]interface{}{
"t": t,
"y": "r",
"r": map[string]string{
"id": s.IDString(),
},
}
b, err := bencode.Marshal(m)
if err != nil {
panic(err)
}
_, err = s.Socket.WriteTo(b, addr)
if err != nil {
panic(err)
}
}
func (s *Server) heardFromNode(addr *net.UDPAddr, id string) { func (s *Server) heardFromNode(addr *net.UDPAddr, id string) {
n := s.getNode(addr) n := s.getNode(addr)
n.id = id n.id = id
@ -200,7 +254,7 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
return return
} }
const compactAddrInfoLen = 26 const compactNodeInfoLen = 26
type compactAddrInfo *net.UDPAddr type compactAddrInfo *net.UDPAddr
@ -246,9 +300,23 @@ type findNodeResponse struct {
Nodes []compactNodeInfo Nodes []compactNodeInfo
} }
func getResponseNodes(m Msg) (s string, err error) {
defer func() {
r := recover()
if r == nil {
return
}
err = fmt.Errorf("couldn't get response nodes: %s: %#v", r, m)
}()
s = m["r"].(map[string]interface{})["nodes"].(string)
return
}
func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error { func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
b := m["r"].(map[string]interface{})["nodes"].(string) b, err := getResponseNodes(m)
log.Printf("%q", b) if err != nil {
return err
}
for i := 0; i < len(b); i += 26 { for i := 0; i < len(b); i += 26 {
var n compactNodeInfo var n compactNodeInfo
err := n.UnmarshalBinary([]byte(b[i : i+26])) err := n.UnmarshalBinary([]byte(b[i : i+26]))
@ -261,7 +329,7 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
} }
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) { func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
log.Print(addr) // log.Print(addr)
t, err = s.query(addr, "find_node", map[string]string{"target": targetID}) t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
if err != nil { if err != nil {
return return
@ -302,10 +370,10 @@ func (s *Server) Bootstrap() error {
} }
} }
queriedNodes := make(map[string]bool, 1000) queriedNodes := make(map[string]bool, 1000)
for { for i := 0; i < 3; i++ {
log.Printf("node table length: %d", len(s.nodes))
for _, node := range s.nodes { for _, node := range s.nodes {
if queriedNodes[node.addr.String()] { if queriedNodes[node.addr.String()] {
log.Printf("skipping already queried: %s", node.addr)
continue continue
} }
t, err := s.FindNode(node.addr, s.ID) t, err := s.FindNode(node.addr, s.ID)
@ -314,7 +382,7 @@ func (s *Server) Bootstrap() error {
} }
queriedNodes[node.addr.String()] = true queriedNodes[node.addr.String()] = true
go func() { go func() {
log.Print(<-t.Response) <-t.Response
}() }()
} }
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)