dht-server: Save and load node table between invocations
This commit is contained in:
parent
1b69e69461
commit
92b77a7cae
|
@ -4,6 +4,7 @@ import (
|
|||
"bitbucket.org/anacrolix/go.torrent/dht"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
type pingResponse struct {
|
||||
|
@ -20,6 +21,20 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
s.Init()
|
||||
func() {
|
||||
f, err := os.Open("nodes")
|
||||
if os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
err = s.ReadNodes(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
log.Printf("dht server on %s", s.Socket.LocalAddr())
|
||||
go func() {
|
||||
err := s.Serve()
|
||||
|
@ -28,8 +43,16 @@ func main() {
|
|||
}
|
||||
}()
|
||||
err = s.Bootstrap()
|
||||
func() {
|
||||
f, err := os.OpenFile("nodes", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
s.WriteNodes(f)
|
||||
}()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
select {}
|
||||
}
|
||||
|
|
82
dht/dht.go
82
dht/dht.go
|
@ -42,6 +42,26 @@ type transaction struct {
|
|||
response chan Msg
|
||||
}
|
||||
|
||||
func (s *Server) ReadNodes(r io.Reader) error {
|
||||
for {
|
||||
var b [compactNodeInfoLen]byte
|
||||
_, err := io.ReadFull(r, b[:])
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var cni compactNodeInfo
|
||||
err = cni.UnmarshalBinary(b[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n := s.getNode(cni.Addr)
|
||||
n.id = string(cni.ID[:])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
|
||||
for _, node := range s.nodes {
|
||||
cni := compactNodeInfo{
|
||||
|
@ -91,7 +111,15 @@ func (s *Server) Serve() error {
|
|||
log.Printf("bad krpc message: %s", err)
|
||||
continue
|
||||
}
|
||||
if d["y"] == "q" {
|
||||
s.handleQuery(addr, d)
|
||||
continue
|
||||
}
|
||||
t := s.findResponseTransaction(d["t"].(string), addr)
|
||||
if t == nil {
|
||||
log.Printf("unexpected message: %#v", d)
|
||||
continue
|
||||
}
|
||||
t.response <- d
|
||||
s.removeTransaction(t)
|
||||
id := ""
|
||||
|
@ -102,6 +130,32 @@ func (s *Server) Serve() error {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
|
||||
if m["q"] != "ping" {
|
||||
return
|
||||
}
|
||||
s.heardFromNode(source, m["a"].(map[string]string)["id"])
|
||||
s.reply(source, m["t"].(string))
|
||||
}
|
||||
|
||||
func (s *Server) reply(addr *net.UDPAddr, t string) {
|
||||
m := map[string]interface{}{
|
||||
"t": t,
|
||||
"y": "r",
|
||||
"r": map[string]string{
|
||||
"id": s.IDString(),
|
||||
},
|
||||
}
|
||||
b, err := bencode.Marshal(m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = s.Socket.WriteTo(b, addr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) heardFromNode(addr *net.UDPAddr, id string) {
|
||||
n := s.getNode(addr)
|
||||
n.id = id
|
||||
|
@ -200,7 +254,7 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
|
|||
return
|
||||
}
|
||||
|
||||
const compactAddrInfoLen = 26
|
||||
const compactNodeInfoLen = 26
|
||||
|
||||
type compactAddrInfo *net.UDPAddr
|
||||
|
||||
|
@ -246,9 +300,23 @@ type findNodeResponse struct {
|
|||
Nodes []compactNodeInfo
|
||||
}
|
||||
|
||||
func getResponseNodes(m Msg) (s string, err error) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("couldn't get response nodes: %s: %#v", r, m)
|
||||
}()
|
||||
s = m["r"].(map[string]interface{})["nodes"].(string)
|
||||
return
|
||||
}
|
||||
|
||||
func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
|
||||
b := m["r"].(map[string]interface{})["nodes"].(string)
|
||||
log.Printf("%q", b)
|
||||
b, err := getResponseNodes(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := 0; i < len(b); i += 26 {
|
||||
var n compactNodeInfo
|
||||
err := n.UnmarshalBinary([]byte(b[i : i+26]))
|
||||
|
@ -261,7 +329,7 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
|
|||
}
|
||||
|
||||
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
|
||||
log.Print(addr)
|
||||
// log.Print(addr)
|
||||
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -302,10 +370,10 @@ func (s *Server) Bootstrap() error {
|
|||
}
|
||||
}
|
||||
queriedNodes := make(map[string]bool, 1000)
|
||||
for {
|
||||
for i := 0; i < 3; i++ {
|
||||
log.Printf("node table length: %d", len(s.nodes))
|
||||
for _, node := range s.nodes {
|
||||
if queriedNodes[node.addr.String()] {
|
||||
log.Printf("skipping already queried: %s", node.addr)
|
||||
continue
|
||||
}
|
||||
t, err := s.FindNode(node.addr, s.ID)
|
||||
|
@ -314,7 +382,7 @@ func (s *Server) Bootstrap() error {
|
|||
}
|
||||
queriedNodes[node.addr.String()] = true
|
||||
go func() {
|
||||
log.Print(<-t.Response)
|
||||
<-t.Response
|
||||
}()
|
||||
}
|
||||
time.Sleep(3 * time.Second)
|
||||
|
|
Loading…
Reference in New Issue