Got dht-server working nicely
This commit is contained in:
parent
92b77a7cae
commit
83a02420a5
|
@ -23,7 +23,7 @@ func main() {
|
|||
}
|
||||
s := dht.Server{}
|
||||
var err error
|
||||
s.Socket, err = net.ListenPacket("udp4", "")
|
||||
s.Socket, err = net.ListenUDP("udp4", nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -2,9 +2,13 @@ package main
|
|||
|
||||
import (
|
||||
"bitbucket.org/anacrolix/go.torrent/dht"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
)
|
||||
|
||||
type pingResponse struct {
|
||||
|
@ -12,47 +16,119 @@ type pingResponse struct {
|
|||
krpc dht.Msg
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
tableFileName = flag.String("tableFile", "", "name of file for storing node info")
|
||||
serveAddr = flag.String("serveAddr", ":0", "local UDP address")
|
||||
|
||||
s dht.Server
|
||||
)
|
||||
|
||||
func loadTable() error {
|
||||
if *tableFileName == "" {
|
||||
return nil
|
||||
}
|
||||
f, err := os.Open(*tableFileName)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening table file: %s", err)
|
||||
}
|
||||
defer f.Close()
|
||||
added := 0
|
||||
for {
|
||||
b := make([]byte, dht.CompactNodeInfoLen)
|
||||
_, err := io.ReadFull(f, b)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading table file: %s", err)
|
||||
}
|
||||
var ni dht.NodeInfo
|
||||
err = ni.UnmarshalCompact(b)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling compact node info: %s", err)
|
||||
}
|
||||
s.AddNode(ni)
|
||||
added++
|
||||
}
|
||||
log.Printf("loaded %d nodes from table file", added)
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
s := dht.Server{}
|
||||
var err error
|
||||
s.Socket, err = net.ListenUDP("udp4", nil)
|
||||
flag.Parse()
|
||||
err := loadTable()
|
||||
if err != nil {
|
||||
log.Fatalf("error loading table: %s", err)
|
||||
}
|
||||
s.Socket, err = net.ListenUDP("udp4", func() *net.UDPAddr {
|
||||
addr, err := net.ResolveUDPAddr("udp4", *serveAddr)
|
||||
if err != nil {
|
||||
log.Fatalf("error resolving serve addr: %s", err)
|
||||
}
|
||||
return addr
|
||||
}())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
s.Init()
|
||||
func() {
|
||||
f, err := os.Open("nodes")
|
||||
if os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
err = s.ReadNodes(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
log.Printf("dht server on %s", s.Socket.LocalAddr())
|
||||
go func() {
|
||||
err := s.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
s.Init()
|
||||
setupSignals()
|
||||
}
|
||||
|
||||
func saveTable() error {
|
||||
goodNodes := s.GoodNodes()
|
||||
if *tableFileName == "" {
|
||||
if len(goodNodes) != 0 {
|
||||
log.Printf("discarding %d good nodes!", len(goodNodes))
|
||||
}
|
||||
}()
|
||||
err = s.Bootstrap()
|
||||
func() {
|
||||
f, err := os.OpenFile("nodes", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
s.WriteNodes(f)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
f, err := os.OpenFile(*tableFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return fmt.Errorf("error opening table file: %s", err)
|
||||
}
|
||||
defer f.Close()
|
||||
for _, nodeInfo := range goodNodes {
|
||||
var b [dht.CompactNodeInfoLen]byte
|
||||
err := nodeInfo.PutCompact(b[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("error compacting node info: %s", err)
|
||||
}
|
||||
_, err = f.Write(b[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing compact node info: %s", err)
|
||||
}
|
||||
}
|
||||
log.Printf("saved %d nodes to table file", len(goodNodes))
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupSignals() {
|
||||
ch := make(chan os.Signal)
|
||||
signal.Notify(ch)
|
||||
go func() {
|
||||
<-ch
|
||||
s.StopServing()
|
||||
}()
|
||||
}
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
err := s.Bootstrap()
|
||||
if err != nil {
|
||||
log.Printf("error bootstrapping: %s", err)
|
||||
s.StopServing()
|
||||
}
|
||||
}()
|
||||
err := s.Serve()
|
||||
if err := saveTable(); err != nil {
|
||||
log.Printf("error saving node table: %s", err)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("error serving dht: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
177
dht/dht.go
177
dht/dht.go
|
@ -9,6 +9,7 @@ import (
|
|||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -18,6 +19,7 @@ type Server struct {
|
|||
transactions []*transaction
|
||||
transactionIDInt uint64
|
||||
nodes map[string]*Node
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
|
@ -27,6 +29,16 @@ type Node struct {
|
|||
lastSentTo time.Time
|
||||
}
|
||||
|
||||
func (n *Node) Good() bool {
|
||||
if len(n.id) != 20 {
|
||||
return false
|
||||
}
|
||||
if time.Now().Sub(n.lastHeardFrom) >= 15*time.Minute {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type Msg map[string]interface{}
|
||||
|
||||
var _ fmt.Stringer = Msg{}
|
||||
|
@ -42,46 +54,6 @@ type transaction struct {
|
|||
response chan Msg
|
||||
}
|
||||
|
||||
func (s *Server) ReadNodes(r io.Reader) error {
|
||||
for {
|
||||
var b [compactNodeInfoLen]byte
|
||||
_, err := io.ReadFull(r, b[:])
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var cni compactNodeInfo
|
||||
err = cni.UnmarshalBinary(b[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n := s.getNode(cni.Addr)
|
||||
n.id = string(cni.ID[:])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
|
||||
for _, node := range s.nodes {
|
||||
cni := compactNodeInfo{
|
||||
Addr: node.addr,
|
||||
}
|
||||
if n := copy(cni.ID[:], node.id); n != 20 {
|
||||
panic(n)
|
||||
}
|
||||
var b [26]byte
|
||||
cni.PutBinary(b[:])
|
||||
var nn int
|
||||
nn, err = w.Write(b[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += nn
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) setDefaults() {
|
||||
if s.ID == "" {
|
||||
var id [20]byte
|
||||
|
@ -95,7 +67,6 @@ func (s *Server) setDefaults() {
|
|||
|
||||
func (s *Server) Init() {
|
||||
s.setDefaults()
|
||||
s.nodes = make(map[string]*Node, 1000)
|
||||
}
|
||||
|
||||
func (s *Server) Serve() error {
|
||||
|
@ -111,13 +82,16 @@ func (s *Server) Serve() error {
|
|||
log.Printf("bad krpc message: %s", err)
|
||||
continue
|
||||
}
|
||||
s.mu.Lock()
|
||||
if d["y"] == "q" {
|
||||
s.handleQuery(addr, d)
|
||||
s.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
t := s.findResponseTransaction(d["t"].(string), addr)
|
||||
if t == nil {
|
||||
log.Printf("unexpected message: %#v", d)
|
||||
s.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
t.response <- d
|
||||
|
@ -127,14 +101,26 @@ func (s *Server) Serve() error {
|
|||
id = d["r"].(map[string]interface{})["id"].(string)
|
||||
}
|
||||
s.heardFromNode(addr, id)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) AddNode(ni NodeInfo) {
|
||||
if s.nodes == nil {
|
||||
s.nodes = make(map[string]*Node)
|
||||
}
|
||||
n := s.getNode(ni.Addr)
|
||||
if n.id == "" {
|
||||
n.id = string(ni.ID[:])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
|
||||
log.Print(m["q"])
|
||||
if m["q"] != "ping" {
|
||||
return
|
||||
}
|
||||
s.heardFromNode(source, m["a"].(map[string]string)["id"])
|
||||
s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
|
||||
s.reply(source, m["t"].(string))
|
||||
}
|
||||
|
||||
|
@ -254,30 +240,29 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
|
|||
return
|
||||
}
|
||||
|
||||
const compactNodeInfoLen = 26
|
||||
const CompactNodeInfoLen = 26
|
||||
|
||||
type compactAddrInfo *net.UDPAddr
|
||||
|
||||
type compactNodeInfo struct {
|
||||
type NodeInfo struct {
|
||||
ID [20]byte
|
||||
Addr compactAddrInfo
|
||||
Addr *net.UDPAddr
|
||||
}
|
||||
|
||||
func (cni *compactNodeInfo) PutBinary(b []byte) {
|
||||
if n := copy(b[:], cni.ID[:]); n != 20 {
|
||||
func (ni *NodeInfo) PutCompact(b []byte) error {
|
||||
if n := copy(b[:], ni.ID[:]); n != 20 {
|
||||
panic(n)
|
||||
}
|
||||
ip := cni.Addr.IP.To4()
|
||||
ip := ni.Addr.IP.To4()
|
||||
if len(ip) != 4 {
|
||||
panic(ip)
|
||||
}
|
||||
if n := copy(b[20:], ip); n != 4 {
|
||||
panic(n)
|
||||
}
|
||||
binary.BigEndian.PutUint16(b[24:], uint16(cni.Addr.Port))
|
||||
binary.BigEndian.PutUint16(b[24:], uint16(ni.Addr.Port))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cni *compactNodeInfo) UnmarshalBinary(b []byte) error {
|
||||
func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
|
||||
if len(b) != 26 {
|
||||
return errors.New("expected 26 bytes")
|
||||
}
|
||||
|
@ -297,7 +282,7 @@ func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
|
|||
}
|
||||
|
||||
type findNodeResponse struct {
|
||||
Nodes []compactNodeInfo
|
||||
Nodes []NodeInfo
|
||||
}
|
||||
|
||||
func getResponseNodes(m Msg) (s string, err error) {
|
||||
|
@ -318,8 +303,8 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
|
|||
return err
|
||||
}
|
||||
for i := 0; i < len(b); i += 26 {
|
||||
var n compactNodeInfo
|
||||
err := n.UnmarshalBinary([]byte(b[i : i+26]))
|
||||
var n NodeInfo
|
||||
err := n.UnmarshalCompact([]byte(b[i : i+26]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -329,7 +314,6 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
|
|||
}
|
||||
|
||||
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
|
||||
// log.Print(addr)
|
||||
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -348,10 +332,12 @@ func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, e
|
|||
if err != nil {
|
||||
log.Print(err)
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
for _, cni := range r.Nodes {
|
||||
n := s.getNode(cni.Addr)
|
||||
n.id = string(cni.ID[:])
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
t.Response <- d
|
||||
|
@ -359,33 +345,60 @@ func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, e
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Server) Bootstrap() error {
|
||||
if len(s.nodes) == 0 {
|
||||
addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.nodes[addr.String()] = &Node{
|
||||
addr: addr,
|
||||
}
|
||||
func (s *Server) addRootNode() error {
|
||||
addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
queriedNodes := make(map[string]bool, 1000)
|
||||
for i := 0; i < 3; i++ {
|
||||
log.Printf("node table length: %d", len(s.nodes))
|
||||
for _, node := range s.nodes {
|
||||
if queriedNodes[node.addr.String()] {
|
||||
continue
|
||||
}
|
||||
t, err := s.FindNode(node.addr, s.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
queriedNodes[node.addr.String()] = true
|
||||
go func() {
|
||||
<-t.Response
|
||||
}()
|
||||
}
|
||||
time.Sleep(3 * time.Second)
|
||||
s.nodes[addr.String()] = &Node{
|
||||
addr: addr,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Populates the node table.
|
||||
func (s *Server) Bootstrap() (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if len(s.nodes) == 0 {
|
||||
err = s.addRootNode()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, node := range s.nodes {
|
||||
var t *transaction
|
||||
s.mu.Unlock()
|
||||
t, err = s.FindNode(node.addr, s.ID)
|
||||
s.mu.Lock()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
<-t.Response
|
||||
}()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) GoodNodes() (nis []NodeInfo) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, node := range s.nodes {
|
||||
if !node.Good() {
|
||||
continue
|
||||
}
|
||||
ni := NodeInfo{
|
||||
Addr: node.addr,
|
||||
}
|
||||
if n := copy(ni.ID[:], node.id); n != 20 {
|
||||
panic(n)
|
||||
}
|
||||
nis = append(nis, ni)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) StopServing() {
|
||||
s.Socket.Close()
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
)
|
||||
|
||||
func TestMarshalCompactNodeInfo(t *testing.T) {
|
||||
cni := compactNodeInfo{
|
||||
cni := NodeInfo{
|
||||
ID: [20]byte{'a', 'b', 'c'},
|
||||
}
|
||||
var err error
|
||||
|
@ -14,8 +14,8 @@ func TestMarshalCompactNodeInfo(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var b [compactAddrInfoLen]byte
|
||||
cni.PutBinary(b[:])
|
||||
var b [CompactNodeInfoLen]byte
|
||||
cni.PutCompact(b[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue