diff --git a/webtorrent/tracker-client.go b/webtorrent/tracker-client.go index d65dcab4..311b296a 100644 --- a/webtorrent/tracker-client.go +++ b/webtorrent/tracker-client.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/json" "fmt" + "github.com/anacrolix/generics" "go.opentelemetry.io/otel/trace" "sync" "time" @@ -34,7 +35,7 @@ type TrackerClient struct { mu sync.Mutex cond sync.Cond - outboundOffers map[string]outboundOffer // OfferID to outboundOffer + outboundOffers map[string]outboundOfferValue // OfferID to outboundOfferValue wsConn *websocket.Conn closed bool stats TrackerClientStats @@ -51,8 +52,13 @@ func (me *TrackerClient) peerIdBinary() string { return binaryToJsonString(me.PeerId[:]) } -// outboundOffer represents an outstanding offer. type outboundOffer struct { + offerId string + outboundOfferValue +} + +// outboundOfferValue represents an outstanding offer. +type outboundOfferValue struct { originalOffer webrtc.SessionDescription peerConnection *wrappedPeerConnection infoHash [20]byte @@ -202,6 +208,9 @@ func (tc *TrackerClient) CloseOffersForInfohash(infoHash [20]byte) { func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error { metrics.Add("outbound announces", 1) + if event == tracker.Stopped { + return tc.announce(event, infoHash, nil) + } var randOfferId [20]byte _, err := rand.Read(randOfferId[:]) if err != nil { @@ -214,14 +223,29 @@ func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte return fmt.Errorf("creating offer: %w", err) } - request, err := tc.GetAnnounceRequest(event, infoHash) + err = tc.announce(event, infoHash, []outboundOffer{{ + offerId: offerIDBinary, + outboundOfferValue: outboundOfferValue{ + originalOffer: offer, + peerConnection: pc, + infoHash: infoHash, + dataChannel: dc, + }}, + }) if err != nil { pc.Close() + } + return err +} + +func (tc *TrackerClient) announce(event tracker.AnnounceEvent, infoHash [20]byte, offers []outboundOffer) error { + request, err := tc.GetAnnounceRequest(event, infoHash) + if err != nil { return fmt.Errorf("getting announce parameters: %w", err) } req := AnnounceRequest{ - Numwant: 1, // If higher we need to create equal amount of offers. + Numwant: len(offers), Uploaded: request.Uploaded, Downloaded: request.Downloaded, Left: request.Left, @@ -229,15 +253,16 @@ func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte Action: "announce", InfoHash: binaryToJsonString(infoHash[:]), PeerID: tc.peerIdBinary(), - Offers: []Offer{{ - OfferID: offerIDBinary, - Offer: offer, - }}, + } + for _, offer := range offers { + req.Offers = append(req.Offers, Offer{ + OfferID: offer.offerId, + Offer: offer.originalOffer, + }) } data, err := json.Marshal(req) if err != nil { - pc.Close() return fmt.Errorf("marshalling request: %w", err) } @@ -245,17 +270,10 @@ func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte defer tc.mu.Unlock() err = tc.writeMessage(data) if err != nil { - pc.Close() return fmt.Errorf("write AnnounceRequest: %w", err) } - if tc.outboundOffers == nil { - tc.outboundOffers = make(map[string]outboundOffer) - } - tc.outboundOffers[offerIDBinary] = outboundOffer{ - peerConnection: pc, - originalOffer: offer, - infoHash: infoHash, - dataChannel: dc, + for _, offer := range offers { + generics.MakeMapIfNilAndSet(&tc.outboundOffers, offer.offerId, offer.outboundOfferValue) } return nil }