Rewrite peerConnMsgWriter.run

This commit is contained in:
Matt Joiner 2021-08-18 16:51:30 +10:00
parent 8f187411cc
commit 092dc74458
1 changed files with 5 additions and 27 deletions

View File

@ -59,35 +59,15 @@ type peerConnMsgWriter struct {
// activity elsewhere in the Client, and some is determined locally when the // activity elsewhere in the Client, and some is determined locally when the
// connection is writable. // connection is writable.
func (cn *peerConnMsgWriter) run(keepAliveTimeout time.Duration) { func (cn *peerConnMsgWriter) run(keepAliveTimeout time.Duration) {
var ( lastWrite := time.Now()
lastWrite time.Time = time.Now()
keepAliveTimer *time.Timer
)
cn.mu.Lock()
defer cn.mu.Unlock()
keepAliveTimer = time.AfterFunc(keepAliveTimeout, func() {
cn.mu.Lock()
defer cn.mu.Unlock()
if time.Since(lastWrite) >= keepAliveTimeout {
cn.writeCond.Broadcast()
}
keepAliveTimer.Reset(keepAliveTimeout)
})
defer keepAliveTimer.Stop()
frontBuf := new(bytes.Buffer) frontBuf := new(bytes.Buffer)
for { for {
if cn.closed.IsSet() { if cn.closed.IsSet() {
return return
} }
keepAlive := false cn.fillWriteBuffer()
if cn.writeBuffer.Len() == 0 { keepAlive := cn.keepAlive()
func() { cn.mu.Lock()
cn.mu.Unlock()
defer cn.mu.Lock()
cn.fillWriteBuffer()
keepAlive = cn.keepAlive()
}()
}
if cn.writeBuffer.Len() == 0 && time.Since(lastWrite) >= keepAliveTimeout && keepAlive { if cn.writeBuffer.Len() == 0 && time.Since(lastWrite) >= keepAliveTimeout && keepAlive {
cn.writeBuffer.Write(pp.Message{Keepalive: true}.MustMarshalBinary()) cn.writeBuffer.Write(pp.Message{Keepalive: true}.MustMarshalBinary())
torrent.Add("written keepalives", 1) torrent.Add("written keepalives", 1)
@ -98,18 +78,16 @@ func (cn *peerConnMsgWriter) run(keepAliveTimeout time.Duration) {
select { select {
case <-cn.closed.Done(): case <-cn.closed.Done():
case <-writeCond: case <-writeCond:
case <-time.After(time.Until(lastWrite.Add(keepAliveTimeout))):
} }
cn.mu.Lock()
continue continue
} }
// Flip the buffers. // Flip the buffers.
frontBuf, cn.writeBuffer = cn.writeBuffer, frontBuf frontBuf, cn.writeBuffer = cn.writeBuffer, frontBuf
cn.mu.Unlock() cn.mu.Unlock()
n, err := cn.w.Write(frontBuf.Bytes()) n, err := cn.w.Write(frontBuf.Bytes())
cn.mu.Lock()
if n != 0 { if n != 0 {
lastWrite = time.Now() lastWrite = time.Now()
keepAliveTimer.Reset(keepAliveTimeout)
} }
if err != nil { if err != nil {
cn.logger.WithDefaultLevel(log.Debug).Printf("error writing: %v", err) cn.logger.WithDefaultLevel(log.Debug).Printf("error writing: %v", err)