Big tidy up of webtorrent code

This commit is contained in:
Matt Joiner 2020-04-07 14:30:27 +10:00
parent bcdccb1ff3
commit 6f2c65fe33
4 changed files with 80 additions and 55 deletions

View File

@ -1266,14 +1266,21 @@ func (t *Torrent) seeding() bool {
func (t *Torrent) onWebRtcConn( func (t *Torrent) onWebRtcConn(
c datachannel.ReadWriteCloser, c datachannel.ReadWriteCloser,
initiatedLocally bool, // Whether we offered first, or they did. dcc webtorrent.DataChannelContext,
) { ) {
defer c.Close() defer c.Close()
pc, err := t.cl.handshakesConnection(context.Background(), webrtcNetConn{c}, t, false, nil, "webrtc") pc, err := t.cl.handshakesConnection(
context.Background(),
webrtcNetConn{c, dcc},
t,
false,
webrtcNetAddr{dcc.Remote},
webrtcNetwork,
)
if err != nil { if err != nil {
t.logger.Printf("error in handshaking webrtc connection: %v", err) t.logger.Printf("error in handshaking webrtc connection: %v", err)
} }
if initiatedLocally { if dcc.LocalOffered {
pc.Discovery = PeerSourceTracker pc.Discovery = PeerSourceTracker
} else { } else {
pc.Discovery = PeerSourceIncoming pc.Discovery = PeerSourceIncoming
@ -1309,11 +1316,11 @@ func (t *Torrent) startScrapingTracker(_url string) {
sl := func() torrentTrackerAnnouncer { sl := func() torrentTrackerAnnouncer {
switch u.Scheme { switch u.Scheme {
case "ws", "wss": case "ws", "wss":
wst := websocketTracker{*u, webtorrent.NewClient(t.cl.peerID, t.infoHash, t.onWebRtcConn)} wst := websocketTracker{*u, webtorrent.NewClient(t.cl.peerID, t.infoHash, t.onWebRtcConn, t.logger)}
go func() { go func() {
err := wst.Client.Run(t.announceRequest(tracker.Started)) err := wst.Client.Run(t.announceRequest(tracker.Started), u.String())
if err != nil { if err != nil {
t.logger.Printf("error running websocket tracker announcer: %v", err) t.logger.WithValues(log.Error).Printf("error running websocket tracker announcer: %v", err)
} }
}() }()
return wst return wst

View File

@ -5,29 +5,37 @@ import (
"time" "time"
"github.com/pion/datachannel" "github.com/pion/datachannel"
"github.com/pion/webrtc/v2"
"github.com/anacrolix/torrent/webtorrent"
) )
const webrtcNetwork = "webrtc"
type webrtcNetConn struct { type webrtcNetConn struct {
datachannel.ReadWriteCloser datachannel.ReadWriteCloser
webtorrent.DataChannelContext
} }
type webrtcNetAddr struct { type webrtcNetAddr struct {
webrtc.SessionDescription
} }
func (webrtcNetAddr) Network() string { func (webrtcNetAddr) Network() string {
return "webrtc" return webrtcNetwork
} }
func (webrtcNetAddr) String() string { func (me webrtcNetAddr) String() string {
return "" // TODO: What can I show here that's more like other protocols?
return "<WebRTC>"
} }
func (w webrtcNetConn) LocalAddr() net.Addr { func (me webrtcNetConn) LocalAddr() net.Addr {
return webrtcNetAddr{} return webrtcNetAddr{me.Local}
} }
func (w webrtcNetConn) RemoteAddr() net.Addr { func (me webrtcNetConn) RemoteAddr() net.Addr {
return webrtcNetAddr{} return webrtcNetAddr{me.Remote}
} }
func (w webrtcNetConn) SetDeadline(t time.Time) error { func (w webrtcNetConn) SetDeadline(t time.Time) error {

View File

@ -14,23 +14,20 @@ import (
"github.com/pion/webrtc/v2" "github.com/pion/webrtc/v2"
) )
const (
trackerURL = `wss://tracker.openwebtorrent.com/` // For simplicity
)
// Client represents the webtorrent client // Client represents the webtorrent client
type Client struct { type Client struct {
lock sync.Mutex lock sync.Mutex
peerIDBinary string peerIDBinary string
infoHashBinary string infoHashBinary string
offeredPeers map[string]Peer // OfferID to Peer outboundOffers map[string]outboundOffer // OfferID to outboundOffer
tracker *websocket.Conn tracker *websocket.Conn
onConn func(_ datachannel.ReadWriteCloser, initiatedLocally bool) onConn onDataChannelOpen
logger log.Logger
} }
// Peer represents a remote peer // outboundOffer represents an outstanding offer.
type Peer struct { type outboundOffer struct {
peerID string originalOffer webrtc.SessionDescription
transport *Transport transport *Transport
} }
@ -42,33 +39,43 @@ func binaryToJsonString(b []byte) string {
return string(seq) return string(seq)
} }
type onDataChannelOpen func(_ datachannel.ReadWriteCloser, initiatedLocally bool) type DataChannelContext struct {
Local, Remote webrtc.SessionDescription
LocalOffered bool
}
func NewClient(peerId, infoHash [20]byte, onConn onDataChannelOpen) *Client { type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
func NewClient(peerId, infoHash [20]byte, onConn onDataChannelOpen, logger log.Logger) *Client {
return &Client{ return &Client{
offeredPeers: make(map[string]Peer), outboundOffers: make(map[string]outboundOffer),
peerIDBinary: binaryToJsonString(peerId[:]), peerIDBinary: binaryToJsonString(peerId[:]),
infoHashBinary: binaryToJsonString(infoHash[:]), infoHashBinary: binaryToJsonString(infoHash[:]),
onConn: onConn, onConn: onConn,
logger: logger,
} }
} }
func (c *Client) Run(ar tracker.AnnounceRequest) error { func (c *Client) Run(ar tracker.AnnounceRequest, url string) error {
t, _, err := websocket.DefaultDialer.Dial(trackerURL, nil) t, _, err := websocket.DefaultDialer.Dial(url, nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to dial tracker: %v", err) return fmt.Errorf("failed to dial tracker: %w", err)
} }
defer t.Close() defer t.Close()
c.logger.WithValues(log.Info).Printf("dialed tracker %q", url)
c.tracker = t c.tracker = t
go c.announce(ar) go func() {
c.trackerReadLoop() err := c.announce(ar)
if err != nil {
return nil c.logger.WithValues(log.Error).Printf("error announcing: %v", err)
}
}()
return c.trackerReadLoop()
} }
func (c *Client) announce(request tracker.AnnounceRequest) error { func (c *Client) announce(request tracker.AnnounceRequest) error {
transpot, offer, err := NewTransport() transport, offer, err := NewTransport()
if err != nil { if err != nil {
return fmt.Errorf("failed to create transport: %w", err) return fmt.Errorf("failed to create transport: %w", err)
} }
@ -77,11 +84,13 @@ func (c *Client) announce(request tracker.AnnounceRequest) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to generate bytes: %w", err) return fmt.Errorf("failed to generate bytes: %w", err)
} }
// OfferID := randOfferID.ToStringHex()
offerIDBinary := randOfferID.ToStringLatin1() offerIDBinary := randOfferID.ToStringLatin1()
c.lock.Lock() c.lock.Lock()
c.offeredPeers[offerIDBinary] = Peer{transport: transpot} c.outboundOffers[offerIDBinary] = outboundOffer{
transport: transport,
originalOffer: offer,
}
c.lock.Unlock() c.lock.Unlock()
req := AnnounceRequest{ req := AnnounceRequest{
@ -124,7 +133,7 @@ func (c *Client) trackerReadLoop() error {
if err != nil { if err != nil {
return fmt.Errorf("read error: %w", err) return fmt.Errorf("read error: %w", err)
} }
log.Printf("recv: %q", message) c.logger.WithValues(log.Debug).Printf("received message from tracker: %q", message)
var ar AnnounceResponse var ar AnnounceResponse
if err := json.Unmarshal(message, &ar); err != nil { if err := json.Unmarshal(message, &ar); err != nil {
@ -137,9 +146,7 @@ func (c *Client) trackerReadLoop() error {
} }
switch { switch {
case ar.Offer != nil: case ar.Offer != nil:
t, answer, err := NewTransportFromOffer(*ar.Offer, func(dc datachannel.ReadWriteCloser) { _, answer, err := NewTransportFromOffer(*ar.Offer, c.onConn)
c.onConn(dc, false)
})
if err != nil { if err != nil {
return fmt.Errorf("write AnnounceResponse: %w", err) return fmt.Errorf("write AnnounceResponse: %w", err)
} }
@ -164,20 +171,21 @@ func (c *Client) trackerReadLoop() error {
c.lock.Unlock() c.lock.Unlock()
} }
c.lock.Unlock() c.lock.Unlock()
// Do something with the peer
_ = Peer{peerID: ar.PeerID, transport: t}
case ar.Answer != nil: case ar.Answer != nil:
c.lock.Lock() c.lock.Lock()
peer, ok := c.offeredPeers[ar.OfferID] offer, ok := c.outboundOffers[ar.OfferID]
c.lock.Unlock() c.lock.Unlock()
if !ok { if !ok {
log.Printf("could not find peer for offer %q", ar.OfferID) c.logger.WithValues(log.Warning).Printf("could not find offer for id %q", ar.OfferID)
continue continue
} }
log.Printf("offer %q got answer %v", ar.OfferID, *ar.Answer) log.Printf("offer %q got answer %v", ar.OfferID, *ar.Answer)
err = peer.transport.SetAnswer(*ar.Answer, func(dc datachannel.ReadWriteCloser) { err = offer.transport.SetAnswer(*ar.Answer, func(dc datachannel.ReadWriteCloser) {
c.onConn(dc, true) c.onConn(dc, DataChannelContext{
Local: offer.originalOffer,
Remote: *ar.Answer,
LocalOffered: true,
})
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to sent answer: %v", err) return fmt.Errorf("failed to sent answer: %v", err)

View File

@ -67,7 +67,7 @@ func NewTransport() (*Transport, webrtc.SessionDescription, error) {
// NewTransportFromOffer creates a transport from a WebRTC offer and and returns a WebRTC answer to // NewTransportFromOffer creates a transport from a WebRTC offer and and returns a WebRTC answer to
// be announced. // be announced.
func NewTransportFromOffer(offer webrtc.SessionDescription, onOpen func(datachannel.ReadWriteCloser)) (*Transport, webrtc.SessionDescription, error) { func NewTransportFromOffer(offer webrtc.SessionDescription, onOpen onDataChannelOpen) (*Transport, webrtc.SessionDescription, error) {
peerConnection, err := newPeerConnection() peerConnection, err := newPeerConnection()
if err != nil { if err != nil {
return nil, webrtc.SessionDescription{}, fmt.Errorf("failed to peer connection: %v", err) return nil, webrtc.SessionDescription{}, fmt.Errorf("failed to peer connection: %v", err)
@ -77,13 +77,6 @@ func NewTransportFromOffer(offer webrtc.SessionDescription, onOpen func(datachan
}) })
t := &Transport{pc: peerConnection} t := &Transport{pc: peerConnection}
peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
fmt.Printf("New DataChannel %s %d\n", d.Label(), d.ID())
t.lock.Lock()
t.dc = d
t.lock.Unlock()
t.handleOpen(onOpen)
})
err = peerConnection.SetRemoteDescription(offer) err = peerConnection.SetRemoteDescription(offer)
if err != nil { if err != nil {
@ -93,6 +86,15 @@ func NewTransportFromOffer(offer webrtc.SessionDescription, onOpen func(datachan
if err != nil { if err != nil {
return nil, webrtc.SessionDescription{}, fmt.Errorf("%v", err) return nil, webrtc.SessionDescription{}, fmt.Errorf("%v", err)
} }
peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
fmt.Printf("New DataChannel %s %d\n", d.Label(), d.ID())
t.lock.Lock()
t.dc = d
t.lock.Unlock()
t.handleOpen(func(dc datachannel.ReadWriteCloser) {
onOpen(dc, DataChannelContext{answer, offer, false})
})
})
err = peerConnection.SetLocalDescription(answer) err = peerConnection.SetLocalDescription(answer)
if err != nil { if err != nil {
return nil, webrtc.SessionDescription{}, fmt.Errorf("%v", err) return nil, webrtc.SessionDescription{}, fmt.Errorf("%v", err)