diff --git a/client.go b/client.go index 7586fa80..2a208fc9 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "crypto/rand" + "crypto/sha1" "encoding/binary" "errors" "expvar" @@ -27,6 +28,7 @@ import ( "github.com/anacrolix/missinggo/v2/bitmap" "github.com/anacrolix/missinggo/v2/pproffd" "github.com/anacrolix/sync" + "github.com/anacrolix/torrent/option" request_strategy "github.com/anacrolix/torrent/request-strategy" "github.com/davecgh/go-spew/spew" "github.com/dustin/go-humanize" @@ -74,6 +76,7 @@ type Client struct { // through legitimate channels. dopplegangerAddrs map[string]struct{} badPeerIPs map[string]struct{} + bannedPrefixes map[string]struct{} torrents map[InfoHash]*Torrent pieceRequestOrder map[interface{}]*request_strategy.PieceRequestOrder @@ -210,6 +213,7 @@ func (cl *Client) init(cfg *ClientConfig) { MaxConnsPerHost: 10, }, } + cl.bannedPrefixes = make(map[banPrefix]struct{}) } func NewClient(cfg *ClientConfig) (cl *Client, err error) { @@ -1126,6 +1130,12 @@ func (cl *Client) badPeerAddr(addr PeerRemoteAddr) bool { if ipa, ok := tryIpPortFromNetAddr(addr); ok { return cl.badPeerIPPort(ipa.IP, ipa.Port) } + addrStr := addr.String() + for prefix := range cl.bannedPrefixes { + if strings.HasPrefix(addrStr, prefix) { + return true + } + } return false } @@ -1185,6 +1195,8 @@ func (cl *Client) newTorrentOpt(opts AddTorrentOpts) (t *Torrent) { webSeeds: make(map[string]*Peer), gotMetainfoC: make(chan struct{}), } + t.smartBanCache.Hash = sha1.Sum + t.smartBanCache.Init() t.networkingEnabled.Set() t.logger = cl.logger.WithContextValue(t) if opts.ChunkSize == 0 { @@ -1508,6 +1520,9 @@ func (cl *Client) newConnection(nc net.Conn, outgoing bool, remoteAddr PeerRemot connString: connString, conn: nc, } + if remoteAddr != nil { + c.banPrefix = option.Some(remoteAddr.String()) + } c.peerImpl = c c.logger = cl.logger.WithDefaultLevel(log.Warning).WithContextValue(c) c.setRW(connStatsReadWriter{nc, c}) @@ -1676,3 +1691,7 @@ func (cl *Client) String() string { func (cl *Client) ConnStats() ConnStats { return cl.stats.Copy() } + +func (cl *Client) banPrefix(prefix banPrefix) { + cl.bannedPrefixes[prefix] = struct{}{} +} diff --git a/option/option.go b/option/option.go new file mode 100644 index 00000000..f53d5bb5 --- /dev/null +++ b/option/option.go @@ -0,0 +1,21 @@ +package option + +type T[V any] struct { + ok bool + value V +} + +func (me *T[V]) Ok() bool { + return me.ok +} + +func (me *T[V]) Value() V { + if !me.ok { + panic("not set") + } + return me.value +} + +func Some[V any](value V) T[V] { + return T[V]{ok: true, value: value} +} diff --git a/peerconn.go b/peerconn.go index 164cdc9c..566553f9 100644 --- a/peerconn.go +++ b/peerconn.go @@ -19,6 +19,7 @@ import ( "github.com/anacrolix/missinggo/iter" "github.com/anacrolix/missinggo/v2/bitmap" "github.com/anacrolix/multiless" + "github.com/anacrolix/torrent/option" "github.com/anacrolix/chansync" "github.com/anacrolix/torrent/bencode" @@ -67,6 +68,7 @@ type Peer struct { outgoing bool Network string RemoteAddr PeerRemoteAddr + banPrefix option.T[string] // True if the connection is operating over MSE obfuscation. headerEncrypted bool cryptoMethod mse.CryptoMethod @@ -1386,6 +1388,11 @@ func (c *Peer) receiveChunk(msg *pp.Message) error { ppReq := newRequestFromMessage(msg) req := c.t.requestIndexFromRequest(ppReq) + t := c.t + + if c.banPrefix.Ok() { + t.smartBanCache.RecordBlock(c.banPrefix.Value(), req, msg.Piece) + } if c.peerChoking { chunksReceived.Add("while choked", 1) @@ -1425,7 +1432,6 @@ func (c *Peer) receiveChunk(msg *pp.Message) error { } } - t := c.t cl := t.cl // Do we actually want this chunk? diff --git a/smartban.go b/smartban.go new file mode 100644 index 00000000..9f43104f --- /dev/null +++ b/smartban.go @@ -0,0 +1,53 @@ +package torrent + +import ( + "bytes" + "crypto/sha1" + + "github.com/anacrolix/torrent/smartban" +) + +type banPrefix = string + +type smartBanCache = smartban.Cache[banPrefix, RequestIndex, [sha1.Size]byte] + +type blockCheckingWriter struct { + cache *smartBanCache + requestIndex RequestIndex + // Peers that didn't match blocks written now. + badPeers map[banPrefix]struct{} + blockBuffer bytes.Buffer + chunkSize int +} + +func (me *blockCheckingWriter) checkBlock() { + b := me.blockBuffer.Next(me.chunkSize) + for _, peer := range me.cache.CheckBlock(me.requestIndex, b) { + me.badPeers[peer] = struct{}{} + } + me.requestIndex++ +} + +func (me *blockCheckingWriter) checkFullBlocks() { + for me.blockBuffer.Len() >= me.chunkSize { + me.checkBlock() + } +} + +func (me *blockCheckingWriter) Write(b []byte) (int, error) { + n, err := me.blockBuffer.Write(b) + if err != nil { + // bytes.Buffer.Write should never fail. + panic(err) + } + me.checkFullBlocks() + return n, err +} + +// Check any remaining block data. Terminal pieces or piece sizes that don't divide into the chunk +// size cleanly may leave fragments that should be checked. +func (me *blockCheckingWriter) Flush() { + for me.blockBuffer.Len() != 0 { + me.checkBlock() + } +} diff --git a/smartban/smartban.go b/smartban/smartban.go new file mode 100644 index 00000000..96e9b759 --- /dev/null +++ b/smartban/smartban.go @@ -0,0 +1,51 @@ +package smartban + +import ( + "sync" +) + +type Cache[Peer, BlockKey, Hash comparable] struct { + Hash func([]byte) Hash + + lock sync.RWMutex + blocks map[BlockKey]map[Peer]Hash +} + +type Block[Key any] struct { + Key Key + Data []byte +} + +func (me *Cache[Peer, BlockKey, Hash]) Init() { + me.blocks = make(map[BlockKey]map[Peer]Hash) +} + +func (me *Cache[Peer, BlockKey, Hash]) RecordBlock(peer Peer, key BlockKey, data []byte) { + hash := me.Hash(data) + me.lock.Lock() + defer me.lock.Unlock() + peers := me.blocks[key] + if peers == nil { + peers = make(map[Peer]Hash) + me.blocks[key] = peers + } + peers[peer] = hash +} + +func (me *Cache[Peer, BlockKey, Hash]) CheckBlock(key BlockKey, data []byte) (bad []Peer) { + correct := me.Hash(data) + me.lock.RLock() + defer me.lock.RUnlock() + for peer, hash := range me.blocks[key] { + if hash != correct { + bad = append(bad, peer) + } + } + return +} + +func (me *Cache[Peer, BlockKey, Hash]) ForgetBlock(key BlockKey) { + me.lock.Lock() + defer me.lock.Unlock() + delete(me.blocks, key) +} diff --git a/torrent.go b/torrent.go index 3e74df63..02e7c399 100644 --- a/torrent.go +++ b/torrent.go @@ -147,6 +147,8 @@ type Torrent struct { // Is On when all pieces are complete. Complete chansync.Flag + + smartBanCache smartBanCache } func (t *Torrent) selectivePieceAvailabilityFromPeers(i pieceIndex) (count int) { @@ -939,7 +941,20 @@ func (t *Torrent) pieceLength(piece pieceIndex) pp.Integer { return pp.Integer(t.info.PieceLength) } -func (t *Torrent) hashPiece(piece pieceIndex) (ret metainfo.Hash, err error) { +func (t *Torrent) smartBanBlockCheckingWriter(piece pieceIndex) *blockCheckingWriter { + return &blockCheckingWriter{ + cache: &t.smartBanCache, + requestIndex: t.pieceRequestIndexOffset(piece), + chunkSize: t.chunkSize.Int(), + } +} + +func (t *Torrent) hashPiece(piece pieceIndex) ( + ret metainfo.Hash, + // These are peers that sent us blocks that differ from what we hash here. + differingPeers map[banPrefix]struct{}, + err error, +) { p := t.piece(piece) p.waitNoPendingWrites() storagePiece := t.pieces[piece].Storage() @@ -955,13 +970,18 @@ func (t *Torrent) hashPiece(piece pieceIndex) (ret metainfo.Hash, err error) { hash := pieceHash.New() const logPieceContents = false + smartBanWriter := t.smartBanBlockCheckingWriter(piece) + writers := []io.Writer{hash, smartBanWriter} + var examineBuf bytes.Buffer if logPieceContents { - var examineBuf bytes.Buffer - _, err = storagePiece.WriteTo(io.MultiWriter(hash, &examineBuf)) - log.Printf("hashed %q with copy err %v", examineBuf.Bytes(), err) - } else { - _, err = storagePiece.WriteTo(hash) + writers = append(writers, &examineBuf) } + _, err = storagePiece.WriteTo(io.MultiWriter(writers...)) + if logPieceContents { + log.Printf("hashed %q with copy err %v", examineBuf.Bytes(), err) + } + smartBanWriter.Flush() + differingPeers = smartBanWriter.badPeers missinggo.CopyExact(&ret, hash.Sum(nil)) return } @@ -2106,8 +2126,14 @@ func (t *Torrent) getPieceToHash() (ret pieceIndex, ok bool) { func (t *Torrent) pieceHasher(index pieceIndex) { p := t.piece(index) - sum, copyErr := t.hashPiece(index) + sum, failedPeers, copyErr := t.hashPiece(index) correct := sum == *p.hash + if correct { + for peer := range failedPeers { + log.Printf("would smart ban %q for %v here", peer, p) + t.cl.banPrefix(peer) + } + } switch copyErr { case nil, io.EOF: default: @@ -2299,8 +2325,9 @@ func (t *Torrent) addWebSeed(url string) { // requests mark more often, so recomputation is probably sooner than with regular peer // conns. ~4x maxRequests would be about right. PeerMaxRequests: 128, - RemoteAddr: remoteAddrFromUrl(url), - callbacks: t.callbacks(), + // TODO: Set ban prefix? + RemoteAddr: remoteAddrFromUrl(url), + callbacks: t.callbacks(), }, client: webseed.Client{ HttpClient: t.cl.webseedHttpClient,