From 092dc74458abadf8ecb4af0e87e8b2c4a0cc63ec Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Wed, 18 Aug 2021 16:51:30 +1000 Subject: [PATCH] Rewrite peerConnMsgWriter.run --- peer-conn-msg-writer.go | 32 +++++--------------------------- 1 file changed, 5 insertions(+), 27 deletions(-) diff --git a/peer-conn-msg-writer.go b/peer-conn-msg-writer.go index dff4eb9e..4dbf00e4 100644 --- a/peer-conn-msg-writer.go +++ b/peer-conn-msg-writer.go @@ -59,35 +59,15 @@ type peerConnMsgWriter struct { // activity elsewhere in the Client, and some is determined locally when the // connection is writable. func (cn *peerConnMsgWriter) run(keepAliveTimeout time.Duration) { - var ( - lastWrite time.Time = time.Now() - keepAliveTimer *time.Timer - ) - cn.mu.Lock() - defer cn.mu.Unlock() - keepAliveTimer = time.AfterFunc(keepAliveTimeout, func() { - cn.mu.Lock() - defer cn.mu.Unlock() - if time.Since(lastWrite) >= keepAliveTimeout { - cn.writeCond.Broadcast() - } - keepAliveTimer.Reset(keepAliveTimeout) - }) - defer keepAliveTimer.Stop() + lastWrite := time.Now() frontBuf := new(bytes.Buffer) for { if cn.closed.IsSet() { return } - keepAlive := false - if cn.writeBuffer.Len() == 0 { - func() { - cn.mu.Unlock() - defer cn.mu.Lock() - cn.fillWriteBuffer() - keepAlive = cn.keepAlive() - }() - } + cn.fillWriteBuffer() + keepAlive := cn.keepAlive() + cn.mu.Lock() if cn.writeBuffer.Len() == 0 && time.Since(lastWrite) >= keepAliveTimeout && keepAlive { cn.writeBuffer.Write(pp.Message{Keepalive: true}.MustMarshalBinary()) torrent.Add("written keepalives", 1) @@ -98,18 +78,16 @@ func (cn *peerConnMsgWriter) run(keepAliveTimeout time.Duration) { select { case <-cn.closed.Done(): case <-writeCond: + case <-time.After(time.Until(lastWrite.Add(keepAliveTimeout))): } - cn.mu.Lock() continue } // Flip the buffers. frontBuf, cn.writeBuffer = cn.writeBuffer, frontBuf cn.mu.Unlock() n, err := cn.w.Write(frontBuf.Bytes()) - cn.mu.Lock() if n != 0 { lastWrite = time.Now() - keepAliveTimer.Reset(keepAliveTimeout) } if err != nil { cn.logger.WithDefaultLevel(log.Debug).Printf("error writing: %v", err)