Implement smart ban using generics
This commit is contained in:
parent
245c728762
commit
53cf508061
19
client.go
19
client.go
|
@ -4,6 +4,7 @@ import (
|
|||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"expvar"
|
||||
|
@ -27,6 +28,7 @@ import (
|
|||
"github.com/anacrolix/missinggo/v2/bitmap"
|
||||
"github.com/anacrolix/missinggo/v2/pproffd"
|
||||
"github.com/anacrolix/sync"
|
||||
"github.com/anacrolix/torrent/option"
|
||||
request_strategy "github.com/anacrolix/torrent/request-strategy"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/dustin/go-humanize"
|
||||
|
@ -74,6 +76,7 @@ type Client struct {
|
|||
// through legitimate channels.
|
||||
dopplegangerAddrs map[string]struct{}
|
||||
badPeerIPs map[string]struct{}
|
||||
bannedPrefixes map[string]struct{}
|
||||
torrents map[InfoHash]*Torrent
|
||||
pieceRequestOrder map[interface{}]*request_strategy.PieceRequestOrder
|
||||
|
||||
|
@ -210,6 +213,7 @@ func (cl *Client) init(cfg *ClientConfig) {
|
|||
MaxConnsPerHost: 10,
|
||||
},
|
||||
}
|
||||
cl.bannedPrefixes = make(map[banPrefix]struct{})
|
||||
}
|
||||
|
||||
func NewClient(cfg *ClientConfig) (cl *Client, err error) {
|
||||
|
@ -1126,6 +1130,12 @@ func (cl *Client) badPeerAddr(addr PeerRemoteAddr) bool {
|
|||
if ipa, ok := tryIpPortFromNetAddr(addr); ok {
|
||||
return cl.badPeerIPPort(ipa.IP, ipa.Port)
|
||||
}
|
||||
addrStr := addr.String()
|
||||
for prefix := range cl.bannedPrefixes {
|
||||
if strings.HasPrefix(addrStr, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -1185,6 +1195,8 @@ func (cl *Client) newTorrentOpt(opts AddTorrentOpts) (t *Torrent) {
|
|||
webSeeds: make(map[string]*Peer),
|
||||
gotMetainfoC: make(chan struct{}),
|
||||
}
|
||||
t.smartBanCache.Hash = sha1.Sum
|
||||
t.smartBanCache.Init()
|
||||
t.networkingEnabled.Set()
|
||||
t.logger = cl.logger.WithContextValue(t)
|
||||
if opts.ChunkSize == 0 {
|
||||
|
@ -1508,6 +1520,9 @@ func (cl *Client) newConnection(nc net.Conn, outgoing bool, remoteAddr PeerRemot
|
|||
connString: connString,
|
||||
conn: nc,
|
||||
}
|
||||
if remoteAddr != nil {
|
||||
c.banPrefix = option.Some(remoteAddr.String())
|
||||
}
|
||||
c.peerImpl = c
|
||||
c.logger = cl.logger.WithDefaultLevel(log.Warning).WithContextValue(c)
|
||||
c.setRW(connStatsReadWriter{nc, c})
|
||||
|
@ -1676,3 +1691,7 @@ func (cl *Client) String() string {
|
|||
func (cl *Client) ConnStats() ConnStats {
|
||||
return cl.stats.Copy()
|
||||
}
|
||||
|
||||
func (cl *Client) banPrefix(prefix banPrefix) {
|
||||
cl.bannedPrefixes[prefix] = struct{}{}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
package option
|
||||
|
||||
type T[V any] struct {
|
||||
ok bool
|
||||
value V
|
||||
}
|
||||
|
||||
func (me *T[V]) Ok() bool {
|
||||
return me.ok
|
||||
}
|
||||
|
||||
func (me *T[V]) Value() V {
|
||||
if !me.ok {
|
||||
panic("not set")
|
||||
}
|
||||
return me.value
|
||||
}
|
||||
|
||||
func Some[V any](value V) T[V] {
|
||||
return T[V]{ok: true, value: value}
|
||||
}
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/anacrolix/missinggo/iter"
|
||||
"github.com/anacrolix/missinggo/v2/bitmap"
|
||||
"github.com/anacrolix/multiless"
|
||||
"github.com/anacrolix/torrent/option"
|
||||
|
||||
"github.com/anacrolix/chansync"
|
||||
"github.com/anacrolix/torrent/bencode"
|
||||
|
@ -67,6 +68,7 @@ type Peer struct {
|
|||
outgoing bool
|
||||
Network string
|
||||
RemoteAddr PeerRemoteAddr
|
||||
banPrefix option.T[string]
|
||||
// True if the connection is operating over MSE obfuscation.
|
||||
headerEncrypted bool
|
||||
cryptoMethod mse.CryptoMethod
|
||||
|
@ -1386,6 +1388,11 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
|
|||
|
||||
ppReq := newRequestFromMessage(msg)
|
||||
req := c.t.requestIndexFromRequest(ppReq)
|
||||
t := c.t
|
||||
|
||||
if c.banPrefix.Ok() {
|
||||
t.smartBanCache.RecordBlock(c.banPrefix.Value(), req, msg.Piece)
|
||||
}
|
||||
|
||||
if c.peerChoking {
|
||||
chunksReceived.Add("while choked", 1)
|
||||
|
@ -1425,7 +1432,6 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
|
|||
}
|
||||
}
|
||||
|
||||
t := c.t
|
||||
cl := t.cl
|
||||
|
||||
// Do we actually want this chunk?
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
package torrent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
|
||||
"github.com/anacrolix/torrent/smartban"
|
||||
)
|
||||
|
||||
type banPrefix = string
|
||||
|
||||
type smartBanCache = smartban.Cache[banPrefix, RequestIndex, [sha1.Size]byte]
|
||||
|
||||
type blockCheckingWriter struct {
|
||||
cache *smartBanCache
|
||||
requestIndex RequestIndex
|
||||
// Peers that didn't match blocks written now.
|
||||
badPeers map[banPrefix]struct{}
|
||||
blockBuffer bytes.Buffer
|
||||
chunkSize int
|
||||
}
|
||||
|
||||
func (me *blockCheckingWriter) checkBlock() {
|
||||
b := me.blockBuffer.Next(me.chunkSize)
|
||||
for _, peer := range me.cache.CheckBlock(me.requestIndex, b) {
|
||||
me.badPeers[peer] = struct{}{}
|
||||
}
|
||||
me.requestIndex++
|
||||
}
|
||||
|
||||
func (me *blockCheckingWriter) checkFullBlocks() {
|
||||
for me.blockBuffer.Len() >= me.chunkSize {
|
||||
me.checkBlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (me *blockCheckingWriter) Write(b []byte) (int, error) {
|
||||
n, err := me.blockBuffer.Write(b)
|
||||
if err != nil {
|
||||
// bytes.Buffer.Write should never fail.
|
||||
panic(err)
|
||||
}
|
||||
me.checkFullBlocks()
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Check any remaining block data. Terminal pieces or piece sizes that don't divide into the chunk
|
||||
// size cleanly may leave fragments that should be checked.
|
||||
func (me *blockCheckingWriter) Flush() {
|
||||
for me.blockBuffer.Len() != 0 {
|
||||
me.checkBlock()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package smartban
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Cache[Peer, BlockKey, Hash comparable] struct {
|
||||
Hash func([]byte) Hash
|
||||
|
||||
lock sync.RWMutex
|
||||
blocks map[BlockKey]map[Peer]Hash
|
||||
}
|
||||
|
||||
type Block[Key any] struct {
|
||||
Key Key
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (me *Cache[Peer, BlockKey, Hash]) Init() {
|
||||
me.blocks = make(map[BlockKey]map[Peer]Hash)
|
||||
}
|
||||
|
||||
func (me *Cache[Peer, BlockKey, Hash]) RecordBlock(peer Peer, key BlockKey, data []byte) {
|
||||
hash := me.Hash(data)
|
||||
me.lock.Lock()
|
||||
defer me.lock.Unlock()
|
||||
peers := me.blocks[key]
|
||||
if peers == nil {
|
||||
peers = make(map[Peer]Hash)
|
||||
me.blocks[key] = peers
|
||||
}
|
||||
peers[peer] = hash
|
||||
}
|
||||
|
||||
func (me *Cache[Peer, BlockKey, Hash]) CheckBlock(key BlockKey, data []byte) (bad []Peer) {
|
||||
correct := me.Hash(data)
|
||||
me.lock.RLock()
|
||||
defer me.lock.RUnlock()
|
||||
for peer, hash := range me.blocks[key] {
|
||||
if hash != correct {
|
||||
bad = append(bad, peer)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (me *Cache[Peer, BlockKey, Hash]) ForgetBlock(key BlockKey) {
|
||||
me.lock.Lock()
|
||||
defer me.lock.Unlock()
|
||||
delete(me.blocks, key)
|
||||
}
|
45
torrent.go
45
torrent.go
|
@ -147,6 +147,8 @@ type Torrent struct {
|
|||
|
||||
// Is On when all pieces are complete.
|
||||
Complete chansync.Flag
|
||||
|
||||
smartBanCache smartBanCache
|
||||
}
|
||||
|
||||
func (t *Torrent) selectivePieceAvailabilityFromPeers(i pieceIndex) (count int) {
|
||||
|
@ -939,7 +941,20 @@ func (t *Torrent) pieceLength(piece pieceIndex) pp.Integer {
|
|||
return pp.Integer(t.info.PieceLength)
|
||||
}
|
||||
|
||||
func (t *Torrent) hashPiece(piece pieceIndex) (ret metainfo.Hash, err error) {
|
||||
func (t *Torrent) smartBanBlockCheckingWriter(piece pieceIndex) *blockCheckingWriter {
|
||||
return &blockCheckingWriter{
|
||||
cache: &t.smartBanCache,
|
||||
requestIndex: t.pieceRequestIndexOffset(piece),
|
||||
chunkSize: t.chunkSize.Int(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Torrent) hashPiece(piece pieceIndex) (
|
||||
ret metainfo.Hash,
|
||||
// These are peers that sent us blocks that differ from what we hash here.
|
||||
differingPeers map[banPrefix]struct{},
|
||||
err error,
|
||||
) {
|
||||
p := t.piece(piece)
|
||||
p.waitNoPendingWrites()
|
||||
storagePiece := t.pieces[piece].Storage()
|
||||
|
@ -955,13 +970,18 @@ func (t *Torrent) hashPiece(piece pieceIndex) (ret metainfo.Hash, err error) {
|
|||
|
||||
hash := pieceHash.New()
|
||||
const logPieceContents = false
|
||||
smartBanWriter := t.smartBanBlockCheckingWriter(piece)
|
||||
writers := []io.Writer{hash, smartBanWriter}
|
||||
var examineBuf bytes.Buffer
|
||||
if logPieceContents {
|
||||
var examineBuf bytes.Buffer
|
||||
_, err = storagePiece.WriteTo(io.MultiWriter(hash, &examineBuf))
|
||||
log.Printf("hashed %q with copy err %v", examineBuf.Bytes(), err)
|
||||
} else {
|
||||
_, err = storagePiece.WriteTo(hash)
|
||||
writers = append(writers, &examineBuf)
|
||||
}
|
||||
_, err = storagePiece.WriteTo(io.MultiWriter(writers...))
|
||||
if logPieceContents {
|
||||
log.Printf("hashed %q with copy err %v", examineBuf.Bytes(), err)
|
||||
}
|
||||
smartBanWriter.Flush()
|
||||
differingPeers = smartBanWriter.badPeers
|
||||
missinggo.CopyExact(&ret, hash.Sum(nil))
|
||||
return
|
||||
}
|
||||
|
@ -2106,8 +2126,14 @@ func (t *Torrent) getPieceToHash() (ret pieceIndex, ok bool) {
|
|||
|
||||
func (t *Torrent) pieceHasher(index pieceIndex) {
|
||||
p := t.piece(index)
|
||||
sum, copyErr := t.hashPiece(index)
|
||||
sum, failedPeers, copyErr := t.hashPiece(index)
|
||||
correct := sum == *p.hash
|
||||
if correct {
|
||||
for peer := range failedPeers {
|
||||
log.Printf("would smart ban %q for %v here", peer, p)
|
||||
t.cl.banPrefix(peer)
|
||||
}
|
||||
}
|
||||
switch copyErr {
|
||||
case nil, io.EOF:
|
||||
default:
|
||||
|
@ -2299,8 +2325,9 @@ func (t *Torrent) addWebSeed(url string) {
|
|||
// requests mark more often, so recomputation is probably sooner than with regular peer
|
||||
// conns. ~4x maxRequests would be about right.
|
||||
PeerMaxRequests: 128,
|
||||
RemoteAddr: remoteAddrFromUrl(url),
|
||||
callbacks: t.callbacks(),
|
||||
// TODO: Set ban prefix?
|
||||
RemoteAddr: remoteAddrFromUrl(url),
|
||||
callbacks: t.callbacks(),
|
||||
},
|
||||
client: webseed.Client{
|
||||
HttpClient: t.cl.webseedHttpClient,
|
||||
|
|
Loading…
Reference in New Issue