diff --git a/client.go b/client.go index 8a94e021..43a4c46f 100644 --- a/client.go +++ b/client.go @@ -621,8 +621,8 @@ func (cl *Client) noLongerHalfOpen(t *Torrent, addr string) { t.openNewConns() } -// Performs initiator handshakes and returns a connection. Returns nil -// *connection if no connection for valid reasons. +// Performs initiator handshakes and returns a connection. Returns nil *connection if no connection +// for valid reasons. func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encryptHeader bool, remoteAddr net.Addr, network string) (c *PeerConn, err error) { c = cl.newConnection(nc, true, remoteAddr, network) c.headerEncrypted = encryptHeader @@ -850,6 +850,7 @@ func (cl *Client) runReceivedConn(c *PeerConn) { cl.runHandshookConn(c, t) } +// Client lock must be held before entering this. func (cl *Client) runHandshookConn(c *PeerConn, t *Torrent) { c.setTorrent(t) if c.PeerID == cl.peerID { diff --git a/torrent.go b/torrent.go index ca8b06f5..de508e61 100644 --- a/torrent.go +++ b/torrent.go @@ -2,6 +2,7 @@ package torrent import ( "container/heap" + "context" "crypto/sha1" "errors" "fmt" @@ -14,6 +15,7 @@ import ( "unsafe" "github.com/davecgh/go-spew/spew" + "github.com/pion/datachannel" "github.com/anacrolix/dht/v2" "github.com/anacrolix/log" @@ -1262,6 +1264,25 @@ func (t *Torrent) seeding() bool { return true } +func (t *Torrent) onWebRtcConn( + c datachannel.ReadWriteCloser, + initiatedLocally bool, // Whether we offered first, or they did. +) { + defer c.Close() + pc, err := t.cl.handshakesConnection(context.Background(), webrtcNetConn{c}, t, false, nil, "webrtc") + if err != nil { + t.logger.Printf("error in handshaking webrtc connection: %v", err) + } + if initiatedLocally { + pc.Discovery = PeerSourceTracker + } else { + pc.Discovery = PeerSourceIncoming + } + t.cl.lock() + defer t.cl.unlock() + t.cl.runHandshookConn(pc, t) +} + func (t *Torrent) startScrapingTracker(_url string) { if _url == "" { return @@ -1288,7 +1309,7 @@ func (t *Torrent) startScrapingTracker(_url string) { sl := func() torrentTrackerAnnouncer { switch u.Scheme { case "ws", "wss": - wst := websocketTracker{*u, webtorrent.NewClient(t.cl.peerID, t.infoHash)} + wst := websocketTracker{*u, webtorrent.NewClient(t.cl.peerID, t.infoHash, t.onWebRtcConn)} go func() { err := wst.Client.Run(t.announceRequest(tracker.Started)) if err != nil { diff --git a/webrtc.go b/webrtc.go new file mode 100644 index 00000000..d805b54b --- /dev/null +++ b/webrtc.go @@ -0,0 +1,43 @@ +package torrent + +import ( + "net" + "time" + + "github.com/pion/datachannel" +) + +type webrtcNetConn struct { + datachannel.ReadWriteCloser +} + +type webrtcNetAddr struct { +} + +func (webrtcNetAddr) Network() string { + return "webrtc" +} + +func (webrtcNetAddr) String() string { + return "" +} + +func (w webrtcNetConn) LocalAddr() net.Addr { + return webrtcNetAddr{} +} + +func (w webrtcNetConn) RemoteAddr() net.Addr { + return webrtcNetAddr{} +} + +func (w webrtcNetConn) SetDeadline(t time.Time) error { + return nil +} + +func (w webrtcNetConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (w webrtcNetConn) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/webtorrent/client.go b/webtorrent/client.go index 9b49e15d..d4dd6049 100644 --- a/webtorrent/client.go +++ b/webtorrent/client.go @@ -3,7 +3,6 @@ package webtorrent import ( "encoding/json" "fmt" - "io" "sync" "github.com/anacrolix/log" @@ -26,6 +25,7 @@ type Client struct { infoHashBinary string offeredPeers map[string]Peer // OfferID to Peer tracker *websocket.Conn + onConn func(_ datachannel.ReadWriteCloser, initiatedLocally bool) } // Peer represents a remote peer @@ -42,11 +42,14 @@ func binaryToJsonString(b []byte) string { return string(seq) } -func NewClient(peerId, infoHash [20]byte) *Client { +type onDataChannelOpen func(_ datachannel.ReadWriteCloser, initiatedLocally bool) + +func NewClient(peerId, infoHash [20]byte, onConn onDataChannelOpen) *Client { return &Client{ offeredPeers: make(map[string]Peer), peerIDBinary: binaryToJsonString(peerId[:]), infoHashBinary: binaryToJsonString(infoHash[:]), + onConn: onConn, } } @@ -134,7 +137,9 @@ func (c *Client) trackerReadLoop() error { } switch { case ar.Offer != nil: - t, answer, err := NewTransportFromOffer(*ar.Offer, c.handleDataChannel) + t, answer, err := NewTransportFromOffer(*ar.Offer, func(dc datachannel.ReadWriteCloser) { + c.onConn(dc, false) + }) if err != nil { return fmt.Errorf("write AnnounceResponse: %w", err) } @@ -170,8 +175,10 @@ func (c *Client) trackerReadLoop() error { log.Printf("could not find peer for offer %q", ar.OfferID) continue } - log.Printf("offer %q got answer %q", ar.OfferID, ar.Answer) - err = peer.transport.SetAnswer(*ar.Answer, c.handleDataChannel) + log.Printf("offer %q got answer %v", ar.OfferID, *ar.Answer) + err = peer.transport.SetAnswer(*ar.Answer, func(dc datachannel.ReadWriteCloser) { + c.onConn(dc, true) + }) if err != nil { return fmt.Errorf("failed to sent answer: %v", err) } @@ -179,23 +186,6 @@ func (c *Client) trackerReadLoop() error { } } -func (c *Client) handleDataChannel(dc datachannel.ReadWriteCloser) { - go c.dcReadLoop(dc) - //go c.dcWriteLoop(dc) -} - -func (c *Client) dcReadLoop(d io.Reader) { - for { - buffer := make([]byte, 1024) - n, err := d.Read(buffer) - if err != nil { - log.Printf("Datachannel closed; Exit the readloop: %v", err) - } - - fmt.Printf("Message from DataChannel: %s\n", string(buffer[:n])) - } -} - type AnnounceRequest struct { Numwant int `json:"numwant"` Uploaded int `json:"uploaded"`