Start a UDP server implementation
This commit is contained in:
parent
682c77fcb9
commit
eb9c032f2b
|
@ -1 +1,26 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
|
||||
"github.com/anacrolix/dht/v2/krpc"
|
||||
)
|
||||
|
||||
// Discriminates behaviours based on address family in use.
|
||||
type AddrFamily int
|
||||
|
||||
const (
|
||||
AddrFamilyIpv4 = iota + 1
|
||||
AddrFamilyIpv6
|
||||
)
|
||||
|
||||
// Returns a marshaler for the given node addrs for the specified family.
|
||||
func GetNodeAddrsCompactMarshaler(nas []krpc.NodeAddr, family AddrFamily) encoding.BinaryMarshaler {
|
||||
switch family {
|
||||
case AddrFamilyIpv4:
|
||||
return krpc.CompactIPv4NodeAddrs(nas)
|
||||
case AddrFamilyIpv6:
|
||||
return krpc.CompactIPv6NodeAddrs(nas)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1 +1,200 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/anacrolix/dht/v2/krpc"
|
||||
"github.com/anacrolix/log"
|
||||
"github.com/anacrolix/torrent/tracker/udp"
|
||||
)
|
||||
|
||||
type ConnectionTrackerAddr = string
|
||||
|
||||
type ConnectionTracker interface {
|
||||
Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
|
||||
Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
|
||||
}
|
||||
|
||||
type InfoHash = [20]byte
|
||||
|
||||
// This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key,
|
||||
// limiting return count, etc.
|
||||
type GetPeersOpts struct{}
|
||||
|
||||
type PeerInfo struct {
|
||||
netip.AddrPort
|
||||
}
|
||||
|
||||
type AnnounceTracker interface {
|
||||
TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr RequestSourceAddr) error
|
||||
Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error)
|
||||
GetPeers(ctx context.Context, infoHash InfoHash, opts GetPeersOpts) ([]PeerInfo, error)
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
ConnTracker ConnectionTracker
|
||||
SendResponse func(data []byte, addr net.Addr) (int, error)
|
||||
AnnounceTracker AnnounceTracker
|
||||
}
|
||||
|
||||
type RequestSourceAddr = net.Addr
|
||||
|
||||
func (me *Server) HandleRequest(ctx context.Context, family udp.AddrFamily, source RequestSourceAddr, body []byte) error {
|
||||
var h udp.RequestHeader
|
||||
var r bytes.Reader
|
||||
r.Reset(body)
|
||||
err := udp.Read(&r, &h)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("reading request header: %w", err)
|
||||
return err
|
||||
}
|
||||
switch h.Action {
|
||||
case udp.ActionConnect:
|
||||
err = me.handleConnect(ctx, source, h.TransactionId)
|
||||
case udp.ActionAnnounce:
|
||||
err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
|
||||
default:
|
||||
err = fmt.Errorf("unimplemented")
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("handling action %v: %w", h.Action, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (me *Server) handleAnnounce(
|
||||
ctx context.Context,
|
||||
addrFamily udp.AddrFamily,
|
||||
source RequestSourceAddr,
|
||||
connId udp.ConnectionId,
|
||||
tid udp.TransactionId,
|
||||
r *bytes.Reader,
|
||||
) error {
|
||||
ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("checking conn id: %w", err)
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid connection id: %v", connId)
|
||||
}
|
||||
var req udp.AnnounceRequest
|
||||
err = udp.Read(r, &req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: This should be done asynchronously to responding to the announce.
|
||||
err = me.AnnounceTracker.TrackAnnounce(ctx, req, source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, GetPeersOpts{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
|
||||
for _, p := range peers {
|
||||
var ip net.IP
|
||||
switch addrFamily {
|
||||
default:
|
||||
continue
|
||||
case udp.AddrFamilyIpv4:
|
||||
if !p.Addr().Unmap().Is4() {
|
||||
continue
|
||||
}
|
||||
ipBuf := p.Addr().As4()
|
||||
ip = ipBuf[:]
|
||||
case udp.AddrFamilyIpv6:
|
||||
ipBuf := p.Addr().As16()
|
||||
ip = ipBuf[:]
|
||||
}
|
||||
nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
|
||||
IP: ip[:],
|
||||
Port: int(p.Port()),
|
||||
})
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
err = udp.Write(&buf, udp.ResponseHeader{
|
||||
Action: udp.ActionAnnounce,
|
||||
TransactionId: tid,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = udp.Write(&buf, udp.AnnounceResponseHeader{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("marshalling compact node addrs: %w", err)
|
||||
return err
|
||||
}
|
||||
log.Print(nodeAddrs)
|
||||
buf.Write(b)
|
||||
n, err := me.SendResponse(buf.Bytes(), source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n < buf.Len() {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
|
||||
connId := randomConnectionId()
|
||||
err := me.ConnTracker.Add(ctx, source.String(), connId)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("recording conn id: %w", err)
|
||||
return err
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
udp.Write(&buf, udp.ResponseHeader{
|
||||
Action: udp.ActionConnect,
|
||||
TransactionId: tid,
|
||||
})
|
||||
udp.Write(&buf, udp.ConnectionResponse{connId})
|
||||
n, err := me.SendResponse(buf.Bytes(), source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n < buf.Len() {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func randomConnectionId() udp.ConnectionId {
|
||||
var b [8]byte
|
||||
_, err := rand.Read(b[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(binary.BigEndian.Uint64(b[:]))
|
||||
}
|
||||
|
||||
func RunServer(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
for {
|
||||
var b [1500]byte
|
||||
n, addr, err := pc.ReadFrom(b[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
err := s.HandleRequest(ctx, family, addr, b[:n])
|
||||
if err != nil {
|
||||
log.Printf("error handling %v byte request from %v: %v", n, addr, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
var trackers = []string{
|
||||
"udp://tracker.opentrackr.org:1337/announce",
|
||||
"udp://tracker.openbittorrent.com:6969/announce",
|
||||
"udp://localhost:42069",
|
||||
}
|
||||
|
||||
func TestAnnounceLocalhost(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue