diff --git a/ratelimitreader.go b/ratelimitreader.go index d7939a3d..7d9e6d86 100644 --- a/ratelimitreader.go +++ b/ratelimitreader.go @@ -1,32 +1,53 @@ package torrent import ( + "context" "fmt" "io" "time" - "golang.org/x/net/context" "golang.org/x/time/rate" ) type rateLimitedReader struct { l *rate.Limiter r io.Reader + + // This is the time of the last Read's reservation. + lastRead time.Time } -func (me rateLimitedReader) Read(b []byte) (n int, err error) { - // Wait until we can read at all. - if err := me.l.WaitN(context.Background(), 1); err != nil { - panic(err) - } - // Limit the read to within the burst. - if me.l.Limit() != rate.Inf && len(b) > me.l.Burst() { - b = b[:me.l.Burst()] - } - n, err = me.r.Read(b) - // Pay the piper. - if !me.l.ReserveN(time.Now(), n-1).OK() { - panic(fmt.Sprintf("burst exceeded?: %d", n-1)) +func (me *rateLimitedReader) Read(b []byte) (n int, err error) { + const oldStyle = false // Retained for future reference. + if oldStyle { + // Wait until we can read at all. + if err := me.l.WaitN(context.Background(), 1); err != nil { + panic(err) + } + // Limit the read to within the burst. + if me.l.Limit() != rate.Inf && len(b) > me.l.Burst() { + b = b[:me.l.Burst()] + } + n, err = me.r.Read(b) + // Pay the piper. + now := time.Now() + me.lastRead = now + if !me.l.ReserveN(now, n-1).OK() { + panic(fmt.Sprintf("burst exceeded?: %d", n-1)) + } + } else { + // Limit the read to within the burst. + if me.l.Limit() != rate.Inf && len(b) > me.l.Burst() { + b = b[:me.l.Burst()] + } + n, err = me.r.Read(b) + now := time.Now() + r := me.l.ReserveN(now, n) + if !r.OK() { + panic(n) + } + me.lastRead = now + time.Sleep(r.Delay()) } return } diff --git a/rlreader_test.go b/rlreader_test.go new file mode 100644 index 00000000..529cbee6 --- /dev/null +++ b/rlreader_test.go @@ -0,0 +1,130 @@ +package torrent + +import ( + "io" + "log" + "math/rand" + "sync" + "time" + + "github.com/bradfitz/iter" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" + + "testing" +) + +func writeN(ws []io.Writer, n int) error { + b := make([]byte, n) + for _, w := range ws[1:] { + n1 := rand.Intn(n) + wn, err := w.Write(b[:n1]) + if wn != n1 { + if err == nil { + panic(n1) + } + return err + } + n -= n1 + } + wn, err := ws[0].Write(b[:n]) + if wn != n { + if err == nil { + panic(n) + } + } + return err +} + +func TestRateLimitReaders(t *testing.T) { + const ( + numReaders = 2 + bytesPerSecond = 100 + burst = 5 + readSize = 6 + writeRounds = 10 + bytesPerRound = 12 + ) + control := rate.NewLimiter(bytesPerSecond, burst) + shared := rate.NewLimiter(bytesPerSecond, burst) + var ( + ws []io.Writer + cs []io.Closer + ) + wg := sync.WaitGroup{} + type read struct { + N int + // When the read was allowed. + At time.Time + } + reads := make(chan read) + done := make(chan struct{}) + for range iter.N(numReaders) { + r, w := io.Pipe() + ws = append(ws, w) + cs = append(cs, w) + wg.Add(1) + go func() { + defer wg.Done() + r := rateLimitedReader{ + l: shared, + r: r, + } + b := make([]byte, readSize) + for { + n, err := r.Read(b) + select { + case reads <- read{n, r.lastRead}: + case <-done: + return + } + if err == io.EOF { + return + } + if err != nil { + panic(err) + } + } + }() + } + closeAll := func() { + for _, c := range cs { + c.Close() + } + } + defer func() { + close(done) + closeAll() + wg.Wait() + }() + written := 0 + go func() { + for range iter.N(writeRounds) { + err := writeN(ws, bytesPerRound) + if err != nil { + log.Printf("error writing: %s", err) + break + } + written += bytesPerRound + } + closeAll() + wg.Wait() + close(reads) + }() + totalBytesRead := 0 + started := time.Now() + for r := range reads { + totalBytesRead += r.N + require.False(t, r.At.IsZero()) + // Copy what the reader should have done with its reservation. + res := control.ReserveN(r.At, r.N) + // If we don't have to wait with the control, the reader has gone too + // fast. + if res.Delay() > 0 { + log.Printf("%d bytes not allowed at %s", r.N, time.Since(started)) + t.FailNow() + } + } + assert.EqualValues(t, writeRounds*bytesPerRound, totalBytesRead) +}