From 0f495ce97dcfa6c47abb8c83431ab9f5215a6b92 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Mon, 29 Nov 2021 10:07:48 +1100 Subject: [PATCH] Optimize the use of worstConnSlice again --- torrent.go | 49 ++++++++++++++++++++++--------------- worse-conns.go | 59 +++++++++++++++++++++++++++++++++++---------- worse-conns_test.go | 32 +++++++++++++++++------- 3 files changed, 98 insertions(+), 42 deletions(-) diff --git a/torrent.go b/torrent.go index d047be4e..2550f539 100644 --- a/torrent.go +++ b/torrent.go @@ -257,8 +257,14 @@ func (t *Torrent) addrActive(addr string) bool { } func (t *Torrent) appendUnclosedConns(ret []*PeerConn) []*PeerConn { + return t.appendConns(ret, func(conn *PeerConn) bool { + return !conn.closed.IsSet() + }) +} + +func (t *Torrent) appendConns(ret []*PeerConn, f func(*PeerConn) bool) []*PeerConn { for c := range t.conns { - if !c.closed.IsSet() { + if f(c) { ret = append(ret, c) } } @@ -969,21 +975,23 @@ func (t *Torrent) wantPieceIndex(index pieceIndex) bool { // conns (which is a map). var peerConnSlices sync.Pool +func getPeerConnSlice(cap int) []*PeerConn { + getInterface := peerConnSlices.Get() + if getInterface == nil { + return make([]*PeerConn, 0, cap) + } else { + return getInterface.([]*PeerConn)[:0] + } +} + // The worst connection is one that hasn't been sent, or sent anything useful for the longest. A bad // connection is one that usually sends us unwanted pieces, or has been in the worse half of the // established connections for more than a minute. This is O(n log n). If there was a way to not // consider the position of a conn relative to the total number, it could be reduced to O(n). func (t *Torrent) worstBadConn() (ret *PeerConn) { - var sl []*PeerConn - getInterface := peerConnSlices.Get() - if getInterface == nil { - sl = make([]*PeerConn, 0, len(t.conns)) - } else { - sl = getInterface.([]*PeerConn)[:0] - } - sl = t.appendUnclosedConns(sl) - defer peerConnSlices.Put(sl) - wcs := worseConnSlice{sl} + wcs := worseConnSlice{conns: t.appendUnclosedConns(getPeerConnSlice(len(t.conns)))} + defer peerConnSlices.Put(wcs.conns) + wcs.initKeys() heap.Init(&wcs) for wcs.Len() != 0 { c := heap.Pop(&wcs).(*PeerConn) @@ -1816,13 +1824,10 @@ func (t *Torrent) wantConns() bool { if t.closed.IsSet() { return false } - if len(t.conns) >= t.maxEstablishedConns && t.worstBadConn() == nil { + if !t.needData() && (!t.seeding() || !t.haveAnyPieces()) { return false } - if t.seeding() && t.haveAnyPieces() { - return true - } - return t.needData() + return len(t.conns) < t.maxEstablishedConns || t.worstBadConn() != nil } func (t *Torrent) SetMaxEstablishedConns(max int) (oldMax int) { @@ -1830,11 +1835,15 @@ func (t *Torrent) SetMaxEstablishedConns(max int) (oldMax int) { defer t.cl.unlock() oldMax = t.maxEstablishedConns t.maxEstablishedConns = max - wcs := slices.HeapInterface(slices.FromMapKeys(t.conns), func(l, r *PeerConn) bool { - return worseConn(&l.Peer, &r.Peer) - }) + wcs := worseConnSlice{ + conns: t.appendConns(nil, func(*PeerConn) bool { + return true + }), + } + wcs.initKeys() + heap.Init(&wcs) for len(t.conns) > t.maxEstablishedConns && wcs.Len() > 0 { - t.dropConnection(wcs.Pop().(*PeerConn)) + t.dropConnection(heap.Pop(&wcs).(*PeerConn)) } t.openNewConns() return oldMax diff --git a/worse-conns.go b/worse-conns.go index 2c117412..b0e0b4f2 100644 --- a/worse-conns.go +++ b/worse-conns.go @@ -7,15 +7,26 @@ import ( "unsafe" "github.com/anacrolix/multiless" + "github.com/anacrolix/sync" ) type worseConnInput struct { - Useful bool - LastHelpful time.Time - CompletedHandshake time.Time - PeerPriority peerPriority - PeerPriorityErr error - Pointer uintptr + Useful bool + LastHelpful time.Time + CompletedHandshake time.Time + GetPeerPriority func() (peerPriority, error) + getPeerPriorityOnce sync.Once + peerPriority peerPriority + peerPriorityErr error + Pointer uintptr +} + +func (me *worseConnInput) doGetPeerPriority() { + me.peerPriority, me.peerPriorityErr = me.GetPeerPriority() +} + +func (me *worseConnInput) doGetPeerPriorityOnce() { + me.getPeerPriorityOnce.Do(me.doGetPeerPriority) } func worseConnInputFromPeer(p *Peer) worseConnInput { @@ -24,23 +35,36 @@ func worseConnInputFromPeer(p *Peer) worseConnInput { LastHelpful: p.lastHelpful(), CompletedHandshake: p.completedHandshake, Pointer: uintptr(unsafe.Pointer(p)), + GetPeerPriority: p.peerPriority, } - ret.PeerPriority, ret.PeerPriorityErr = p.peerPriority() return ret } func worseConn(_l, _r *Peer) bool { - return worseConnInputFromPeer(_l).Less(worseConnInputFromPeer(_r)) + // TODO: Use generics for ptr to + l := worseConnInputFromPeer(_l) + r := worseConnInputFromPeer(_r) + return l.Less(&r) } -func (l worseConnInput) Less(r worseConnInput) bool { +func (l *worseConnInput) Less(r *worseConnInput) bool { less, ok := multiless.New().Bool( l.Useful, r.Useful).CmpInt64( l.LastHelpful.Sub(r.LastHelpful).Nanoseconds()).CmpInt64( l.CompletedHandshake.Sub(r.CompletedHandshake).Nanoseconds()).LazySameLess( func() (same, less bool) { - same = l.PeerPriorityErr != nil || r.PeerPriorityErr != nil || l.PeerPriority == r.PeerPriority - less = l.PeerPriority < r.PeerPriority + l.doGetPeerPriorityOnce() + if l.peerPriorityErr != nil { + same = true + return + } + r.doGetPeerPriorityOnce() + if r.peerPriorityErr != nil { + same = true + return + } + same = l.peerPriority == r.peerPriority + less = l.peerPriority < r.peerPriority return }).Uintptr( l.Pointer, r.Pointer, @@ -53,6 +77,14 @@ func (l worseConnInput) Less(r worseConnInput) bool { type worseConnSlice struct { conns []*PeerConn + keys []worseConnInput +} + +func (me *worseConnSlice) initKeys() { + me.keys = make([]worseConnInput, len(me.conns)) + for i, c := range me.conns { + me.keys[i] = worseConnInputFromPeer(&c.Peer) + } } var _ heap.Interface = &worseConnSlice{} @@ -62,7 +94,7 @@ func (me worseConnSlice) Len() int { } func (me worseConnSlice) Less(i, j int) bool { - return worseConn(&me.conns[i].Peer, &me.conns[j].Peer) + return me.keys[i].Less(&me.keys[j]) } func (me *worseConnSlice) Pop() interface{} { @@ -73,9 +105,10 @@ func (me *worseConnSlice) Pop() interface{} { } func (me *worseConnSlice) Push(x interface{}) { - me.conns = append(me.conns, x.(*PeerConn)) + panic("not implemented") } func (me worseConnSlice) Swap(i, j int) { me.conns[i], me.conns[j] = me.conns[j], me.conns[i] + me.keys[i], me.keys[j] = me.keys[j], me.keys[i] } diff --git a/worse-conns_test.go b/worse-conns_test.go index 39eecf78..3865b648 100644 --- a/worse-conns_test.go +++ b/worse-conns_test.go @@ -9,22 +9,36 @@ import ( func TestWorseConnLastHelpful(t *testing.T) { c := qt.New(t) - c.Check(worseConnInput{}.Less(worseConnInput{LastHelpful: time.Now()}), qt.IsTrue) - c.Check(worseConnInput{}.Less(worseConnInput{CompletedHandshake: time.Now()}), qt.IsTrue) - c.Check(worseConnInput{LastHelpful: time.Now()}.Less(worseConnInput{CompletedHandshake: time.Now()}), qt.IsFalse) - c.Check(worseConnInput{ + c.Check((&worseConnInput{}).Less(&worseConnInput{LastHelpful: time.Now()}), qt.IsTrue) + c.Check((&worseConnInput{}).Less(&worseConnInput{CompletedHandshake: time.Now()}), qt.IsTrue) + c.Check((&worseConnInput{LastHelpful: time.Now()}).Less(&worseConnInput{CompletedHandshake: time.Now()}), qt.IsFalse) + c.Check((&worseConnInput{ LastHelpful: time.Now(), - }.Less(worseConnInput{ + }).Less(&worseConnInput{ LastHelpful: time.Now(), CompletedHandshake: time.Now(), }), qt.IsTrue) now := time.Now() - c.Check(worseConnInput{ + c.Check((&worseConnInput{ LastHelpful: now, - }.Less(worseConnInput{ + }).Less(&worseConnInput{ LastHelpful: now.Add(-time.Nanosecond), CompletedHandshake: now, }), qt.IsFalse) - c.Check(worseConnInput{}.Less(worseConnInput{Pointer: 1}), qt.IsTrue) - c.Check(worseConnInput{Pointer: 2}.Less(worseConnInput{Pointer: 1}), qt.IsFalse) + readyPeerPriority := func() (peerPriority, error) { + return 42, nil + } + c.Check((&worseConnInput{ + GetPeerPriority: readyPeerPriority, + }).Less(&worseConnInput{ + GetPeerPriority: readyPeerPriority, + Pointer: 1, + }), qt.IsTrue) + c.Check((&worseConnInput{ + GetPeerPriority: readyPeerPriority, + Pointer: 2, + }).Less(&worseConnInput{ + GetPeerPriority: readyPeerPriority, + Pointer: 1, + }), qt.IsFalse) }