Start a UDP server implementation

This commit is contained in:
Matt Joiner 2022-12-05 12:52:19 +11:00
parent 682c77fcb9
commit eb9c032f2b
No known key found for this signature in database
GPG Key ID: 6B990B8185E7F782
3 changed files with 225 additions and 0 deletions

View File

@ -1 +1,26 @@
package udp
import (
"encoding"
"github.com/anacrolix/dht/v2/krpc"
)
// Discriminates behaviours based on address family in use.
type AddrFamily int
const (
AddrFamilyIpv4 = iota + 1
AddrFamilyIpv6
)
// Returns a marshaler for the given node addrs for the specified family.
func GetNodeAddrsCompactMarshaler(nas []krpc.NodeAddr, family AddrFamily) encoding.BinaryMarshaler {
switch family {
case AddrFamilyIpv4:
return krpc.CompactIPv4NodeAddrs(nas)
case AddrFamilyIpv6:
return krpc.CompactIPv6NodeAddrs(nas)
}
return nil
}

View File

@ -1 +1,200 @@
package server
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"net"
"net/netip"
"github.com/anacrolix/dht/v2/krpc"
"github.com/anacrolix/log"
"github.com/anacrolix/torrent/tracker/udp"
)
type ConnectionTrackerAddr = string
type ConnectionTracker interface {
Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
}
type InfoHash = [20]byte
// This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key,
// limiting return count, etc.
type GetPeersOpts struct{}
type PeerInfo struct {
netip.AddrPort
}
type AnnounceTracker interface {
TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr RequestSourceAddr) error
Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error)
GetPeers(ctx context.Context, infoHash InfoHash, opts GetPeersOpts) ([]PeerInfo, error)
}
type Server struct {
ConnTracker ConnectionTracker
SendResponse func(data []byte, addr net.Addr) (int, error)
AnnounceTracker AnnounceTracker
}
type RequestSourceAddr = net.Addr
func (me *Server) HandleRequest(ctx context.Context, family udp.AddrFamily, source RequestSourceAddr, body []byte) error {
var h udp.RequestHeader
var r bytes.Reader
r.Reset(body)
err := udp.Read(&r, &h)
if err != nil {
err = fmt.Errorf("reading request header: %w", err)
return err
}
switch h.Action {
case udp.ActionConnect:
err = me.handleConnect(ctx, source, h.TransactionId)
case udp.ActionAnnounce:
err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
default:
err = fmt.Errorf("unimplemented")
}
if err != nil {
err = fmt.Errorf("handling action %v: %w", h.Action, err)
}
return err
}
func (me *Server) handleAnnounce(
ctx context.Context,
addrFamily udp.AddrFamily,
source RequestSourceAddr,
connId udp.ConnectionId,
tid udp.TransactionId,
r *bytes.Reader,
) error {
ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
if err != nil {
err = fmt.Errorf("checking conn id: %w", err)
return err
}
if !ok {
return fmt.Errorf("invalid connection id: %v", connId)
}
var req udp.AnnounceRequest
err = udp.Read(r, &req)
if err != nil {
return err
}
// TODO: This should be done asynchronously to responding to the announce.
err = me.AnnounceTracker.TrackAnnounce(ctx, req, source)
if err != nil {
return err
}
peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, GetPeersOpts{})
if err != nil {
return err
}
nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
for _, p := range peers {
var ip net.IP
switch addrFamily {
default:
continue
case udp.AddrFamilyIpv4:
if !p.Addr().Unmap().Is4() {
continue
}
ipBuf := p.Addr().As4()
ip = ipBuf[:]
case udp.AddrFamilyIpv6:
ipBuf := p.Addr().As16()
ip = ipBuf[:]
}
nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
IP: ip[:],
Port: int(p.Port()),
})
}
var buf bytes.Buffer
err = udp.Write(&buf, udp.ResponseHeader{
Action: udp.ActionAnnounce,
TransactionId: tid,
})
if err != nil {
return err
}
err = udp.Write(&buf, udp.AnnounceResponseHeader{})
if err != nil {
return err
}
b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
if err != nil {
err = fmt.Errorf("marshalling compact node addrs: %w", err)
return err
}
log.Print(nodeAddrs)
buf.Write(b)
n, err := me.SendResponse(buf.Bytes(), source)
if err != nil {
return err
}
if n < buf.Len() {
err = io.ErrShortWrite
}
return err
}
func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
connId := randomConnectionId()
err := me.ConnTracker.Add(ctx, source.String(), connId)
if err != nil {
err = fmt.Errorf("recording conn id: %w", err)
return err
}
var buf bytes.Buffer
udp.Write(&buf, udp.ResponseHeader{
Action: udp.ActionConnect,
TransactionId: tid,
})
udp.Write(&buf, udp.ConnectionResponse{connId})
n, err := me.SendResponse(buf.Bytes(), source)
if err != nil {
return err
}
if n < buf.Len() {
err = io.ErrShortWrite
}
return err
}
func randomConnectionId() udp.ConnectionId {
var b [8]byte
_, err := rand.Read(b[:])
if err != nil {
panic(err)
}
return int64(binary.BigEndian.Uint64(b[:]))
}
func RunServer(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for {
var b [1500]byte
n, addr, err := pc.ReadFrom(b[:])
if err != nil {
return err
}
go func() {
err := s.HandleRequest(ctx, family, addr, b[:n])
if err != nil {
log.Printf("error handling %v byte request from %v: %v", n, addr, err)
}
}()
}
}

View File

@ -23,6 +23,7 @@ import (
var trackers = []string{
"udp://tracker.opentrackr.org:1337/announce",
"udp://tracker.openbittorrent.com:6969/announce",
"udp://localhost:42069",
}
func TestAnnounceLocalhost(t *testing.T) {