Add customer headers when dialling WS connection to tracker (#789)
* expose WebtorrentTrackerHttpHeader field
This commit is contained in:
parent
682c77fcb9
commit
3909c6c125
|
@ -298,6 +298,7 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
|
|||
return t.announceRequest(event), nil
|
||||
},
|
||||
Proxy: cl.config.HTTPProxy,
|
||||
WebsocketTrackerHttpHeader: cl.config.WebsocketTrackerHttpHeader,
|
||||
DialContext: cl.config.TrackerDialContext,
|
||||
OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) {
|
||||
cl.lock()
|
||||
|
|
|
@ -117,6 +117,9 @@ type ClientConfig struct {
|
|||
// HttpRequestDirector modifies the request before it's sent.
|
||||
// Useful for adding authentication headers, for example
|
||||
HttpRequestDirector func(*http.Request) error
|
||||
// WebsocketTrackerHttpHeader returns a custom header to be used when dialing a websocket connection
|
||||
// to the tracker. Useful for adding authentication headers
|
||||
WebsocketTrackerHttpHeader func() http.Header
|
||||
// Updated occasionally to when there's been some changes to client
|
||||
// behaviour in case other clients are assuming anything of us. See also
|
||||
// `bep20`.
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -40,6 +41,8 @@ type TrackerClient struct {
|
|||
closed bool
|
||||
stats TrackerClientStats
|
||||
pingTicker *time.Ticker
|
||||
|
||||
WebsocketTrackerHttpHeader func() http.Header
|
||||
}
|
||||
|
||||
func (me *TrackerClient) Stats() TrackerClientStats {
|
||||
|
@ -86,7 +89,13 @@ func (tc *TrackerClient) doWebsocket() error {
|
|||
tc.mu.Lock()
|
||||
tc.stats.Dials++
|
||||
tc.mu.Unlock()
|
||||
c, _, err := tc.Dialer.Dial(tc.Url, nil)
|
||||
|
||||
var header http.Header
|
||||
if tc.WebsocketTrackerHttpHeader != nil {
|
||||
header = tc.WebsocketTrackerHttpHeader()
|
||||
}
|
||||
|
||||
c, _, err := tc.Dialer.Dial(tc.Url, header)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing tracker: %w", err)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
netHttp "net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
|
@ -12,7 +13,7 @@ import (
|
|||
"github.com/pion/datachannel"
|
||||
|
||||
"github.com/anacrolix/torrent/tracker"
|
||||
"github.com/anacrolix/torrent/tracker/http"
|
||||
httpTracker "github.com/anacrolix/torrent/tracker/http"
|
||||
"github.com/anacrolix/torrent/webtorrent"
|
||||
)
|
||||
|
||||
|
@ -43,6 +44,7 @@ type websocketTrackers struct {
|
|||
clients map[string]*refCountedWebtorrentTrackerClient
|
||||
Proxy httpTracker.ProxyFunc
|
||||
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
WebsocketTrackerHttpHeader func() netHttp.Header
|
||||
}
|
||||
|
||||
func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) {
|
||||
|
@ -61,6 +63,7 @@ func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.Tra
|
|||
Logger: me.Logger.WithText(func(m log.Msg) string {
|
||||
return fmt.Sprintf("tracker client for %q: %v", url, m)
|
||||
}),
|
||||
WebsocketTrackerHttpHeader: me.WebsocketTrackerHttpHeader,
|
||||
},
|
||||
}
|
||||
value.TrackerClient.Start(func(err error) {
|
||||
|
|
Loading…
Reference in New Issue