Implement smart ban using generics

This commit is contained in:
Matt Joiner 2022-01-10 16:19:16 +11:00
parent 245c728762
commit 53cf508061
6 changed files with 187 additions and 10 deletions

View File

@ -4,6 +4,7 @@ import (
"bufio"
"context"
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"errors"
"expvar"
@ -27,6 +28,7 @@ import (
"github.com/anacrolix/missinggo/v2/bitmap"
"github.com/anacrolix/missinggo/v2/pproffd"
"github.com/anacrolix/sync"
"github.com/anacrolix/torrent/option"
request_strategy "github.com/anacrolix/torrent/request-strategy"
"github.com/davecgh/go-spew/spew"
"github.com/dustin/go-humanize"
@ -74,6 +76,7 @@ type Client struct {
// through legitimate channels.
dopplegangerAddrs map[string]struct{}
badPeerIPs map[string]struct{}
bannedPrefixes map[string]struct{}
torrents map[InfoHash]*Torrent
pieceRequestOrder map[interface{}]*request_strategy.PieceRequestOrder
@ -210,6 +213,7 @@ func (cl *Client) init(cfg *ClientConfig) {
MaxConnsPerHost: 10,
},
}
cl.bannedPrefixes = make(map[banPrefix]struct{})
}
func NewClient(cfg *ClientConfig) (cl *Client, err error) {
@ -1126,6 +1130,12 @@ func (cl *Client) badPeerAddr(addr PeerRemoteAddr) bool {
if ipa, ok := tryIpPortFromNetAddr(addr); ok {
return cl.badPeerIPPort(ipa.IP, ipa.Port)
}
addrStr := addr.String()
for prefix := range cl.bannedPrefixes {
if strings.HasPrefix(addrStr, prefix) {
return true
}
}
return false
}
@ -1185,6 +1195,8 @@ func (cl *Client) newTorrentOpt(opts AddTorrentOpts) (t *Torrent) {
webSeeds: make(map[string]*Peer),
gotMetainfoC: make(chan struct{}),
}
t.smartBanCache.Hash = sha1.Sum
t.smartBanCache.Init()
t.networkingEnabled.Set()
t.logger = cl.logger.WithContextValue(t)
if opts.ChunkSize == 0 {
@ -1508,6 +1520,9 @@ func (cl *Client) newConnection(nc net.Conn, outgoing bool, remoteAddr PeerRemot
connString: connString,
conn: nc,
}
if remoteAddr != nil {
c.banPrefix = option.Some(remoteAddr.String())
}
c.peerImpl = c
c.logger = cl.logger.WithDefaultLevel(log.Warning).WithContextValue(c)
c.setRW(connStatsReadWriter{nc, c})
@ -1676,3 +1691,7 @@ func (cl *Client) String() string {
func (cl *Client) ConnStats() ConnStats {
return cl.stats.Copy()
}
func (cl *Client) banPrefix(prefix banPrefix) {
cl.bannedPrefixes[prefix] = struct{}{}
}

21
option/option.go Normal file
View File

@ -0,0 +1,21 @@
package option
type T[V any] struct {
ok bool
value V
}
func (me *T[V]) Ok() bool {
return me.ok
}
func (me *T[V]) Value() V {
if !me.ok {
panic("not set")
}
return me.value
}
func Some[V any](value V) T[V] {
return T[V]{ok: true, value: value}
}

View File

@ -19,6 +19,7 @@ import (
"github.com/anacrolix/missinggo/iter"
"github.com/anacrolix/missinggo/v2/bitmap"
"github.com/anacrolix/multiless"
"github.com/anacrolix/torrent/option"
"github.com/anacrolix/chansync"
"github.com/anacrolix/torrent/bencode"
@ -67,6 +68,7 @@ type Peer struct {
outgoing bool
Network string
RemoteAddr PeerRemoteAddr
banPrefix option.T[string]
// True if the connection is operating over MSE obfuscation.
headerEncrypted bool
cryptoMethod mse.CryptoMethod
@ -1386,6 +1388,11 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
ppReq := newRequestFromMessage(msg)
req := c.t.requestIndexFromRequest(ppReq)
t := c.t
if c.banPrefix.Ok() {
t.smartBanCache.RecordBlock(c.banPrefix.Value(), req, msg.Piece)
}
if c.peerChoking {
chunksReceived.Add("while choked", 1)
@ -1425,7 +1432,6 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
}
}
t := c.t
cl := t.cl
// Do we actually want this chunk?

53
smartban.go Normal file
View File

@ -0,0 +1,53 @@
package torrent
import (
"bytes"
"crypto/sha1"
"github.com/anacrolix/torrent/smartban"
)
type banPrefix = string
type smartBanCache = smartban.Cache[banPrefix, RequestIndex, [sha1.Size]byte]
type blockCheckingWriter struct {
cache *smartBanCache
requestIndex RequestIndex
// Peers that didn't match blocks written now.
badPeers map[banPrefix]struct{}
blockBuffer bytes.Buffer
chunkSize int
}
func (me *blockCheckingWriter) checkBlock() {
b := me.blockBuffer.Next(me.chunkSize)
for _, peer := range me.cache.CheckBlock(me.requestIndex, b) {
me.badPeers[peer] = struct{}{}
}
me.requestIndex++
}
func (me *blockCheckingWriter) checkFullBlocks() {
for me.blockBuffer.Len() >= me.chunkSize {
me.checkBlock()
}
}
func (me *blockCheckingWriter) Write(b []byte) (int, error) {
n, err := me.blockBuffer.Write(b)
if err != nil {
// bytes.Buffer.Write should never fail.
panic(err)
}
me.checkFullBlocks()
return n, err
}
// Check any remaining block data. Terminal pieces or piece sizes that don't divide into the chunk
// size cleanly may leave fragments that should be checked.
func (me *blockCheckingWriter) Flush() {
for me.blockBuffer.Len() != 0 {
me.checkBlock()
}
}

51
smartban/smartban.go Normal file
View File

@ -0,0 +1,51 @@
package smartban
import (
"sync"
)
type Cache[Peer, BlockKey, Hash comparable] struct {
Hash func([]byte) Hash
lock sync.RWMutex
blocks map[BlockKey]map[Peer]Hash
}
type Block[Key any] struct {
Key Key
Data []byte
}
func (me *Cache[Peer, BlockKey, Hash]) Init() {
me.blocks = make(map[BlockKey]map[Peer]Hash)
}
func (me *Cache[Peer, BlockKey, Hash]) RecordBlock(peer Peer, key BlockKey, data []byte) {
hash := me.Hash(data)
me.lock.Lock()
defer me.lock.Unlock()
peers := me.blocks[key]
if peers == nil {
peers = make(map[Peer]Hash)
me.blocks[key] = peers
}
peers[peer] = hash
}
func (me *Cache[Peer, BlockKey, Hash]) CheckBlock(key BlockKey, data []byte) (bad []Peer) {
correct := me.Hash(data)
me.lock.RLock()
defer me.lock.RUnlock()
for peer, hash := range me.blocks[key] {
if hash != correct {
bad = append(bad, peer)
}
}
return
}
func (me *Cache[Peer, BlockKey, Hash]) ForgetBlock(key BlockKey) {
me.lock.Lock()
defer me.lock.Unlock()
delete(me.blocks, key)
}

View File

@ -147,6 +147,8 @@ type Torrent struct {
// Is On when all pieces are complete.
Complete chansync.Flag
smartBanCache smartBanCache
}
func (t *Torrent) selectivePieceAvailabilityFromPeers(i pieceIndex) (count int) {
@ -939,7 +941,20 @@ func (t *Torrent) pieceLength(piece pieceIndex) pp.Integer {
return pp.Integer(t.info.PieceLength)
}
func (t *Torrent) hashPiece(piece pieceIndex) (ret metainfo.Hash, err error) {
func (t *Torrent) smartBanBlockCheckingWriter(piece pieceIndex) *blockCheckingWriter {
return &blockCheckingWriter{
cache: &t.smartBanCache,
requestIndex: t.pieceRequestIndexOffset(piece),
chunkSize: t.chunkSize.Int(),
}
}
func (t *Torrent) hashPiece(piece pieceIndex) (
ret metainfo.Hash,
// These are peers that sent us blocks that differ from what we hash here.
differingPeers map[banPrefix]struct{},
err error,
) {
p := t.piece(piece)
p.waitNoPendingWrites()
storagePiece := t.pieces[piece].Storage()
@ -955,13 +970,18 @@ func (t *Torrent) hashPiece(piece pieceIndex) (ret metainfo.Hash, err error) {
hash := pieceHash.New()
const logPieceContents = false
if logPieceContents {
smartBanWriter := t.smartBanBlockCheckingWriter(piece)
writers := []io.Writer{hash, smartBanWriter}
var examineBuf bytes.Buffer
_, err = storagePiece.WriteTo(io.MultiWriter(hash, &examineBuf))
log.Printf("hashed %q with copy err %v", examineBuf.Bytes(), err)
} else {
_, err = storagePiece.WriteTo(hash)
if logPieceContents {
writers = append(writers, &examineBuf)
}
_, err = storagePiece.WriteTo(io.MultiWriter(writers...))
if logPieceContents {
log.Printf("hashed %q with copy err %v", examineBuf.Bytes(), err)
}
smartBanWriter.Flush()
differingPeers = smartBanWriter.badPeers
missinggo.CopyExact(&ret, hash.Sum(nil))
return
}
@ -2106,8 +2126,14 @@ func (t *Torrent) getPieceToHash() (ret pieceIndex, ok bool) {
func (t *Torrent) pieceHasher(index pieceIndex) {
p := t.piece(index)
sum, copyErr := t.hashPiece(index)
sum, failedPeers, copyErr := t.hashPiece(index)
correct := sum == *p.hash
if correct {
for peer := range failedPeers {
log.Printf("would smart ban %q for %v here", peer, p)
t.cl.banPrefix(peer)
}
}
switch copyErr {
case nil, io.EOF:
default:
@ -2299,6 +2325,7 @@ func (t *Torrent) addWebSeed(url string) {
// requests mark more often, so recomputation is probably sooner than with regular peer
// conns. ~4x maxRequests would be about right.
PeerMaxRequests: 128,
// TODO: Set ban prefix?
RemoteAddr: remoteAddrFromUrl(url),
callbacks: t.callbacks(),
},