Add http and udp tracker server implementations
This commit is contained in:
commit
48b3e66c76
|
@ -3,16 +3,24 @@ package httpTracker
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/anacrolix/dht/v2/krpc"
|
||||
)
|
||||
|
||||
// TODO: Use netip.Addr and Option[[20]byte].
|
||||
type Peer struct {
|
||||
IP net.IP `bencode:"ip"`
|
||||
Port int `bencode:"port"`
|
||||
ID []byte `bencode:"peer id"`
|
||||
}
|
||||
|
||||
func (p Peer) ToNetipAddrPort() (addrPort netip.AddrPort, ok bool) {
|
||||
addr, ok := netip.AddrFromSlice(p.IP)
|
||||
addrPort = netip.AddrPortFrom(addr, uint16(p.Port))
|
||||
return
|
||||
}
|
||||
|
||||
func (p Peer) String() string {
|
||||
loc := net.JoinHostPort(p.IP.String(), fmt.Sprintf("%d", p.Port))
|
||||
if len(p.ID) != 0 {
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
package httpTrackerServer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/anacrolix/dht/v2/krpc"
|
||||
"github.com/anacrolix/generics"
|
||||
"github.com/anacrolix/log"
|
||||
trackerServer "github.com/anacrolix/torrent/tracker/server"
|
||||
|
||||
"github.com/anacrolix/torrent/bencode"
|
||||
"github.com/anacrolix/torrent/tracker"
|
||||
httpTracker "github.com/anacrolix/torrent/tracker/http"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Announce *trackerServer.AnnounceHandler
|
||||
// Called to derive an announcer's IP if non-nil. If not specified, the Request.RemoteAddr is
|
||||
// used. Necessary for instances running behind reverse proxies for example.
|
||||
RequestHost func(r *http.Request) (netip.Addr, error)
|
||||
}
|
||||
|
||||
func unmarshalQueryKeyToArray(w http.ResponseWriter, key string, query url.Values) (ret [20]byte, ok bool) {
|
||||
str := query.Get(key)
|
||||
if len(str) != len(ret) {
|
||||
http.Error(w, fmt.Sprintf("%v has wrong length", key), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
copy(ret[:], str)
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
// Returns false if there was an error and it was served.
|
||||
func (me Handler) requestHostAddr(r *http.Request) (_ netip.Addr, err error) {
|
||||
if me.RequestHost != nil {
|
||||
return me.RequestHost(r)
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return netip.ParseAddr(host)
|
||||
}
|
||||
|
||||
var requestHeadersLogger = log.Default.WithNames("request", "headers")
|
||||
|
||||
func (me Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
vs := r.URL.Query()
|
||||
var event tracker.AnnounceEvent
|
||||
err := event.UnmarshalText([]byte(vs.Get("event")))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
infoHash, ok := unmarshalQueryKeyToArray(w, "info_hash", vs)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
peerId, ok := unmarshalQueryKeyToArray(w, "peer_id", vs)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
requestHeadersLogger.Levelf(log.Debug, "request RemoteAddr=%q, header=%q", r.RemoteAddr, r.Header)
|
||||
addr, err := me.requestHostAddr(r)
|
||||
if err != nil {
|
||||
log.Printf("error getting requester IP: %v", err)
|
||||
http.Error(w, "error determining your IP", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
portU64, _ := strconv.ParseUint(vs.Get("port"), 0, 16)
|
||||
addrPort := netip.AddrPortFrom(addr, uint16(portU64))
|
||||
left, err := strconv.ParseInt(vs.Get("left"), 0, 64)
|
||||
if err != nil {
|
||||
left = -1
|
||||
}
|
||||
res := me.Announce.Serve(
|
||||
r.Context(),
|
||||
tracker.AnnounceRequest{
|
||||
InfoHash: infoHash,
|
||||
PeerId: peerId,
|
||||
Event: event,
|
||||
Port: addrPort.Port(),
|
||||
NumWant: -1,
|
||||
Left: left,
|
||||
},
|
||||
addrPort,
|
||||
trackerServer.GetPeersOpts{
|
||||
MaxCount: generics.Some[uint](200),
|
||||
},
|
||||
)
|
||||
err = res.Err
|
||||
if err != nil {
|
||||
log.Printf("error serving announce: %v", err)
|
||||
http.Error(w, "error handling announce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var resp httpTracker.HttpResponse
|
||||
resp.Incomplete = res.Leechers.Value
|
||||
resp.Complete = res.Seeders.Value
|
||||
resp.Interval = res.Interval.UnwrapOr(5 * 60)
|
||||
resp.Peers.Compact = true
|
||||
for _, peer := range res.Peers {
|
||||
if peer.Addr().Is4() {
|
||||
resp.Peers.List = append(resp.Peers.List, tracker.Peer{
|
||||
IP: peer.Addr().AsSlice(),
|
||||
Port: int(peer.Port()),
|
||||
})
|
||||
} else if peer.Addr().Is6() {
|
||||
resp.Peers6 = append(resp.Peers6, krpc.NodeAddr{
|
||||
IP: peer.Addr().AsSlice(),
|
||||
Port: int(peer.Port()),
|
||||
})
|
||||
}
|
||||
}
|
||||
err = bencode.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
log.Printf("error encoding and writing response body: %v", err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,324 @@
|
|||
package trackerServer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/anacrolix/generics"
|
||||
"github.com/anacrolix/log"
|
||||
"github.com/anacrolix/torrent/tracker"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/anacrolix/torrent/tracker/udp"
|
||||
)
|
||||
|
||||
// This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key,
|
||||
// limiting return count, etc.
|
||||
type GetPeersOpts struct {
|
||||
// Negative numbers are not allowed.
|
||||
MaxCount generics.Option[uint]
|
||||
}
|
||||
|
||||
type InfoHash = [20]byte
|
||||
|
||||
type PeerInfo struct {
|
||||
AnnounceAddr
|
||||
}
|
||||
|
||||
type AnnounceAddr = netip.AddrPort
|
||||
|
||||
type AnnounceTracker interface {
|
||||
TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr AnnounceAddr) error
|
||||
Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error)
|
||||
GetPeers(
|
||||
ctx context.Context,
|
||||
infoHash InfoHash,
|
||||
opts GetPeersOpts,
|
||||
remote AnnounceAddr,
|
||||
) ServerAnnounceResult
|
||||
}
|
||||
|
||||
type ServerAnnounceResult struct {
|
||||
Err error
|
||||
Peers []PeerInfo
|
||||
Interval generics.Option[int32]
|
||||
Leechers generics.Option[int32]
|
||||
Seeders generics.Option[int32]
|
||||
}
|
||||
|
||||
type AnnounceHandler struct {
|
||||
AnnounceTracker AnnounceTracker
|
||||
|
||||
UpstreamTrackers []Client
|
||||
UpstreamTrackerUrls []string
|
||||
UpstreamAnnouncePeerId [20]byte
|
||||
UpstreamAnnounceGate UpstreamAnnounceGater
|
||||
|
||||
mu sync.Mutex
|
||||
// Operations are only removed when all the upstream peers have been tracked.
|
||||
ongoingUpstreamAugmentations map[InfoHash]augmentationOperation
|
||||
}
|
||||
|
||||
type peerSet = map[PeerInfo]struct{}
|
||||
|
||||
type augmentationOperation struct {
|
||||
// Closed when no more announce responses are pending. finalPeers will contain all the peers
|
||||
// seen.
|
||||
doneAnnouncing chan struct{}
|
||||
// This receives the latest peerSet until doneAnnouncing is closed.
|
||||
curPeers chan peerSet
|
||||
// This contains the final peerSet after doneAnnouncing is closed.
|
||||
finalPeers peerSet
|
||||
}
|
||||
|
||||
func (me augmentationOperation) getCurPeers() (ret peerSet) {
|
||||
ret, _ = me.getCurPeersAndDone()
|
||||
return
|
||||
}
|
||||
|
||||
func (me augmentationOperation) getCurPeersAndDone() (ret peerSet, done bool) {
|
||||
select {
|
||||
case ret = <-me.curPeers:
|
||||
case <-me.doneAnnouncing:
|
||||
ret = copyPeerSet(me.finalPeers)
|
||||
done = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Adds peers from new that aren't in orig. Modifies both arguments.
|
||||
func addMissing(orig []PeerInfo, new peerSet) {
|
||||
for _, peer := range orig {
|
||||
delete(new, peer)
|
||||
}
|
||||
for peer := range new {
|
||||
orig = append(orig, peer)
|
||||
}
|
||||
}
|
||||
|
||||
var tracer = otel.Tracer("torrent.tracker.udp")
|
||||
|
||||
func (me *AnnounceHandler) Serve(
|
||||
ctx context.Context, req AnnounceRequest, addr AnnounceAddr, opts GetPeersOpts,
|
||||
) (ret ServerAnnounceResult) {
|
||||
ctx, span := tracer.Start(
|
||||
ctx,
|
||||
"AnnounceHandler.Serve",
|
||||
trace.WithAttributes(
|
||||
attribute.Int64("announce.request.num_want", int64(req.NumWant)),
|
||||
attribute.Int("announce.request.port", int(req.Port)),
|
||||
attribute.String("announce.request.info_hash", hex.EncodeToString(req.InfoHash[:])),
|
||||
attribute.String("announce.request.event", req.Event.String()),
|
||||
attribute.Int64("announce.get_peers.opts.max_count_value", int64(opts.MaxCount.Value)),
|
||||
attribute.Bool("announce.get_peers.opts.max_count_ok", opts.MaxCount.Ok),
|
||||
attribute.String("announce.source.addr.ip", addr.Addr().String()),
|
||||
attribute.Int("announce.source.addr.port", int(addr.Port())),
|
||||
),
|
||||
)
|
||||
defer span.End()
|
||||
defer func() {
|
||||
span.SetAttributes(attribute.Int("announce.get_peers.len", len(ret.Peers)))
|
||||
if ret.Err != nil {
|
||||
span.SetStatus(codes.Error, ret.Err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Port != 0 {
|
||||
addr = netip.AddrPortFrom(addr.Addr(), req.Port)
|
||||
}
|
||||
ret.Err = me.AnnounceTracker.TrackAnnounce(ctx, req, addr)
|
||||
if ret.Err != nil {
|
||||
ret.Err = fmt.Errorf("tracking announce: %w", ret.Err)
|
||||
return
|
||||
}
|
||||
infoHash := req.InfoHash
|
||||
var op generics.Option[augmentationOperation]
|
||||
// Grab a handle to any augmentations that are already running.
|
||||
me.mu.Lock()
|
||||
op.Value, op.Ok = me.ongoingUpstreamAugmentations[infoHash]
|
||||
me.mu.Unlock()
|
||||
// Apply num_want limit to max count. I really can't tell if this is the right place to do it,
|
||||
// but it seems the most flexible.
|
||||
if req.NumWant != -1 {
|
||||
newCount := uint(req.NumWant)
|
||||
if opts.MaxCount.Ok {
|
||||
if newCount < opts.MaxCount.Value {
|
||||
opts.MaxCount.Value = newCount
|
||||
}
|
||||
} else {
|
||||
opts.MaxCount = generics.Some(newCount)
|
||||
}
|
||||
}
|
||||
ret = me.AnnounceTracker.GetPeers(ctx, infoHash, opts, addr)
|
||||
if ret.Err != nil {
|
||||
return
|
||||
}
|
||||
// Take whatever peers it has ready. If it's finished, it doesn't matter if we do this inside
|
||||
// the mutex or not.
|
||||
if op.Ok {
|
||||
curPeers, done := op.Value.getCurPeersAndDone()
|
||||
addMissing(ret.Peers, curPeers)
|
||||
if done {
|
||||
// It doesn't get any better with this operation. Forget it.
|
||||
op.Ok = false
|
||||
}
|
||||
}
|
||||
me.mu.Lock()
|
||||
// If we didn't have an operation, and don't have enough peers, start one. Allowing 1 is
|
||||
// assuming the announcing peer might be that one. Really we should record a value to prevent
|
||||
// duplicate announces. Also don't announce upstream if we got no peers because the caller asked
|
||||
// for none.
|
||||
if !op.Ok && len(ret.Peers) <= 1 && opts.MaxCount.UnwrapOr(1) > 0 {
|
||||
op.Value, op.Ok = me.ongoingUpstreamAugmentations[infoHash]
|
||||
if !op.Ok {
|
||||
op.Set(me.augmentPeersFromUpstream(req.InfoHash))
|
||||
generics.MakeMapIfNilAndSet(&me.ongoingUpstreamAugmentations, infoHash, op.Value)
|
||||
}
|
||||
}
|
||||
me.mu.Unlock()
|
||||
// Wait a while for the current operation.
|
||||
if op.Ok {
|
||||
// Force the augmentation to return with whatever it has if it hasn't completed in a
|
||||
// reasonable time.
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-op.Value.doneAnnouncing:
|
||||
}
|
||||
cancel()
|
||||
addMissing(ret.Peers, op.Value.getCurPeers())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (me *AnnounceHandler) augmentPeersFromUpstream(infoHash [20]byte) augmentationOperation {
|
||||
const announceTimeout = time.Minute
|
||||
announceCtx, cancel := context.WithTimeout(context.Background(), announceTimeout)
|
||||
subReq := AnnounceRequest{
|
||||
InfoHash: infoHash,
|
||||
PeerId: me.UpstreamAnnouncePeerId,
|
||||
Event: tracker.None,
|
||||
Key: 0,
|
||||
NumWant: -1,
|
||||
Port: 0,
|
||||
}
|
||||
peersChan := make(chan []Peer)
|
||||
var pendingUpstreams sync.WaitGroup
|
||||
for i := range me.UpstreamTrackers {
|
||||
client := me.UpstreamTrackers[i]
|
||||
url := me.UpstreamTrackerUrls[i]
|
||||
pendingUpstreams.Add(1)
|
||||
go func() {
|
||||
started, err := me.UpstreamAnnounceGate.Start(announceCtx, url, infoHash, announceTimeout)
|
||||
if err != nil {
|
||||
log.Printf("error reserving announce for %x to %v: %v", infoHash, url, err)
|
||||
}
|
||||
if err != nil || !started {
|
||||
peersChan <- nil
|
||||
return
|
||||
}
|
||||
log.Printf("announcing %x upstream to %v", infoHash, url)
|
||||
resp, err := client.Announce(announceCtx, subReq, tracker.AnnounceOpt{
|
||||
UserAgent: "aragorn",
|
||||
})
|
||||
interval := resp.Interval
|
||||
go func() {
|
||||
if interval < 5*60 {
|
||||
// This is as much to reduce load on upstream trackers in the event of errors,
|
||||
// as it is to reduce load on our peer store.
|
||||
interval = 5 * 60
|
||||
}
|
||||
err := me.UpstreamAnnounceGate.Completed(context.Background(), url, infoHash, interval)
|
||||
if err != nil {
|
||||
log.Printf("error recording completed announce for %x to %v: %v", infoHash, url, err)
|
||||
}
|
||||
}()
|
||||
peersChan <- resp.Peers
|
||||
if err != nil {
|
||||
log.Levelf(log.Warning, "error announcing to upstream %q: %v", url, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
peersToTrack := make(map[string]Peer)
|
||||
go func() {
|
||||
pendingUpstreams.Wait()
|
||||
cancel()
|
||||
close(peersChan)
|
||||
log.Levelf(log.Debug, "adding %v distinct peers from upstream trackers", len(peersToTrack))
|
||||
for _, peer := range peersToTrack {
|
||||
addrPort, ok := peer.ToNetipAddrPort()
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
trackReq := AnnounceRequest{
|
||||
InfoHash: infoHash,
|
||||
Event: tracker.Started,
|
||||
Port: uint16(peer.Port),
|
||||
// Let's assume upstream peers are leechers without knowing better.
|
||||
Left: -1,
|
||||
}
|
||||
copy(trackReq.PeerId[:], peer.ID)
|
||||
// TODO: How do we know if these peers are leechers or seeders?
|
||||
err := me.AnnounceTracker.TrackAnnounce(context.TODO(), trackReq, addrPort)
|
||||
if err != nil {
|
||||
log.Levelf(log.Error, "error tracking upstream peer: %v", err)
|
||||
}
|
||||
}
|
||||
me.mu.Lock()
|
||||
delete(me.ongoingUpstreamAugmentations, infoHash)
|
||||
me.mu.Unlock()
|
||||
}()
|
||||
curPeersChan := make(chan map[PeerInfo]struct{})
|
||||
doneChan := make(chan struct{})
|
||||
retPeers := make(map[PeerInfo]struct{})
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
for {
|
||||
select {
|
||||
case peers, ok := <-peersChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
voldemort(peers, peersToTrack, retPeers)
|
||||
pendingUpstreams.Done()
|
||||
case curPeersChan <- copyPeerSet(retPeers):
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Take return references.
|
||||
return augmentationOperation{
|
||||
curPeers: curPeersChan,
|
||||
finalPeers: retPeers,
|
||||
doneAnnouncing: doneChan,
|
||||
}
|
||||
}
|
||||
|
||||
func copyPeerSet(orig peerSet) (ret peerSet) {
|
||||
ret = make(peerSet, len(orig))
|
||||
for k, v := range orig {
|
||||
ret[k] = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Adds peers to trailing containers.
|
||||
func voldemort(peers []Peer, toTrack map[string]Peer, sets ...map[PeerInfo]struct{}) {
|
||||
for _, protoPeer := range peers {
|
||||
toTrack[protoPeer.String()] = protoPeer
|
||||
addr, ok := netip.AddrFromSlice(protoPeer.IP)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
handlerPeer := PeerInfo{netip.AddrPortFrom(addr, uint16(protoPeer.Port))}
|
||||
for _, set := range sets {
|
||||
set[handlerPeer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package trackerServer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UpstreamAnnounceGater interface {
|
||||
Start(ctx context.Context, tracker string, infoHash InfoHash,
|
||||
// How long the announce block remains before discarding it.
|
||||
timeout time.Duration,
|
||||
) (bool, error)
|
||||
Completed(
|
||||
ctx context.Context, tracker string, infoHash InfoHash,
|
||||
// Num of seconds reported by tracker, or some suitable value the caller has chosen.
|
||||
interval int32,
|
||||
) error
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
package trackerServer
|
||||
|
||||
import "github.com/anacrolix/torrent/tracker"
|
||||
|
||||
type (
|
||||
AnnounceRequest = tracker.AnnounceRequest
|
||||
Client = tracker.Client
|
||||
Peer = tracker.Peer
|
||||
)
|
|
@ -21,7 +21,7 @@ type torrent struct {
|
|||
|
||||
type server struct {
|
||||
pc net.PacketConn
|
||||
conns map[int64]struct{}
|
||||
conns map[udp.ConnectionId]struct{}
|
||||
t map[[20]byte]torrent
|
||||
}
|
||||
|
||||
|
@ -46,10 +46,10 @@ func (s *server) respond(addr net.Addr, rh udp.ResponseHeader, parts ...interfac
|
|||
return
|
||||
}
|
||||
|
||||
func (s *server) newConn() (ret int64) {
|
||||
ret = rand.Int63()
|
||||
func (s *server) newConn() (ret udp.ConnectionId) {
|
||||
ret = rand.Uint64()
|
||||
if s.conns == nil {
|
||||
s.conns = make(map[int64]struct{})
|
||||
s.conns = make(map[udp.ConnectionId]struct{})
|
||||
}
|
||||
s.conns[ret] = struct{}{}
|
||||
return
|
||||
|
|
|
@ -1 +1,26 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
|
||||
"github.com/anacrolix/dht/v2/krpc"
|
||||
)
|
||||
|
||||
// Discriminates behaviours based on address family in use.
|
||||
type AddrFamily int
|
||||
|
||||
const (
|
||||
AddrFamilyIpv4 = iota + 1
|
||||
AddrFamilyIpv6
|
||||
)
|
||||
|
||||
// Returns a marshaler for the given node addrs for the specified family.
|
||||
func GetNodeAddrsCompactMarshaler(nas []krpc.NodeAddr, family AddrFamily) encoding.BinaryMarshaler {
|
||||
switch family {
|
||||
case AddrFamilyIpv4:
|
||||
return krpc.CompactIPv4NodeAddrs(nas)
|
||||
case AddrFamilyIpv6:
|
||||
return krpc.CompactIPv6NodeAddrs(nas)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -38,7 +38,12 @@ func (me *AnnounceEvent) UnmarshalText(text []byte) error {
|
|||
var announceEventStrings = []string{"", "completed", "started", "stopped"}
|
||||
|
||||
func (e AnnounceEvent) String() string {
|
||||
// See BEP 3, "event", and https://github.com/anacrolix/torrent/issues/416#issuecomment-751427001.
|
||||
// See BEP 3, "event", and
|
||||
// https://github.com/anacrolix/torrent/issues/416#issuecomment-751427001. Return a safe default
|
||||
// in case event values are not sanitized.
|
||||
if e < 0 || int(e) >= len(announceEventStrings) {
|
||||
return ""
|
||||
}
|
||||
return announceEventStrings[e]
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ const (
|
|||
|
||||
type TransactionId = int32
|
||||
|
||||
type ConnectionId = int64
|
||||
type ConnectionId = uint64
|
||||
|
||||
type ConnectionRequest struct {
|
||||
ConnectionId ConnectionId
|
||||
|
|
|
@ -1 +1,241 @@
|
|||
package server
|
||||
package udpTrackerServer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/anacrolix/dht/v2/krpc"
|
||||
"github.com/anacrolix/generics"
|
||||
"github.com/anacrolix/log"
|
||||
trackerServer "github.com/anacrolix/torrent/tracker/server"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/anacrolix/torrent/tracker/udp"
|
||||
)
|
||||
|
||||
type ConnectionTrackerAddr = string
|
||||
|
||||
type ConnectionTracker interface {
|
||||
Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
|
||||
Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
|
||||
}
|
||||
|
||||
type InfoHash = [20]byte
|
||||
|
||||
type AnnounceTracker = trackerServer.AnnounceTracker
|
||||
|
||||
type Server struct {
|
||||
ConnTracker ConnectionTracker
|
||||
SendResponse func(ctx context.Context, data []byte, addr net.Addr) (int, error)
|
||||
Announce *trackerServer.AnnounceHandler
|
||||
}
|
||||
|
||||
type RequestSourceAddr = net.Addr
|
||||
|
||||
var tracer = otel.Tracer("torrent.tracker.udp")
|
||||
|
||||
func (me *Server) HandleRequest(
|
||||
ctx context.Context,
|
||||
family udp.AddrFamily,
|
||||
source RequestSourceAddr,
|
||||
body []byte,
|
||||
) (err error) {
|
||||
ctx, span := tracer.Start(ctx, "Server.HandleRequest",
|
||||
trace.WithAttributes(attribute.Int("payload.len", len(body))))
|
||||
defer span.End()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
}
|
||||
}()
|
||||
var h udp.RequestHeader
|
||||
var r bytes.Reader
|
||||
r.Reset(body)
|
||||
err = udp.Read(&r, &h)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("reading request header: %w", err)
|
||||
return err
|
||||
}
|
||||
switch h.Action {
|
||||
case udp.ActionConnect:
|
||||
err = me.handleConnect(ctx, source, h.TransactionId)
|
||||
case udp.ActionAnnounce:
|
||||
err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
|
||||
default:
|
||||
err = fmt.Errorf("unimplemented")
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("handling action %v: %w", h.Action, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (me *Server) handleAnnounce(
|
||||
ctx context.Context,
|
||||
addrFamily udp.AddrFamily,
|
||||
source RequestSourceAddr,
|
||||
connId udp.ConnectionId,
|
||||
tid udp.TransactionId,
|
||||
r *bytes.Reader,
|
||||
) error {
|
||||
// Should we set a timeout of 10s or something for the entire response, so that we give up if a
|
||||
// retry is imminent?
|
||||
|
||||
ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("checking conn id: %w", err)
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("incorrect connection id: %x", connId)
|
||||
}
|
||||
var req udp.AnnounceRequest
|
||||
err = udp.Read(r, &req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: This should be done asynchronously to responding to the announce.
|
||||
announceAddr, err := netip.ParseAddrPort(source.String())
|
||||
if err != nil {
|
||||
err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
|
||||
return err
|
||||
}
|
||||
opts := trackerServer.GetPeersOpts{MaxCount: generics.Some[uint](50)}
|
||||
if addrFamily == udp.AddrFamilyIpv4 {
|
||||
opts.MaxCount = generics.Some[uint](150)
|
||||
}
|
||||
res := me.Announce.Serve(ctx, req, announceAddr, opts)
|
||||
if res.Err != nil {
|
||||
return res.Err
|
||||
}
|
||||
nodeAddrs := make([]krpc.NodeAddr, 0, len(res.Peers))
|
||||
for _, p := range res.Peers {
|
||||
var ip net.IP
|
||||
switch addrFamily {
|
||||
default:
|
||||
continue
|
||||
case udp.AddrFamilyIpv4:
|
||||
if !p.Addr().Unmap().Is4() {
|
||||
continue
|
||||
}
|
||||
ipBuf := p.Addr().As4()
|
||||
ip = ipBuf[:]
|
||||
case udp.AddrFamilyIpv6:
|
||||
ipBuf := p.Addr().As16()
|
||||
ip = ipBuf[:]
|
||||
}
|
||||
nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
|
||||
IP: ip[:],
|
||||
Port: int(p.Port()),
|
||||
})
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
err = udp.Write(&buf, udp.ResponseHeader{
|
||||
Action: udp.ActionAnnounce,
|
||||
TransactionId: tid,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = udp.Write(&buf, udp.AnnounceResponseHeader{
|
||||
Interval: res.Interval.UnwrapOr(5 * 60),
|
||||
Seeders: res.Seeders.Value,
|
||||
Leechers: res.Leechers.Value,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("marshalling compact node addrs: %w", err)
|
||||
return err
|
||||
}
|
||||
buf.Write(b)
|
||||
n, err := me.SendResponse(ctx, buf.Bytes(), source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n < buf.Len() {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
|
||||
connId := randomConnectionId()
|
||||
err := me.ConnTracker.Add(ctx, source.String(), connId)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("recording conn id: %w", err)
|
||||
return err
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
udp.Write(&buf, udp.ResponseHeader{
|
||||
Action: udp.ActionConnect,
|
||||
TransactionId: tid,
|
||||
})
|
||||
udp.Write(&buf, udp.ConnectionResponse{connId})
|
||||
n, err := me.SendResponse(ctx, buf.Bytes(), source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n < buf.Len() {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func randomConnectionId() udp.ConnectionId {
|
||||
var b [8]byte
|
||||
_, err := rand.Read(b[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return binary.BigEndian.Uint64(b[:])
|
||||
}
|
||||
|
||||
func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
var b [1500]byte
|
||||
// Limit concurrent handled requests.
|
||||
sem := make(chan struct{}, 1000)
|
||||
for {
|
||||
n, addr, err := pc.ReadFrom(b[:])
|
||||
ctx, span := tracer.Start(ctx, "handle udp packet")
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
span.End()
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
span.End()
|
||||
return ctx.Err()
|
||||
default:
|
||||
span.SetStatus(codes.Error, "concurrency limit reached")
|
||||
span.End()
|
||||
log.Levelf(log.Debug, "dropping request from %v: concurrency limit reached", addr)
|
||||
continue
|
||||
case sem <- struct{}{}:
|
||||
}
|
||||
b := append([]byte(nil), b[:n]...)
|
||||
go func() {
|
||||
defer span.End()
|
||||
defer func() { <-sem }()
|
||||
err := s.HandleRequest(ctx, family, addr, b)
|
||||
if err != nil {
|
||||
log.Printf("error handling %v byte request from %v: %v", n, addr, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
var trackers = []string{
|
||||
"udp://tracker.opentrackr.org:1337/announce",
|
||||
"udp://tracker.openbittorrent.com:6969/announce",
|
||||
"udp://localhost:42069",
|
||||
}
|
||||
|
||||
func TestAnnounceLocalhost(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue