Add customer headers when dialling WS connection to tracker (#789)
* expose WebtorrentTrackerHttpHeader field
This commit is contained in:
parent
682c77fcb9
commit
3909c6c125
|
@ -297,8 +297,9 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
|
||||||
}
|
}
|
||||||
return t.announceRequest(event), nil
|
return t.announceRequest(event), nil
|
||||||
},
|
},
|
||||||
Proxy: cl.config.HTTPProxy,
|
Proxy: cl.config.HTTPProxy,
|
||||||
DialContext: cl.config.TrackerDialContext,
|
WebsocketTrackerHttpHeader: cl.config.WebsocketTrackerHttpHeader,
|
||||||
|
DialContext: cl.config.TrackerDialContext,
|
||||||
OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) {
|
OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) {
|
||||||
cl.lock()
|
cl.lock()
|
||||||
defer cl.unlock()
|
defer cl.unlock()
|
||||||
|
|
|
@ -117,6 +117,9 @@ type ClientConfig struct {
|
||||||
// HttpRequestDirector modifies the request before it's sent.
|
// HttpRequestDirector modifies the request before it's sent.
|
||||||
// Useful for adding authentication headers, for example
|
// Useful for adding authentication headers, for example
|
||||||
HttpRequestDirector func(*http.Request) error
|
HttpRequestDirector func(*http.Request) error
|
||||||
|
// WebsocketTrackerHttpHeader returns a custom header to be used when dialing a websocket connection
|
||||||
|
// to the tracker. Useful for adding authentication headers
|
||||||
|
WebsocketTrackerHttpHeader func() http.Header
|
||||||
// Updated occasionally to when there's been some changes to client
|
// Updated occasionally to when there's been some changes to client
|
||||||
// behaviour in case other clients are assuming anything of us. See also
|
// behaviour in case other clients are assuming anything of us. See also
|
||||||
// `bep20`.
|
// `bep20`.
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -40,6 +41,8 @@ type TrackerClient struct {
|
||||||
closed bool
|
closed bool
|
||||||
stats TrackerClientStats
|
stats TrackerClientStats
|
||||||
pingTicker *time.Ticker
|
pingTicker *time.Ticker
|
||||||
|
|
||||||
|
WebsocketTrackerHttpHeader func() http.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
func (me *TrackerClient) Stats() TrackerClientStats {
|
func (me *TrackerClient) Stats() TrackerClientStats {
|
||||||
|
@ -86,7 +89,13 @@ func (tc *TrackerClient) doWebsocket() error {
|
||||||
tc.mu.Lock()
|
tc.mu.Lock()
|
||||||
tc.stats.Dials++
|
tc.stats.Dials++
|
||||||
tc.mu.Unlock()
|
tc.mu.Unlock()
|
||||||
c, _, err := tc.Dialer.Dial(tc.Url, nil)
|
|
||||||
|
var header http.Header
|
||||||
|
if tc.WebsocketTrackerHttpHeader != nil {
|
||||||
|
header = tc.WebsocketTrackerHttpHeader()
|
||||||
|
}
|
||||||
|
|
||||||
|
c, _, err := tc.Dialer.Dial(tc.Url, header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("dialing tracker: %w", err)
|
return fmt.Errorf("dialing tracker: %w", err)
|
||||||
}
|
}
|
||||||
|
|
21
wstracker.go
21
wstracker.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
netHttp "net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -12,7 +13,7 @@ import (
|
||||||
"github.com/pion/datachannel"
|
"github.com/pion/datachannel"
|
||||||
|
|
||||||
"github.com/anacrolix/torrent/tracker"
|
"github.com/anacrolix/torrent/tracker"
|
||||||
"github.com/anacrolix/torrent/tracker/http"
|
httpTracker "github.com/anacrolix/torrent/tracker/http"
|
||||||
"github.com/anacrolix/torrent/webtorrent"
|
"github.com/anacrolix/torrent/webtorrent"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,14 +36,15 @@ type refCountedWebtorrentTrackerClient struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type websocketTrackers struct {
|
type websocketTrackers struct {
|
||||||
PeerId [20]byte
|
PeerId [20]byte
|
||||||
Logger log.Logger
|
Logger log.Logger
|
||||||
GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
|
GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
|
||||||
OnConn func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
|
OnConn func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
clients map[string]*refCountedWebtorrentTrackerClient
|
clients map[string]*refCountedWebtorrentTrackerClient
|
||||||
Proxy httpTracker.ProxyFunc
|
Proxy httpTracker.ProxyFunc
|
||||||
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
WebsocketTrackerHttpHeader func() netHttp.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) {
|
func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) {
|
||||||
|
@ -61,6 +63,7 @@ func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.Tra
|
||||||
Logger: me.Logger.WithText(func(m log.Msg) string {
|
Logger: me.Logger.WithText(func(m log.Msg) string {
|
||||||
return fmt.Sprintf("tracker client for %q: %v", url, m)
|
return fmt.Sprintf("tracker client for %q: %v", url, m)
|
||||||
}),
|
}),
|
||||||
|
WebsocketTrackerHttpHeader: me.WebsocketTrackerHttpHeader,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
value.TrackerClient.Start(func(err error) {
|
value.TrackerClient.Start(func(err error) {
|
||||||
|
|
Loading…
Reference in New Issue