dht-server: Save and load node table between invocations
This commit is contained in:
parent
1b69e69461
commit
92b77a7cae
|
@ -4,6 +4,7 @@ import (
|
||||||
"bitbucket.org/anacrolix/go.torrent/dht"
|
"bitbucket.org/anacrolix/go.torrent/dht"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
type pingResponse struct {
|
type pingResponse struct {
|
||||||
|
@ -20,6 +21,20 @@ func main() {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
s.Init()
|
s.Init()
|
||||||
|
func() {
|
||||||
|
f, err := os.Open("nodes")
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
err = s.ReadNodes(f)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
log.Printf("dht server on %s", s.Socket.LocalAddr())
|
log.Printf("dht server on %s", s.Socket.LocalAddr())
|
||||||
go func() {
|
go func() {
|
||||||
err := s.Serve()
|
err := s.Serve()
|
||||||
|
@ -28,8 +43,16 @@ func main() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
err = s.Bootstrap()
|
err = s.Bootstrap()
|
||||||
|
func() {
|
||||||
|
f, err := os.OpenFile("nodes", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
s.WriteNodes(f)
|
||||||
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
select {}
|
|
||||||
}
|
}
|
||||||
|
|
82
dht/dht.go
82
dht/dht.go
|
@ -42,6 +42,26 @@ type transaction struct {
|
||||||
response chan Msg
|
response chan Msg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) ReadNodes(r io.Reader) error {
|
||||||
|
for {
|
||||||
|
var b [compactNodeInfoLen]byte
|
||||||
|
_, err := io.ReadFull(r, b[:])
|
||||||
|
if err == io.EOF {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var cni compactNodeInfo
|
||||||
|
err = cni.UnmarshalBinary(b[:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n := s.getNode(cni.Addr)
|
||||||
|
n.id = string(cni.ID[:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
|
func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
|
||||||
for _, node := range s.nodes {
|
for _, node := range s.nodes {
|
||||||
cni := compactNodeInfo{
|
cni := compactNodeInfo{
|
||||||
|
@ -91,7 +111,15 @@ func (s *Server) Serve() error {
|
||||||
log.Printf("bad krpc message: %s", err)
|
log.Printf("bad krpc message: %s", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if d["y"] == "q" {
|
||||||
|
s.handleQuery(addr, d)
|
||||||
|
continue
|
||||||
|
}
|
||||||
t := s.findResponseTransaction(d["t"].(string), addr)
|
t := s.findResponseTransaction(d["t"].(string), addr)
|
||||||
|
if t == nil {
|
||||||
|
log.Printf("unexpected message: %#v", d)
|
||||||
|
continue
|
||||||
|
}
|
||||||
t.response <- d
|
t.response <- d
|
||||||
s.removeTransaction(t)
|
s.removeTransaction(t)
|
||||||
id := ""
|
id := ""
|
||||||
|
@ -102,6 +130,32 @@ func (s *Server) Serve() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
|
||||||
|
if m["q"] != "ping" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.heardFromNode(source, m["a"].(map[string]string)["id"])
|
||||||
|
s.reply(source, m["t"].(string))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) reply(addr *net.UDPAddr, t string) {
|
||||||
|
m := map[string]interface{}{
|
||||||
|
"t": t,
|
||||||
|
"y": "r",
|
||||||
|
"r": map[string]string{
|
||||||
|
"id": s.IDString(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, err := bencode.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
_, err = s.Socket.WriteTo(b, addr)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) heardFromNode(addr *net.UDPAddr, id string) {
|
func (s *Server) heardFromNode(addr *net.UDPAddr, id string) {
|
||||||
n := s.getNode(addr)
|
n := s.getNode(addr)
|
||||||
n.id = id
|
n.id = id
|
||||||
|
@ -200,7 +254,7 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const compactAddrInfoLen = 26
|
const compactNodeInfoLen = 26
|
||||||
|
|
||||||
type compactAddrInfo *net.UDPAddr
|
type compactAddrInfo *net.UDPAddr
|
||||||
|
|
||||||
|
@ -246,9 +300,23 @@ type findNodeResponse struct {
|
||||||
Nodes []compactNodeInfo
|
Nodes []compactNodeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getResponseNodes(m Msg) (s string, err error) {
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("couldn't get response nodes: %s: %#v", r, m)
|
||||||
|
}()
|
||||||
|
s = m["r"].(map[string]interface{})["nodes"].(string)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
|
func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
|
||||||
b := m["r"].(map[string]interface{})["nodes"].(string)
|
b, err := getResponseNodes(m)
|
||||||
log.Printf("%q", b)
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
for i := 0; i < len(b); i += 26 {
|
for i := 0; i < len(b); i += 26 {
|
||||||
var n compactNodeInfo
|
var n compactNodeInfo
|
||||||
err := n.UnmarshalBinary([]byte(b[i : i+26]))
|
err := n.UnmarshalBinary([]byte(b[i : i+26]))
|
||||||
|
@ -261,7 +329,7 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
|
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
|
||||||
log.Print(addr)
|
// log.Print(addr)
|
||||||
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
|
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -302,10 +370,10 @@ func (s *Server) Bootstrap() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
queriedNodes := make(map[string]bool, 1000)
|
queriedNodes := make(map[string]bool, 1000)
|
||||||
for {
|
for i := 0; i < 3; i++ {
|
||||||
|
log.Printf("node table length: %d", len(s.nodes))
|
||||||
for _, node := range s.nodes {
|
for _, node := range s.nodes {
|
||||||
if queriedNodes[node.addr.String()] {
|
if queriedNodes[node.addr.String()] {
|
||||||
log.Printf("skipping already queried: %s", node.addr)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
t, err := s.FindNode(node.addr, s.ID)
|
t, err := s.FindNode(node.addr, s.ID)
|
||||||
|
@ -314,7 +382,7 @@ func (s *Server) Bootstrap() error {
|
||||||
}
|
}
|
||||||
queriedNodes[node.addr.String()] = true
|
queriedNodes[node.addr.String()] = true
|
||||||
go func() {
|
go func() {
|
||||||
log.Print(<-t.Response)
|
<-t.Response
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
Loading…
Reference in New Issue