From 87c9697a26b9f7a6d53a9ba6818773845844567c Mon Sep 17 00:00:00 2001 From: Dmitry Borzov Date: Sun, 6 Dec 2015 17:56:46 +0300 Subject: [PATCH] Move dht.Server defs to a separate file --- dht/dht.go | 706 +------------------------------------------------- dht/doc.go | 2 +- dht/server.go | 701 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 709 insertions(+), 700 deletions(-) create mode 100644 dht/server.go diff --git a/dht/dht.go b/dht/dht.go index a1ceeb49..4bc4a5c6 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -1,28 +1,17 @@ package dht import ( - "crypto" _ "crypto/sha1" - "encoding/binary" "errors" "fmt" "hash/crc32" - "io" - "log" "math/big" "math/rand" "net" - "os" "strconv" "time" - "github.com/anacrolix/missinggo" - "github.com/anacrolix/sync" - "github.com/tylertreat/BoomFilters" - - "github.com/anacrolix/torrent/bencode" "github.com/anacrolix/torrent/iplist" - "github.com/anacrolix/torrent/logonce" ) const ( @@ -33,36 +22,19 @@ var ( queryResendEvery = 5 * time.Second ) +var maxDistance big.Int + +func init() { + var zero big.Int + maxDistance.SetBit(&zero, 160, 1) +} + // Uniquely identifies a transaction to us. type transactionKey struct { RemoteAddr string // host:port T string // The KRPC transaction ID. } -// A Server defines parameters for a DHT node server that is able to -// send queries, and respond to the ones from the network. -// Each node has a globally unique identifier known as the "node ID." -// Node IDs are chosen at random from the same 160-bit space -// as BitTorrent infohashes and define the behaviour of the node. -// Zero valued Server does not have a valid ID and thus -// is unable to function properly. Use `NewServer(nil)` -// to initialize a default node. -type Server struct { - id string - socket net.PacketConn - transactions map[transactionKey]*Transaction - transactionIDInt uint64 - nodes map[string]*node // Keyed by dHTAddr.String(). - mu sync.Mutex - closed chan struct{} - ipBlockList iplist.Ranger - badNodes *boom.BloomFilter - - numConfirmedAnnounces int - bootstrapNodes []string - config ServerConfig -} - // ServerConfig allows to set up a configuration of the `Server` instance // to be created with NewServer type ServerConfig struct { @@ -97,28 +69,6 @@ type ServerStats struct { BadNodes uint } -// Stats returns statistics for the server. -func (s *Server) Stats() (ss ServerStats) { - s.mu.Lock() - defer s.mu.Unlock() - for _, n := range s.nodes { - if n.DefinitelyGood() { - ss.GoodNodes++ - } - } - ss.Nodes = len(s.nodes) - ss.OutstandingTransactions = len(s.transactions) - ss.ConfirmedAnnounces = s.numConfirmedAnnounces - ss.BadNodes = s.badNodes.Count() - return -} - -// Addr returns the listen address for the server. Packets arriving to this address -// are processed by the server (unless aliens are involved). -func (s *Server) Addr() net.Addr { - return s.socket.LocalAddr() -} - func makeSocket(addr string) (socket *net.UDPConn, err error) { addr_, err := net.ResolveUDPAddr("", addr) if err != nil { @@ -128,58 +78,6 @@ func makeSocket(addr string) (socket *net.UDPConn, err error) { return } -// NewServer initializes a new DHT node server. -func NewServer(c *ServerConfig) (s *Server, err error) { - if c == nil { - c = &ServerConfig{} - } - s = &Server{ - config: *c, - ipBlockList: c.IPBlocklist, - badNodes: boom.NewBloomFilter(1000, 0.1), - } - if c.Conn != nil { - s.socket = c.Conn - } else { - s.socket, err = makeSocket(c.Addr) - if err != nil { - return - } - } - s.bootstrapNodes = c.BootstrapNodes - err = s.init() - if err != nil { - return - } - go func() { - err := s.serve() - select { - case <-s.closed: - return - default: - } - if err != nil { - panic(err) - } - }() - go func() { - err := s.bootstrap() - if err != nil { - select { - case <-s.closed: - default: - log.Printf("error bootstrapping DHT: %s", err) - } - } - }() - return -} - -// Returns a description of the Server. Python repr-style. -func (s *Server) String() string { - return fmt.Sprintf("dht server on %s", s.socket.LocalAddr()) -} - type nodeID struct { i big.Int set bool @@ -333,430 +231,6 @@ func NodeIdSecure(id string, ip net.IP) bool { return true } -func (s *Server) setDefaults() (err error) { - if s.id == "" { - var id [20]byte - h := crypto.SHA1.New() - ss, err := os.Hostname() - if err != nil { - log.Print(err) - } - ss += s.socket.LocalAddr().String() - h.Write([]byte(ss)) - if b := h.Sum(id[:0:20]); len(b) != 20 { - panic(len(b)) - } - if len(id) != 20 { - panic(len(id)) - } - publicIP := func() net.IP { - if s.config.PublicIP != nil { - return s.config.PublicIP - } else { - return missinggo.AddrIP(s.socket.LocalAddr()) - } - }() - SecureNodeId(id[:], publicIP) - s.id = string(id[:]) - } - s.nodes = make(map[string]*node, maxNodes) - return -} - -// Packets to and from any address matching a range in the list are dropped. -func (s *Server) SetIPBlockList(list iplist.Ranger) { - s.mu.Lock() - defer s.mu.Unlock() - s.ipBlockList = list -} - -func (s *Server) IPBlocklist() iplist.Ranger { - return s.ipBlockList -} - -func (s *Server) init() (err error) { - err = s.setDefaults() - if err != nil { - return - } - s.closed = make(chan struct{}) - s.transactions = make(map[transactionKey]*Transaction) - return -} - -func (s *Server) processPacket(b []byte, addr dHTAddr) { - if len(b) < 2 || b[0] != 'd' || b[len(b)-1] != 'e' { - // KRPC messages are bencoded dicts. - readNotKRPCDict.Add(1) - return - } - var d Msg - err := bencode.Unmarshal(b, &d) - if err != nil { - readUnmarshalError.Add(1) - func() { - if se, ok := err.(*bencode.SyntaxError); ok { - // The message was truncated. - if int(se.Offset) == len(b) { - return - } - // Some messages seem to drop to nul chars abrubtly. - if int(se.Offset) < len(b) && b[se.Offset] == 0 { - return - } - // The message isn't bencode from the first. - if se.Offset == 0 { - return - } - } - // if missinggo.CryHeard() { - // log.Printf("%s: received bad krpc message from %s: %s: %+q", s, addr, err, b) - // } - }() - return - } - s.mu.Lock() - defer s.mu.Unlock() - if d.Y == "q" { - readQuery.Add(1) - s.handleQuery(addr, d) - return - } - t := s.findResponseTransaction(d.T, addr) - if t == nil { - //log.Printf("unexpected message: %#v", d) - return - } - node := s.getNode(addr, d.SenderID()) - node.lastGotResponse = time.Now() - // TODO: Update node ID as this is an authoritative packet. - go t.handleResponse(d) - s.deleteTransaction(t) -} - -func (s *Server) serve() error { - var b [0x10000]byte - for { - n, addr, err := s.socket.ReadFrom(b[:]) - if err != nil { - return err - } - read.Add(1) - if n == len(b) { - logonce.Stderr.Printf("received dht packet exceeds buffer size") - continue - } - s.mu.Lock() - blocked := s.ipBlocked(missinggo.AddrIP(addr)) - s.mu.Unlock() - if blocked { - readBlocked.Add(1) - continue - } - s.processPacket(b[:n], newDHTAddr(addr)) - } -} - -func (s *Server) ipBlocked(ip net.IP) (blocked bool) { - if s.ipBlockList == nil { - return - } - _, blocked = s.ipBlockList.Lookup(ip) - return -} - -// Adds directly to the node table. -func (s *Server) AddNode(ni NodeInfo) { - s.mu.Lock() - defer s.mu.Unlock() - if s.nodes == nil { - s.nodes = make(map[string]*node) - } - s.getNode(ni.Addr, string(ni.ID[:])) -} - -func (s *Server) nodeByID(id string) *node { - for _, node := range s.nodes { - if node.idString() == id { - return node - } - } - return nil -} - -func (s *Server) handleQuery(source dHTAddr, m Msg) { - node := s.getNode(source, m.SenderID()) - node.lastGotQuery = time.Now() - // Don't respond. - if s.config.Passive { - return - } - args := m.A - switch m.Q { - case "ping": - s.reply(source, m.T, Return{}) - case "get_peers": // TODO: Extract common behaviour with find_node. - targetID := args.InfoHash - if len(targetID) != 20 { - break - } - var rNodes []NodeInfo - // TODO: Reply with "values" list if we have peers instead. - for _, node := range s.closestGoodNodes(8, targetID) { - rNodes = append(rNodes, node.NodeInfo()) - } - s.reply(source, m.T, Return{ - Nodes: rNodes, - // TODO: Generate this dynamically, and store it for the source. - Token: "hi", - }) - case "find_node": // TODO: Extract common behaviour with get_peers. - targetID := args.Target - if len(targetID) != 20 { - log.Printf("bad DHT query: %v", m) - return - } - var rNodes []NodeInfo - if node := s.nodeByID(targetID); node != nil { - rNodes = append(rNodes, node.NodeInfo()) - } else { - // This will probably cause a crash for IPv6, but meh. - for _, node := range s.closestGoodNodes(8, targetID) { - rNodes = append(rNodes, node.NodeInfo()) - } - } - s.reply(source, m.T, Return{ - Nodes: rNodes, - }) - case "announce_peer": - // TODO(anacrolix): Implement this lolz. - // log.Print(m) - case "vote": - // TODO(anacrolix): Or reject, I don't think I want this. - default: - log.Printf("%s: not handling received query: q=%s", s, m.Q) - return - } -} - -func (s *Server) reply(addr dHTAddr, t string, r Return) { - r.ID = s.ID() - m := Msg{ - T: t, - Y: "r", - R: &r, - } - b, err := bencode.Marshal(m) - if err != nil { - panic(err) - } - err = s.writeToNode(b, addr) - if err != nil { - log.Printf("error replying to %s: %s", addr, err) - } -} - -// Returns a node struct for the addr. It is taken from the table or created -// and possibly added if required and meets validity constraints. -func (s *Server) getNode(addr dHTAddr, id string) (n *node) { - addrStr := addr.String() - n = s.nodes[addrStr] - if n != nil { - if id != "" { - n.SetIDFromString(id) - } - return - } - n = &node{ - addr: addr, - } - if len(id) == 20 { - n.SetIDFromString(id) - } - if len(s.nodes) >= maxNodes { - return - } - if !s.config.NoSecurity && !n.IsSecure() { - return - } - if s.badNodes.Test([]byte(addrStr)) { - return - } - s.nodes[addrStr] = n - return -} - -func (s *Server) nodeTimedOut(addr dHTAddr) { - node, ok := s.nodes[addr.String()] - if !ok { - return - } - if node.DefinitelyGood() { - return - } - if len(s.nodes) < maxNodes { - return - } - delete(s.nodes, addr.String()) -} - -func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) { - if list := s.ipBlockList; list != nil { - if r, ok := list.Lookup(missinggo.AddrIP(node.UDPAddr())); ok { - err = fmt.Errorf("write to %s blocked: %s", node, r.Description) - return - } - } - n, err := s.socket.WriteTo(b, node.UDPAddr()) - if err != nil { - err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err) - return - } - if n != len(b) { - err = io.ErrShortWrite - return - } - return -} - -func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *Transaction { - return s.transactions[transactionKey{ - sourceNode.String(), - transactionID}] -} - -func (s *Server) nextTransactionID() string { - var b [binary.MaxVarintLen64]byte - n := binary.PutUvarint(b[:], s.transactionIDInt) - s.transactionIDInt++ - return string(b[:n]) -} - -func (s *Server) deleteTransaction(t *Transaction) { - delete(s.transactions, t.key()) -} - -func (s *Server) addTransaction(t *Transaction) { - if _, ok := s.transactions[t.key()]; ok { - panic("transaction not unique") - } - s.transactions[t.key()] = t -} - -// ID returns the 20-byte server ID. This is the ID used to communicate with the -// DHT network. -func (s *Server) ID() string { - if len(s.id) != 20 { - panic("bad node id") - } - return s.id -} - -func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *Transaction, err error) { - tid := s.nextTransactionID() - if a == nil { - a = make(map[string]interface{}, 1) - } - a["id"] = s.ID() - d := map[string]interface{}{ - "t": tid, - "y": "q", - "q": q, - "a": a, - } - // BEP 43. Outgoing queries from uncontactiable nodes should contain - // "ro":1 in the top level dictionary. - if s.config.Passive { - d["ro"] = 1 - } - b, err := bencode.Marshal(d) - if err != nil { - return - } - t = &Transaction{ - remoteAddr: node, - t: tid, - response: make(chan Msg, 1), - done: make(chan struct{}), - queryPacket: b, - s: s, - onResponse: onResponse, - } - err = t.sendQuery() - if err != nil { - return - } - s.getNode(node, "").lastSentQuery = time.Now() - t.startTimer() - s.addTransaction(t) - return -} - -// Sends a ping query to the address given. -func (s *Server) Ping(node *net.UDPAddr) (*Transaction, error) { - s.mu.Lock() - defer s.mu.Unlock() - return s.query(newDHTAddr(node), "ping", nil, nil) -} - -func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) { - if port == 0 && !impliedPort { - return errors.New("nothing to announce") - } - _, err = s.query(node, "announce_peer", map[string]interface{}{ - "implied_port": func() int { - if impliedPort { - return 1 - } else { - return 0 - } - }(), - "info_hash": infoHash, - "port": port, - "token": token, - }, func(m Msg) { - if err := m.Error(); err != nil { - announceErrors.Add(1) - // log.Print(token) - // logonce.Stderr.Printf("announce_peer response: %s", err) - return - } - s.numConfirmedAnnounces++ - }) - return -} - -// Add response nodes to node table. -func (s *Server) liftNodes(d Msg) { - if d.Y != "r" { - return - } - for _, cni := range d.R.Nodes { - if missinggo.AddrPort(cni.Addr) == 0 { - // TODO: Why would people even do this? - continue - } - if s.ipBlocked(missinggo.AddrIP(cni.Addr)) { - continue - } - n := s.getNode(cni.Addr, string(cni.ID[:])) - n.SetIDFromBytes(cni.ID[:]) - } -} - -// Sends a find_node query to addr. targetID is the node we're looking for. -func (s *Server) findNode(addr dHTAddr, targetID string) (t *Transaction, err error) { - t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) { - // Scrape peers from the response to put in the server's table before - // handing the response back to the caller. - s.liftNodes(d) - }) - if err != nil { - return - } - return -} - type Peer struct { IP net.IP Port int @@ -766,20 +240,6 @@ func (me *Peer) String() string { return net.JoinHostPort(me.IP.String(), strconv.FormatInt(int64(me.Port), 10)) } -func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err error) { - if len(infoHash) != 20 { - err = fmt.Errorf("infohash has bad length") - return - } - t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) { - s.liftNodes(m) - if m.R != nil && m.R.Token != "" { - s.getNode(addr, m.SenderID()).announceToken = m.R.Token - } - }) - return -} - func bootstrapAddrs(nodeAddrs []string) (addrs []*net.UDPAddr, err error) { bootstrapNodes := nodeAddrs if len(bootstrapNodes) == 0 { @@ -800,155 +260,3 @@ func bootstrapAddrs(nodeAddrs []string) (addrs []*net.UDPAddr, err error) { } return } - -// Adds bootstrap nodes directly to table, if there's room. Node ID security -// is bypassed, but the IP blocklist is not. -func (s *Server) addRootNodes() error { - addrs, err := bootstrapAddrs(s.bootstrapNodes) - if err != nil { - return err - } - for _, addr := range addrs { - if len(s.nodes) >= maxNodes { - break - } - if s.nodes[addr.String()] != nil { - continue - } - if s.ipBlocked(addr.IP) { - log.Printf("dht root node is in the blocklist: %s", addr.IP) - continue - } - s.nodes[addr.String()] = &node{ - addr: newDHTAddr(addr), - } - } - return nil -} - -// Populates the node table. -func (s *Server) bootstrap() (err error) { - s.mu.Lock() - defer s.mu.Unlock() - if len(s.nodes) == 0 { - err = s.addRootNodes() - } - if err != nil { - return - } - for { - var outstanding sync.WaitGroup - for _, node := range s.nodes { - var t *Transaction - t, err = s.findNode(node.addr, s.id) - if err != nil { - err = fmt.Errorf("error sending find_node: %s", err) - return - } - outstanding.Add(1) - t.SetResponseHandler(func(Msg) { - outstanding.Done() - }) - } - noOutstanding := make(chan struct{}) - go func() { - outstanding.Wait() - close(noOutstanding) - }() - s.mu.Unlock() - select { - case <-s.closed: - s.mu.Lock() - return - case <-time.After(15 * time.Second): - case <-noOutstanding: - } - s.mu.Lock() - // log.Printf("now have %d nodes", len(s.nodes)) - if s.numGoodNodes() >= 160 { - break - } - } - return -} - -func (s *Server) numGoodNodes() (num int) { - for _, n := range s.nodes { - if n.DefinitelyGood() { - num++ - } - } - return -} - -// Returns how many nodes are in the node table. -func (s *Server) NumNodes() int { - s.mu.Lock() - defer s.mu.Unlock() - return len(s.nodes) -} - -// Exports the current node table. -func (s *Server) Nodes() (nis []NodeInfo) { - s.mu.Lock() - defer s.mu.Unlock() - for _, node := range s.nodes { - // if !node.Good() { - // continue - // } - ni := NodeInfo{ - Addr: node.addr, - } - if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 { - panic(n) - } - nis = append(nis, ni) - } - return -} - -// Stops the server network activity. This is all that's required to clean-up a Server. -func (s *Server) Close() { - s.mu.Lock() - select { - case <-s.closed: - default: - close(s.closed) - s.socket.Close() - } - s.mu.Unlock() -} - -var maxDistance big.Int - -func init() { - var zero big.Int - maxDistance.SetBit(&zero, 160, 1) -} - -func (s *Server) closestGoodNodes(k int, targetID string) []*node { - return s.closestNodes(k, nodeIDFromString(targetID), func(n *node) bool { return n.DefinitelyGood() }) -} - -func (s *Server) closestNodes(k int, target nodeID, filter func(*node) bool) []*node { - sel := newKClosestNodesSelector(k, target) - idNodes := make(map[string]*node, len(s.nodes)) - for _, node := range s.nodes { - if !filter(node) { - continue - } - sel.Push(node.id) - idNodes[node.idString()] = node - } - ids := sel.IDs() - ret := make([]*node, 0, len(ids)) - for _, id := range ids { - ret = append(ret, idNodes[id.ByteString()]) - } - return ret -} - -func (me *Server) badNode(addr dHTAddr) { - me.badNodes.Add([]byte(addr.String())) - delete(me.nodes, addr.String()) -} diff --git a/dht/doc.go b/dht/doc.go index 007e1ac6..66fe6cdc 100644 --- a/dht/doc.go +++ b/dht/doc.go @@ -1,7 +1,7 @@ /* Package dht implements a Distributed Hash Table (DHT) part of the BitTorrent protocol, - as specified by BEP 5: http://www.bittorrent.org/beps/bep_0005.html. + as specified by BEP 5: http://www.bittorrent.org/beps/bep_0005.html BitTorrent uses a "distributed hash table" (DHT) for storing peer contact information for "trackerless" torrents. diff --git a/dht/server.go b/dht/server.go new file mode 100644 index 00000000..d203ed4b --- /dev/null +++ b/dht/server.go @@ -0,0 +1,701 @@ +package dht + +import ( + "crypto" + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "sync" + "time" + + "github.com/anacrolix/missinggo" + "github.com/anacrolix/torrent/bencode" + "github.com/anacrolix/torrent/iplist" + "github.com/anacrolix/torrent/logonce" + "github.com/tylertreat/BoomFilters" +) + +// A Server defines parameters for a DHT node server that is able to +// send queries, and respond to the ones from the network. +// Each node has a globally unique identifier known as the "node ID." +// Node IDs are chosen at random from the same 160-bit space +// as BitTorrent infohashes and define the behaviour of the node. +// Zero valued Server does not have a valid ID and thus +// is unable to function properly. Use `NewServer(nil)` +// to initialize a default node. +type Server struct { + id string + socket net.PacketConn + transactions map[transactionKey]*Transaction + transactionIDInt uint64 + nodes map[string]*node // Keyed by dHTAddr.String(). + mu sync.Mutex + closed chan struct{} + ipBlockList iplist.Ranger + badNodes *boom.BloomFilter + + numConfirmedAnnounces int + bootstrapNodes []string + config ServerConfig +} + +// Stats returns statistics for the server. +func (s *Server) Stats() (ss ServerStats) { + s.mu.Lock() + defer s.mu.Unlock() + for _, n := range s.nodes { + if n.DefinitelyGood() { + ss.GoodNodes++ + } + } + ss.Nodes = len(s.nodes) + ss.OutstandingTransactions = len(s.transactions) + ss.ConfirmedAnnounces = s.numConfirmedAnnounces + ss.BadNodes = s.badNodes.Count() + return +} + +// Addr returns the listen address for the server. Packets arriving to this address +// are processed by the server (unless aliens are involved). +func (s *Server) Addr() net.Addr { + return s.socket.LocalAddr() +} + +// NewServer initializes a new DHT node server. +func NewServer(c *ServerConfig) (s *Server, err error) { + if c == nil { + c = &ServerConfig{} + } + s = &Server{ + config: *c, + ipBlockList: c.IPBlocklist, + badNodes: boom.NewBloomFilter(1000, 0.1), + } + if c.Conn != nil { + s.socket = c.Conn + } else { + s.socket, err = makeSocket(c.Addr) + if err != nil { + return + } + } + s.bootstrapNodes = c.BootstrapNodes + err = s.init() + if err != nil { + return + } + go func() { + err := s.serve() + select { + case <-s.closed: + return + default: + } + if err != nil { + panic(err) + } + }() + go func() { + err := s.bootstrap() + if err != nil { + select { + case <-s.closed: + default: + log.Printf("error bootstrapping DHT: %s", err) + } + } + }() + return +} + +// Returns a description of the Server. Python repr-style. +func (s *Server) String() string { + return fmt.Sprintf("dht server on %s", s.socket.LocalAddr()) +} + +// Packets to and from any address matching a range in the list are dropped. +func (s *Server) SetIPBlockList(list iplist.Ranger) { + s.mu.Lock() + defer s.mu.Unlock() + s.ipBlockList = list +} + +func (s *Server) IPBlocklist() iplist.Ranger { + return s.ipBlockList +} + +func (s *Server) init() (err error) { + err = s.setDefaults() + if err != nil { + return + } + s.closed = make(chan struct{}) + s.transactions = make(map[transactionKey]*Transaction) + return +} + +func (s *Server) processPacket(b []byte, addr dHTAddr) { + if len(b) < 2 || b[0] != 'd' || b[len(b)-1] != 'e' { + // KRPC messages are bencoded dicts. + readNotKRPCDict.Add(1) + return + } + var d Msg + err := bencode.Unmarshal(b, &d) + if err != nil { + readUnmarshalError.Add(1) + func() { + if se, ok := err.(*bencode.SyntaxError); ok { + // The message was truncated. + if int(se.Offset) == len(b) { + return + } + // Some messages seem to drop to nul chars abrubtly. + if int(se.Offset) < len(b) && b[se.Offset] == 0 { + return + } + // The message isn't bencode from the first. + if se.Offset == 0 { + return + } + } + // if missinggo.CryHeard() { + // log.Printf("%s: received bad krpc message from %s: %s: %+q", s, addr, err, b) + // } + }() + return + } + s.mu.Lock() + defer s.mu.Unlock() + if d.Y == "q" { + readQuery.Add(1) + s.handleQuery(addr, d) + return + } + t := s.findResponseTransaction(d.T, addr) + if t == nil { + //log.Printf("unexpected message: %#v", d) + return + } + node := s.getNode(addr, d.SenderID()) + node.lastGotResponse = time.Now() + // TODO: Update node ID as this is an authoritative packet. + go t.handleResponse(d) + s.deleteTransaction(t) +} + +func (s *Server) serve() error { + var b [0x10000]byte + for { + n, addr, err := s.socket.ReadFrom(b[:]) + if err != nil { + return err + } + read.Add(1) + if n == len(b) { + logonce.Stderr.Printf("received dht packet exceeds buffer size") + continue + } + s.mu.Lock() + blocked := s.ipBlocked(missinggo.AddrIP(addr)) + s.mu.Unlock() + if blocked { + readBlocked.Add(1) + continue + } + s.processPacket(b[:n], newDHTAddr(addr)) + } +} + +func (s *Server) ipBlocked(ip net.IP) (blocked bool) { + if s.ipBlockList == nil { + return + } + _, blocked = s.ipBlockList.Lookup(ip) + return +} + +// Adds directly to the node table. +func (s *Server) AddNode(ni NodeInfo) { + s.mu.Lock() + defer s.mu.Unlock() + if s.nodes == nil { + s.nodes = make(map[string]*node) + } + s.getNode(ni.Addr, string(ni.ID[:])) +} + +func (s *Server) nodeByID(id string) *node { + for _, node := range s.nodes { + if node.idString() == id { + return node + } + } + return nil +} + +func (s *Server) handleQuery(source dHTAddr, m Msg) { + node := s.getNode(source, m.SenderID()) + node.lastGotQuery = time.Now() + // Don't respond. + if s.config.Passive { + return + } + args := m.A + switch m.Q { + case "ping": + s.reply(source, m.T, Return{}) + case "get_peers": // TODO: Extract common behaviour with find_node. + targetID := args.InfoHash + if len(targetID) != 20 { + break + } + var rNodes []NodeInfo + // TODO: Reply with "values" list if we have peers instead. + for _, node := range s.closestGoodNodes(8, targetID) { + rNodes = append(rNodes, node.NodeInfo()) + } + s.reply(source, m.T, Return{ + Nodes: rNodes, + // TODO: Generate this dynamically, and store it for the source. + Token: "hi", + }) + case "find_node": // TODO: Extract common behaviour with get_peers. + targetID := args.Target + if len(targetID) != 20 { + log.Printf("bad DHT query: %v", m) + return + } + var rNodes []NodeInfo + if node := s.nodeByID(targetID); node != nil { + rNodes = append(rNodes, node.NodeInfo()) + } else { + // This will probably cause a crash for IPv6, but meh. + for _, node := range s.closestGoodNodes(8, targetID) { + rNodes = append(rNodes, node.NodeInfo()) + } + } + s.reply(source, m.T, Return{ + Nodes: rNodes, + }) + case "announce_peer": + // TODO(anacrolix): Implement this lolz. + // log.Print(m) + case "vote": + // TODO(anacrolix): Or reject, I don't think I want this. + default: + log.Printf("%s: not handling received query: q=%s", s, m.Q) + return + } +} + +func (s *Server) reply(addr dHTAddr, t string, r Return) { + r.ID = s.ID() + m := Msg{ + T: t, + Y: "r", + R: &r, + } + b, err := bencode.Marshal(m) + if err != nil { + panic(err) + } + err = s.writeToNode(b, addr) + if err != nil { + log.Printf("error replying to %s: %s", addr, err) + } +} + +// Returns a node struct for the addr. It is taken from the table or created +// and possibly added if required and meets validity constraints. +func (s *Server) getNode(addr dHTAddr, id string) (n *node) { + addrStr := addr.String() + n = s.nodes[addrStr] + if n != nil { + if id != "" { + n.SetIDFromString(id) + } + return + } + n = &node{ + addr: addr, + } + if len(id) == 20 { + n.SetIDFromString(id) + } + if len(s.nodes) >= maxNodes { + return + } + if !s.config.NoSecurity && !n.IsSecure() { + return + } + if s.badNodes.Test([]byte(addrStr)) { + return + } + s.nodes[addrStr] = n + return +} + +func (s *Server) nodeTimedOut(addr dHTAddr) { + node, ok := s.nodes[addr.String()] + if !ok { + return + } + if node.DefinitelyGood() { + return + } + if len(s.nodes) < maxNodes { + return + } + delete(s.nodes, addr.String()) +} + +func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) { + if list := s.ipBlockList; list != nil { + if r, ok := list.Lookup(missinggo.AddrIP(node.UDPAddr())); ok { + err = fmt.Errorf("write to %s blocked: %s", node, r.Description) + return + } + } + n, err := s.socket.WriteTo(b, node.UDPAddr()) + if err != nil { + err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err) + return + } + if n != len(b) { + err = io.ErrShortWrite + return + } + return +} + +func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *Transaction { + return s.transactions[transactionKey{ + sourceNode.String(), + transactionID}] +} + +func (s *Server) nextTransactionID() string { + var b [binary.MaxVarintLen64]byte + n := binary.PutUvarint(b[:], s.transactionIDInt) + s.transactionIDInt++ + return string(b[:n]) +} + +func (s *Server) deleteTransaction(t *Transaction) { + delete(s.transactions, t.key()) +} + +func (s *Server) addTransaction(t *Transaction) { + if _, ok := s.transactions[t.key()]; ok { + panic("transaction not unique") + } + s.transactions[t.key()] = t +} + +// ID returns the 20-byte server ID. This is the ID used to communicate with the +// DHT network. +func (s *Server) ID() string { + if len(s.id) != 20 { + panic("bad node id") + } + return s.id +} + +func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *Transaction, err error) { + tid := s.nextTransactionID() + if a == nil { + a = make(map[string]interface{}, 1) + } + a["id"] = s.ID() + d := map[string]interface{}{ + "t": tid, + "y": "q", + "q": q, + "a": a, + } + // BEP 43. Outgoing queries from uncontactiable nodes should contain + // "ro":1 in the top level dictionary. + if s.config.Passive { + d["ro"] = 1 + } + b, err := bencode.Marshal(d) + if err != nil { + return + } + t = &Transaction{ + remoteAddr: node, + t: tid, + response: make(chan Msg, 1), + done: make(chan struct{}), + queryPacket: b, + s: s, + onResponse: onResponse, + } + err = t.sendQuery() + if err != nil { + return + } + s.getNode(node, "").lastSentQuery = time.Now() + t.startTimer() + s.addTransaction(t) + return +} + +// Sends a ping query to the address given. +func (s *Server) Ping(node *net.UDPAddr) (*Transaction, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.query(newDHTAddr(node), "ping", nil, nil) +} + +func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) { + if port == 0 && !impliedPort { + return errors.New("nothing to announce") + } + _, err = s.query(node, "announce_peer", map[string]interface{}{ + "implied_port": func() int { + if impliedPort { + return 1 + } else { + return 0 + } + }(), + "info_hash": infoHash, + "port": port, + "token": token, + }, func(m Msg) { + if err := m.Error(); err != nil { + announceErrors.Add(1) + // log.Print(token) + // logonce.Stderr.Printf("announce_peer response: %s", err) + return + } + s.numConfirmedAnnounces++ + }) + return +} + +// Add response nodes to node table. +func (s *Server) liftNodes(d Msg) { + if d.Y != "r" { + return + } + for _, cni := range d.R.Nodes { + if missinggo.AddrPort(cni.Addr) == 0 { + // TODO: Why would people even do this? + continue + } + if s.ipBlocked(missinggo.AddrIP(cni.Addr)) { + continue + } + n := s.getNode(cni.Addr, string(cni.ID[:])) + n.SetIDFromBytes(cni.ID[:]) + } +} + +// Sends a find_node query to addr. targetID is the node we're looking for. +func (s *Server) findNode(addr dHTAddr, targetID string) (t *Transaction, err error) { + t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) { + // Scrape peers from the response to put in the server's table before + // handing the response back to the caller. + s.liftNodes(d) + }) + if err != nil { + return + } + return +} + +// Adds bootstrap nodes directly to table, if there's room. Node ID security +// is bypassed, but the IP blocklist is not. +func (s *Server) addRootNodes() error { + addrs, err := bootstrapAddrs(s.bootstrapNodes) + if err != nil { + return err + } + for _, addr := range addrs { + if len(s.nodes) >= maxNodes { + break + } + if s.nodes[addr.String()] != nil { + continue + } + if s.ipBlocked(addr.IP) { + log.Printf("dht root node is in the blocklist: %s", addr.IP) + continue + } + s.nodes[addr.String()] = &node{ + addr: newDHTAddr(addr), + } + } + return nil +} + +// Populates the node table. +func (s *Server) bootstrap() (err error) { + s.mu.Lock() + defer s.mu.Unlock() + if len(s.nodes) == 0 { + err = s.addRootNodes() + } + if err != nil { + return + } + for { + var outstanding sync.WaitGroup + for _, node := range s.nodes { + var t *Transaction + t, err = s.findNode(node.addr, s.id) + if err != nil { + err = fmt.Errorf("error sending find_node: %s", err) + return + } + outstanding.Add(1) + t.SetResponseHandler(func(Msg) { + outstanding.Done() + }) + } + noOutstanding := make(chan struct{}) + go func() { + outstanding.Wait() + close(noOutstanding) + }() + s.mu.Unlock() + select { + case <-s.closed: + s.mu.Lock() + return + case <-time.After(15 * time.Second): + case <-noOutstanding: + } + s.mu.Lock() + // log.Printf("now have %d nodes", len(s.nodes)) + if s.numGoodNodes() >= 160 { + break + } + } + return +} + +func (s *Server) numGoodNodes() (num int) { + for _, n := range s.nodes { + if n.DefinitelyGood() { + num++ + } + } + return +} + +// Returns how many nodes are in the node table. +func (s *Server) NumNodes() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.nodes) +} + +// Exports the current node table. +func (s *Server) Nodes() (nis []NodeInfo) { + s.mu.Lock() + defer s.mu.Unlock() + for _, node := range s.nodes { + // if !node.Good() { + // continue + // } + ni := NodeInfo{ + Addr: node.addr, + } + if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 { + panic(n) + } + nis = append(nis, ni) + } + return +} + +// Stops the server network activity. This is all that's required to clean-up a Server. +func (s *Server) Close() { + s.mu.Lock() + select { + case <-s.closed: + default: + close(s.closed) + s.socket.Close() + } + s.mu.Unlock() +} + +func (s *Server) setDefaults() (err error) { + if s.id == "" { + var id [20]byte + h := crypto.SHA1.New() + ss, err := os.Hostname() + if err != nil { + log.Print(err) + } + ss += s.socket.LocalAddr().String() + h.Write([]byte(ss)) + if b := h.Sum(id[:0:20]); len(b) != 20 { + panic(len(b)) + } + if len(id) != 20 { + panic(len(id)) + } + publicIP := func() net.IP { + if s.config.PublicIP != nil { + return s.config.PublicIP + } else { + return missinggo.AddrIP(s.socket.LocalAddr()) + } + }() + SecureNodeId(id[:], publicIP) + s.id = string(id[:]) + } + s.nodes = make(map[string]*node, maxNodes) + return +} + +func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err error) { + if len(infoHash) != 20 { + err = fmt.Errorf("infohash has bad length") + return + } + t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) { + s.liftNodes(m) + if m.R != nil && m.R.Token != "" { + s.getNode(addr, m.SenderID()).announceToken = m.R.Token + } + }) + return +} + +func (s *Server) closestGoodNodes(k int, targetID string) []*node { + return s.closestNodes(k, nodeIDFromString(targetID), func(n *node) bool { return n.DefinitelyGood() }) +} + +func (s *Server) closestNodes(k int, target nodeID, filter func(*node) bool) []*node { + sel := newKClosestNodesSelector(k, target) + idNodes := make(map[string]*node, len(s.nodes)) + for _, node := range s.nodes { + if !filter(node) { + continue + } + sel.Push(node.id) + idNodes[node.idString()] = node + } + ids := sel.IDs() + ret := make([]*node, 0, len(ids)) + for _, id := range ids { + ret = append(ret, idNodes[id.ByteString()]) + } + return ret +} + +func (me *Server) badNode(addr dHTAddr) { + me.badNodes.Add([]byte(addr.String())) + delete(me.nodes, addr.String()) +}