diff --git a/client.go b/client.go index e68e80b1..4adf28b7 100644 --- a/client.go +++ b/client.go @@ -297,8 +297,9 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) { } return t.announceRequest(event), nil }, - Proxy: cl.config.HTTPProxy, - DialContext: cl.config.TrackerDialContext, + Proxy: cl.config.HTTPProxy, + WebsocketTrackerHttpHeader: cl.config.WebsocketTrackerHttpHeader, + DialContext: cl.config.TrackerDialContext, OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) { cl.lock() defer cl.unlock() diff --git a/config.go b/config.go index 09f9bc1e..e1e6452a 100644 --- a/config.go +++ b/config.go @@ -117,6 +117,9 @@ type ClientConfig struct { // HttpRequestDirector modifies the request before it's sent. // Useful for adding authentication headers, for example HttpRequestDirector func(*http.Request) error + // WebsocketTrackerHttpHeader returns a custom header to be used when dialing a websocket connection + // to the tracker. Useful for adding authentication headers + WebsocketTrackerHttpHeader func() http.Header // Updated occasionally to when there's been some changes to client // behaviour in case other clients are assuming anything of us. See also // `bep20`. diff --git a/webtorrent/tracker-client.go b/webtorrent/tracker-client.go index 64885bf4..60cd8527 100644 --- a/webtorrent/tracker-client.go +++ b/webtorrent/tracker-client.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/json" "fmt" + "net/http" "sync" "time" @@ -40,6 +41,8 @@ type TrackerClient struct { closed bool stats TrackerClientStats pingTicker *time.Ticker + + WebsocketTrackerHttpHeader func() http.Header } func (me *TrackerClient) Stats() TrackerClientStats { @@ -86,7 +89,13 @@ func (tc *TrackerClient) doWebsocket() error { tc.mu.Lock() tc.stats.Dials++ tc.mu.Unlock() - c, _, err := tc.Dialer.Dial(tc.Url, nil) + + var header http.Header + if tc.WebsocketTrackerHttpHeader != nil { + header = tc.WebsocketTrackerHttpHeader() + } + + c, _, err := tc.Dialer.Dial(tc.Url, header) if err != nil { return fmt.Errorf("dialing tracker: %w", err) } diff --git a/wstracker.go b/wstracker.go index 9b1a9201..c379dc31 100644 --- a/wstracker.go +++ b/wstracker.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + netHttp "net/http" "net/url" "sync" @@ -12,7 +13,7 @@ import ( "github.com/pion/datachannel" "github.com/anacrolix/torrent/tracker" - "github.com/anacrolix/torrent/tracker/http" + httpTracker "github.com/anacrolix/torrent/tracker/http" "github.com/anacrolix/torrent/webtorrent" ) @@ -35,14 +36,15 @@ type refCountedWebtorrentTrackerClient struct { } type websocketTrackers struct { - PeerId [20]byte - Logger log.Logger - GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error) - OnConn func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext) - mu sync.Mutex - clients map[string]*refCountedWebtorrentTrackerClient - Proxy httpTracker.ProxyFunc - DialContext func(ctx context.Context, network, addr string) (net.Conn, error) + PeerId [20]byte + Logger log.Logger + GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error) + OnConn func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext) + mu sync.Mutex + clients map[string]*refCountedWebtorrentTrackerClient + Proxy httpTracker.ProxyFunc + DialContext func(ctx context.Context, network, addr string) (net.Conn, error) + WebsocketTrackerHttpHeader func() netHttp.Header } func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) { @@ -61,6 +63,7 @@ func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.Tra Logger: me.Logger.WithText(func(m log.Msg) string { return fmt.Sprintf("tracker client for %q: %v", url, m) }), + WebsocketTrackerHttpHeader: me.WebsocketTrackerHttpHeader, }, } value.TrackerClient.Start(func(err error) {