Start a UDP server implementation
This commit is contained in:
parent
682c77fcb9
commit
eb9c032f2b
|
@ -1 +1,26 @@
|
||||||
package udp
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding"
|
||||||
|
|
||||||
|
"github.com/anacrolix/dht/v2/krpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Discriminates behaviours based on address family in use.
|
||||||
|
type AddrFamily int
|
||||||
|
|
||||||
|
const (
|
||||||
|
AddrFamilyIpv4 = iota + 1
|
||||||
|
AddrFamilyIpv6
|
||||||
|
)
|
||||||
|
|
||||||
|
// Returns a marshaler for the given node addrs for the specified family.
|
||||||
|
func GetNodeAddrsCompactMarshaler(nas []krpc.NodeAddr, family AddrFamily) encoding.BinaryMarshaler {
|
||||||
|
switch family {
|
||||||
|
case AddrFamilyIpv4:
|
||||||
|
return krpc.CompactIPv4NodeAddrs(nas)
|
||||||
|
case AddrFamilyIpv6:
|
||||||
|
return krpc.CompactIPv6NodeAddrs(nas)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -1 +1,200 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/anacrolix/dht/v2/krpc"
|
||||||
|
"github.com/anacrolix/log"
|
||||||
|
"github.com/anacrolix/torrent/tracker/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectionTrackerAddr = string
|
||||||
|
|
||||||
|
type ConnectionTracker interface {
|
||||||
|
Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
|
||||||
|
Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type InfoHash = [20]byte
|
||||||
|
|
||||||
|
// This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key,
|
||||||
|
// limiting return count, etc.
|
||||||
|
type GetPeersOpts struct{}
|
||||||
|
|
||||||
|
type PeerInfo struct {
|
||||||
|
netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
type AnnounceTracker interface {
|
||||||
|
TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr RequestSourceAddr) error
|
||||||
|
Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error)
|
||||||
|
GetPeers(ctx context.Context, infoHash InfoHash, opts GetPeersOpts) ([]PeerInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
ConnTracker ConnectionTracker
|
||||||
|
SendResponse func(data []byte, addr net.Addr) (int, error)
|
||||||
|
AnnounceTracker AnnounceTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestSourceAddr = net.Addr
|
||||||
|
|
||||||
|
func (me *Server) HandleRequest(ctx context.Context, family udp.AddrFamily, source RequestSourceAddr, body []byte) error {
|
||||||
|
var h udp.RequestHeader
|
||||||
|
var r bytes.Reader
|
||||||
|
r.Reset(body)
|
||||||
|
err := udp.Read(&r, &h)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("reading request header: %w", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch h.Action {
|
||||||
|
case udp.ActionConnect:
|
||||||
|
err = me.handleConnect(ctx, source, h.TransactionId)
|
||||||
|
case udp.ActionAnnounce:
|
||||||
|
err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unimplemented")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("handling action %v: %w", h.Action, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (me *Server) handleAnnounce(
|
||||||
|
ctx context.Context,
|
||||||
|
addrFamily udp.AddrFamily,
|
||||||
|
source RequestSourceAddr,
|
||||||
|
connId udp.ConnectionId,
|
||||||
|
tid udp.TransactionId,
|
||||||
|
r *bytes.Reader,
|
||||||
|
) error {
|
||||||
|
ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("checking conn id: %w", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid connection id: %v", connId)
|
||||||
|
}
|
||||||
|
var req udp.AnnounceRequest
|
||||||
|
err = udp.Read(r, &req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// TODO: This should be done asynchronously to responding to the announce.
|
||||||
|
err = me.AnnounceTracker.TrackAnnounce(ctx, req, source)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, GetPeersOpts{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
|
||||||
|
for _, p := range peers {
|
||||||
|
var ip net.IP
|
||||||
|
switch addrFamily {
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
case udp.AddrFamilyIpv4:
|
||||||
|
if !p.Addr().Unmap().Is4() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ipBuf := p.Addr().As4()
|
||||||
|
ip = ipBuf[:]
|
||||||
|
case udp.AddrFamilyIpv6:
|
||||||
|
ipBuf := p.Addr().As16()
|
||||||
|
ip = ipBuf[:]
|
||||||
|
}
|
||||||
|
nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
|
||||||
|
IP: ip[:],
|
||||||
|
Port: int(p.Port()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err = udp.Write(&buf, udp.ResponseHeader{
|
||||||
|
Action: udp.ActionAnnounce,
|
||||||
|
TransactionId: tid,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = udp.Write(&buf, udp.AnnounceResponseHeader{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("marshalling compact node addrs: %w", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Print(nodeAddrs)
|
||||||
|
buf.Write(b)
|
||||||
|
n, err := me.SendResponse(buf.Bytes(), source)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n < buf.Len() {
|
||||||
|
err = io.ErrShortWrite
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
|
||||||
|
connId := randomConnectionId()
|
||||||
|
err := me.ConnTracker.Add(ctx, source.String(), connId)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("recording conn id: %w", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
udp.Write(&buf, udp.ResponseHeader{
|
||||||
|
Action: udp.ActionConnect,
|
||||||
|
TransactionId: tid,
|
||||||
|
})
|
||||||
|
udp.Write(&buf, udp.ConnectionResponse{connId})
|
||||||
|
n, err := me.SendResponse(buf.Bytes(), source)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n < buf.Len() {
|
||||||
|
err = io.ErrShortWrite
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomConnectionId() udp.ConnectionId {
|
||||||
|
var b [8]byte
|
||||||
|
_, err := rand.Read(b[:])
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return int64(binary.BigEndian.Uint64(b[:]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunServer(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
for {
|
||||||
|
var b [1500]byte
|
||||||
|
n, addr, err := pc.ReadFrom(b[:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
err := s.HandleRequest(ctx, family, addr, b[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error handling %v byte request from %v: %v", n, addr, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
var trackers = []string{
|
var trackers = []string{
|
||||||
"udp://tracker.opentrackr.org:1337/announce",
|
"udp://tracker.opentrackr.org:1337/announce",
|
||||||
"udp://tracker.openbittorrent.com:6969/announce",
|
"udp://tracker.openbittorrent.com:6969/announce",
|
||||||
|
"udp://localhost:42069",
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAnnounceLocalhost(t *testing.T) {
|
func TestAnnounceLocalhost(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue