Got dht-server working nicely

This commit is contained in:
Matt Joiner 2014-05-27 16:28:56 +10:00
parent 92b77a7cae
commit 83a02420a5
4 changed files with 210 additions and 121 deletions

View File

@ -23,7 +23,7 @@ func main() {
}
s := dht.Server{}
var err error
s.Socket, err = net.ListenPacket("udp4", "")
s.Socket, err = net.ListenUDP("udp4", nil)
if err != nil {
log.Fatal(err)
}

View File

@ -2,9 +2,13 @@ package main
import (
"bitbucket.org/anacrolix/go.torrent/dht"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"os/signal"
)
type pingResponse struct {
@ -12,47 +16,119 @@ type pingResponse struct {
krpc dht.Msg
}
func main() {
var (
tableFileName = flag.String("tableFile", "", "name of file for storing node info")
serveAddr = flag.String("serveAddr", ":0", "local UDP address")
s dht.Server
)
func loadTable() error {
if *tableFileName == "" {
return nil
}
f, err := os.Open(*tableFileName)
if os.IsNotExist(err) {
return nil
}
if err != nil {
return fmt.Errorf("error opening table file: %s", err)
}
defer f.Close()
added := 0
for {
b := make([]byte, dht.CompactNodeInfoLen)
_, err := io.ReadFull(f, b)
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error reading table file: %s", err)
}
var ni dht.NodeInfo
err = ni.UnmarshalCompact(b)
if err != nil {
return fmt.Errorf("error unmarshaling compact node info: %s", err)
}
s.AddNode(ni)
added++
}
log.Printf("loaded %d nodes from table file", added)
return nil
}
func init() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
s := dht.Server{}
var err error
s.Socket, err = net.ListenUDP("udp4", nil)
flag.Parse()
err := loadTable()
if err != nil {
log.Fatalf("error loading table: %s", err)
}
s.Socket, err = net.ListenUDP("udp4", func() *net.UDPAddr {
addr, err := net.ResolveUDPAddr("udp4", *serveAddr)
if err != nil {
log.Fatalf("error resolving serve addr: %s", err)
}
return addr
}())
if err != nil {
log.Fatal(err)
}
s.Init()
func() {
f, err := os.Open("nodes")
if os.IsNotExist(err) {
return
}
if err != nil {
log.Fatal(err)
}
defer f.Close()
err = s.ReadNodes(f)
if err != nil {
log.Fatal(err)
}
}()
log.Printf("dht server on %s", s.Socket.LocalAddr())
go func() {
err := s.Serve()
if err != nil {
log.Fatal(err)
s.Init()
setupSignals()
}
func saveTable() error {
goodNodes := s.GoodNodes()
if *tableFileName == "" {
if len(goodNodes) != 0 {
log.Printf("discarding %d good nodes!", len(goodNodes))
}
}()
err = s.Bootstrap()
func() {
f, err := os.OpenFile("nodes", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
if err != nil {
log.Print(err)
return
}
defer f.Close()
s.WriteNodes(f)
}()
return nil
}
f, err := os.OpenFile(*tableFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
if err != nil {
log.Fatal(err)
return fmt.Errorf("error opening table file: %s", err)
}
defer f.Close()
for _, nodeInfo := range goodNodes {
var b [dht.CompactNodeInfoLen]byte
err := nodeInfo.PutCompact(b[:])
if err != nil {
return fmt.Errorf("error compacting node info: %s", err)
}
_, err = f.Write(b[:])
if err != nil {
return fmt.Errorf("error writing compact node info: %s", err)
}
}
log.Printf("saved %d nodes to table file", len(goodNodes))
return nil
}
func setupSignals() {
ch := make(chan os.Signal)
signal.Notify(ch)
go func() {
<-ch
s.StopServing()
}()
}
func main() {
go func() {
err := s.Bootstrap()
if err != nil {
log.Printf("error bootstrapping: %s", err)
s.StopServing()
}
}()
err := s.Serve()
if err := saveTable(); err != nil {
log.Printf("error saving node table: %s", err)
}
if err != nil {
log.Fatalf("error serving dht: %s", err)
}
}

View File

@ -9,6 +9,7 @@ import (
"io"
"log"
"net"
"sync"
"time"
)
@ -18,6 +19,7 @@ type Server struct {
transactions []*transaction
transactionIDInt uint64
nodes map[string]*Node
mu sync.Mutex
}
type Node struct {
@ -27,6 +29,16 @@ type Node struct {
lastSentTo time.Time
}
func (n *Node) Good() bool {
if len(n.id) != 20 {
return false
}
if time.Now().Sub(n.lastHeardFrom) >= 15*time.Minute {
return false
}
return true
}
type Msg map[string]interface{}
var _ fmt.Stringer = Msg{}
@ -42,46 +54,6 @@ type transaction struct {
response chan Msg
}
func (s *Server) ReadNodes(r io.Reader) error {
for {
var b [compactNodeInfoLen]byte
_, err := io.ReadFull(r, b[:])
if err == io.EOF {
return nil
}
if err != nil {
return err
}
var cni compactNodeInfo
err = cni.UnmarshalBinary(b[:])
if err != nil {
return err
}
n := s.getNode(cni.Addr)
n.id = string(cni.ID[:])
}
}
func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
for _, node := range s.nodes {
cni := compactNodeInfo{
Addr: node.addr,
}
if n := copy(cni.ID[:], node.id); n != 20 {
panic(n)
}
var b [26]byte
cni.PutBinary(b[:])
var nn int
nn, err = w.Write(b[:])
if err != nil {
return
}
n += nn
}
return
}
func (s *Server) setDefaults() {
if s.ID == "" {
var id [20]byte
@ -95,7 +67,6 @@ func (s *Server) setDefaults() {
func (s *Server) Init() {
s.setDefaults()
s.nodes = make(map[string]*Node, 1000)
}
func (s *Server) Serve() error {
@ -111,13 +82,16 @@ func (s *Server) Serve() error {
log.Printf("bad krpc message: %s", err)
continue
}
s.mu.Lock()
if d["y"] == "q" {
s.handleQuery(addr, d)
s.mu.Unlock()
continue
}
t := s.findResponseTransaction(d["t"].(string), addr)
if t == nil {
log.Printf("unexpected message: %#v", d)
s.mu.Unlock()
continue
}
t.response <- d
@ -127,14 +101,26 @@ func (s *Server) Serve() error {
id = d["r"].(map[string]interface{})["id"].(string)
}
s.heardFromNode(addr, id)
s.mu.Unlock()
}
}
func (s *Server) AddNode(ni NodeInfo) {
if s.nodes == nil {
s.nodes = make(map[string]*Node)
}
n := s.getNode(ni.Addr)
if n.id == "" {
n.id = string(ni.ID[:])
}
}
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
log.Print(m["q"])
if m["q"] != "ping" {
return
}
s.heardFromNode(source, m["a"].(map[string]string)["id"])
s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
s.reply(source, m["t"].(string))
}
@ -254,30 +240,29 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
return
}
const compactNodeInfoLen = 26
const CompactNodeInfoLen = 26
type compactAddrInfo *net.UDPAddr
type compactNodeInfo struct {
type NodeInfo struct {
ID [20]byte
Addr compactAddrInfo
Addr *net.UDPAddr
}
func (cni *compactNodeInfo) PutBinary(b []byte) {
if n := copy(b[:], cni.ID[:]); n != 20 {
func (ni *NodeInfo) PutCompact(b []byte) error {
if n := copy(b[:], ni.ID[:]); n != 20 {
panic(n)
}
ip := cni.Addr.IP.To4()
ip := ni.Addr.IP.To4()
if len(ip) != 4 {
panic(ip)
}
if n := copy(b[20:], ip); n != 4 {
panic(n)
}
binary.BigEndian.PutUint16(b[24:], uint16(cni.Addr.Port))
binary.BigEndian.PutUint16(b[24:], uint16(ni.Addr.Port))
return nil
}
func (cni *compactNodeInfo) UnmarshalBinary(b []byte) error {
func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
if len(b) != 26 {
return errors.New("expected 26 bytes")
}
@ -297,7 +282,7 @@ func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
}
type findNodeResponse struct {
Nodes []compactNodeInfo
Nodes []NodeInfo
}
func getResponseNodes(m Msg) (s string, err error) {
@ -318,8 +303,8 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
return err
}
for i := 0; i < len(b); i += 26 {
var n compactNodeInfo
err := n.UnmarshalBinary([]byte(b[i : i+26]))
var n NodeInfo
err := n.UnmarshalCompact([]byte(b[i : i+26]))
if err != nil {
return err
}
@ -329,7 +314,6 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
}
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
// log.Print(addr)
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
if err != nil {
return
@ -348,10 +332,12 @@ func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, e
if err != nil {
log.Print(err)
} else {
s.mu.Lock()
for _, cni := range r.Nodes {
n := s.getNode(cni.Addr)
n.id = string(cni.ID[:])
}
s.mu.Unlock()
}
}
t.Response <- d
@ -359,33 +345,60 @@ func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, e
return
}
func (s *Server) Bootstrap() error {
if len(s.nodes) == 0 {
addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
if err != nil {
return err
}
s.nodes[addr.String()] = &Node{
addr: addr,
}
func (s *Server) addRootNode() error {
addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
if err != nil {
return err
}
queriedNodes := make(map[string]bool, 1000)
for i := 0; i < 3; i++ {
log.Printf("node table length: %d", len(s.nodes))
for _, node := range s.nodes {
if queriedNodes[node.addr.String()] {
continue
}
t, err := s.FindNode(node.addr, s.ID)
if err != nil {
return err
}
queriedNodes[node.addr.String()] = true
go func() {
<-t.Response
}()
}
time.Sleep(3 * time.Second)
s.nodes[addr.String()] = &Node{
addr: addr,
}
return nil
}
// Populates the node table.
func (s *Server) Bootstrap() (err error) {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.nodes) == 0 {
err = s.addRootNode()
if err != nil {
return
}
}
for _, node := range s.nodes {
var t *transaction
s.mu.Unlock()
t, err = s.FindNode(node.addr, s.ID)
s.mu.Lock()
if err != nil {
return
}
go func() {
<-t.Response
}()
}
return
}
func (s *Server) GoodNodes() (nis []NodeInfo) {
s.mu.Lock()
defer s.mu.Unlock()
for _, node := range s.nodes {
if !node.Good() {
continue
}
ni := NodeInfo{
Addr: node.addr,
}
if n := copy(ni.ID[:], node.id); n != 20 {
panic(n)
}
nis = append(nis, ni)
}
return
}
func (s *Server) StopServing() {
s.Socket.Close()
}

View File

@ -6,7 +6,7 @@ import (
)
func TestMarshalCompactNodeInfo(t *testing.T) {
cni := compactNodeInfo{
cni := NodeInfo{
ID: [20]byte{'a', 'b', 'c'},
}
var err error
@ -14,8 +14,8 @@ func TestMarshalCompactNodeInfo(t *testing.T) {
if err != nil {
t.Fatal(err)
}
var b [compactAddrInfoLen]byte
cni.PutBinary(b[:])
var b [CompactNodeInfoLen]byte
cni.PutCompact(b[:])
if err != nil {
t.Fatal(err)
}