Add customer headers when dialling WS connection to tracker (#789)

* expose WebtorrentTrackerHttpHeader field
This commit is contained in:
Marco Vidonis 2022-12-07 22:17:33 +00:00 committed by GitHub
parent 682c77fcb9
commit 3909c6c125
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 12 deletions

View File

@ -297,8 +297,9 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
} }
return t.announceRequest(event), nil return t.announceRequest(event), nil
}, },
Proxy: cl.config.HTTPProxy, Proxy: cl.config.HTTPProxy,
DialContext: cl.config.TrackerDialContext, WebsocketTrackerHttpHeader: cl.config.WebsocketTrackerHttpHeader,
DialContext: cl.config.TrackerDialContext,
OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) { OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) {
cl.lock() cl.lock()
defer cl.unlock() defer cl.unlock()

View File

@ -117,6 +117,9 @@ type ClientConfig struct {
// HttpRequestDirector modifies the request before it's sent. // HttpRequestDirector modifies the request before it's sent.
// Useful for adding authentication headers, for example // Useful for adding authentication headers, for example
HttpRequestDirector func(*http.Request) error HttpRequestDirector func(*http.Request) error
// WebsocketTrackerHttpHeader returns a custom header to be used when dialing a websocket connection
// to the tracker. Useful for adding authentication headers
WebsocketTrackerHttpHeader func() http.Header
// Updated occasionally to when there's been some changes to client // Updated occasionally to when there's been some changes to client
// behaviour in case other clients are assuming anything of us. See also // behaviour in case other clients are assuming anything of us. See also
// `bep20`. // `bep20`.

View File

@ -5,6 +5,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"sync" "sync"
"time" "time"
@ -40,6 +41,8 @@ type TrackerClient struct {
closed bool closed bool
stats TrackerClientStats stats TrackerClientStats
pingTicker *time.Ticker pingTicker *time.Ticker
WebsocketTrackerHttpHeader func() http.Header
} }
func (me *TrackerClient) Stats() TrackerClientStats { func (me *TrackerClient) Stats() TrackerClientStats {
@ -86,7 +89,13 @@ func (tc *TrackerClient) doWebsocket() error {
tc.mu.Lock() tc.mu.Lock()
tc.stats.Dials++ tc.stats.Dials++
tc.mu.Unlock() tc.mu.Unlock()
c, _, err := tc.Dialer.Dial(tc.Url, nil)
var header http.Header
if tc.WebsocketTrackerHttpHeader != nil {
header = tc.WebsocketTrackerHttpHeader()
}
c, _, err := tc.Dialer.Dial(tc.Url, header)
if err != nil { if err != nil {
return fmt.Errorf("dialing tracker: %w", err) return fmt.Errorf("dialing tracker: %w", err)
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
netHttp "net/http"
"net/url" "net/url"
"sync" "sync"
@ -12,7 +13,7 @@ import (
"github.com/pion/datachannel" "github.com/pion/datachannel"
"github.com/anacrolix/torrent/tracker" "github.com/anacrolix/torrent/tracker"
"github.com/anacrolix/torrent/tracker/http" httpTracker "github.com/anacrolix/torrent/tracker/http"
"github.com/anacrolix/torrent/webtorrent" "github.com/anacrolix/torrent/webtorrent"
) )
@ -35,14 +36,15 @@ type refCountedWebtorrentTrackerClient struct {
} }
type websocketTrackers struct { type websocketTrackers struct {
PeerId [20]byte PeerId [20]byte
Logger log.Logger Logger log.Logger
GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error) GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
OnConn func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext) OnConn func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
mu sync.Mutex mu sync.Mutex
clients map[string]*refCountedWebtorrentTrackerClient clients map[string]*refCountedWebtorrentTrackerClient
Proxy httpTracker.ProxyFunc Proxy httpTracker.ProxyFunc
DialContext func(ctx context.Context, network, addr string) (net.Conn, error) DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
WebsocketTrackerHttpHeader func() netHttp.Header
} }
func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) { func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) {
@ -61,6 +63,7 @@ func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.Tra
Logger: me.Logger.WithText(func(m log.Msg) string { Logger: me.Logger.WithText(func(m log.Msg) string {
return fmt.Sprintf("tracker client for %q: %v", url, m) return fmt.Sprintf("tracker client for %q: %v", url, m)
}), }),
WebsocketTrackerHttpHeader: me.WebsocketTrackerHttpHeader,
}, },
} }
value.TrackerClient.Start(func(err error) { value.TrackerClient.Start(func(err error) {