Add the ReceiveEncryptedHandshakeSkeys callback

This commit is contained in:
Matt Joiner 2020-11-05 13:28:45 +11:00
parent 15c8846153
commit 131037dd9f
3 changed files with 27 additions and 8 deletions

View File

@ -1,14 +1,21 @@
package torrent package torrent
import ( import (
"github.com/anacrolix/torrent/mse"
pp "github.com/anacrolix/torrent/peer_protocol" pp "github.com/anacrolix/torrent/peer_protocol"
) )
// These are called synchronously, and do not pass ownership. The Client and other locks may still // These are called synchronously, and do not pass ownership of arguments (do not expect to retain
// be held. nil functions are not called. // data after returning from the callback). The Client and other locks may still be held. nil
// functions are not called.
type Callbacks struct { type Callbacks struct {
CompletedHandshake func(_ *PeerConn, infoHash InfoHash) // Called after a peer connection completes the BitTorrent handshake. The Client lock is not
// held.
CompletedHandshake func(*PeerConn, InfoHash)
ReadMessage func(*PeerConn, *pp.Message) ReadMessage func(*PeerConn, *pp.Message)
ReadExtendedHandshake func(*PeerConn, *pp.ExtendedHandshakeMessage) ReadExtendedHandshake func(*PeerConn, *pp.ExtendedHandshakeMessage)
PeerConnClosed func(*PeerConn) PeerConnClosed func(*PeerConn)
// Provides secret keys to be tried against incoming encrypted connections.
ReceiveEncryptedHandshakeSkeys mse.SecretKeyIter
} }

View File

@ -798,10 +798,11 @@ func (cl *Client) initiateHandshakes(c *PeerConn, t *Torrent) error {
return nil return nil
} }
// Calls f with any secret keys. // Calls f with any secret keys. Note that it takes the Client lock, and so must be used from code
// that won't also try to take the lock. This saves us copying all the infohashes everytime.
func (cl *Client) forSkeys(f func([]byte) bool) { func (cl *Client) forSkeys(f func([]byte) bool) {
cl.lock() cl.rLock()
defer cl.unlock() defer cl.rUnlock()
if false { // Emulate the bug from #114 if false { // Emulate the bug from #114
var firstIh InfoHash var firstIh InfoHash
for ih := range cl.torrents { for ih := range cl.torrents {
@ -822,11 +823,18 @@ func (cl *Client) forSkeys(f func([]byte) bool) {
} }
} }
func (cl *Client) handshakeReceiverSecretKeys() mse.SecretKeyIter {
if ret := cl.config.Callbacks.ReceiveEncryptedHandshakeSkeys; ret != nil {
return ret
}
return cl.forSkeys
}
// Do encryption and bittorrent handshakes as receiver. // Do encryption and bittorrent handshakes as receiver.
func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) { func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
defer perf.ScopeTimerErr(&err)() defer perf.ScopeTimerErr(&err)()
var rw io.ReadWriter var rw io.ReadWriter
rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.forSkeys, cl.config.HeaderObfuscationPolicy, cl.config.CryptoSelector) rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.handshakeReceiverSecretKeys(), cl.config.HeaderObfuscationPolicy, cl.config.CryptoSelector)
c.setRW(rw) c.setRW(rw)
if err == nil || err == mse.ErrNoSecretKeyMatch { if err == nil || err == mse.ErrNoSecretKeyMatch {
if c.headerEncrypted { if c.headerEncrypted {
@ -844,7 +852,7 @@ func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
return return
} }
if cl.config.HeaderObfuscationPolicy.RequirePreferred && c.headerEncrypted != cl.config.HeaderObfuscationPolicy.Preferred { if cl.config.HeaderObfuscationPolicy.RequirePreferred && c.headerEncrypted != cl.config.HeaderObfuscationPolicy.Preferred {
err = errors.New("connection not have required header obfuscation") err = errors.New("connection does not have required header obfuscation")
return return
} }
ih, err := cl.connBtHandshake(c, nil) ih, err := cl.connBtHandshake(c, nil)

View File

@ -27,6 +27,7 @@ func (r deadlineReader) Read(b []byte) (int, error) {
return r.r.Read(b) return r.r.Read(b)
} }
// Handles stream encryption for inbound connections.
func handleEncryption( func handleEncryption(
rw io.ReadWriter, rw io.ReadWriter,
skeys mse.SecretKeyIter, skeys mse.SecretKeyIter,
@ -38,12 +39,14 @@ func handleEncryption(
cryptoMethod mse.CryptoMethod, cryptoMethod mse.CryptoMethod,
err error, err error,
) { ) {
// Tries to start an unencrypted stream.
if !policy.RequirePreferred || !policy.Preferred { if !policy.RequirePreferred || !policy.Preferred {
var protocol [len(pp.Protocol)]byte var protocol [len(pp.Protocol)]byte
_, err = io.ReadFull(rw, protocol[:]) _, err = io.ReadFull(rw, protocol[:])
if err != nil { if err != nil {
return return
} }
// Put the protocol back into the stream.
rw = struct { rw = struct {
io.Reader io.Reader
io.Writer io.Writer
@ -56,6 +59,7 @@ func handleEncryption(
return return
} }
if policy.RequirePreferred { if policy.RequirePreferred {
// We are here because we require unencrypted connections.
err = fmt.Errorf("unexpected protocol string %q and header obfuscation disabled", protocol) err = fmt.Errorf("unexpected protocol string %q and header obfuscation disabled", protocol)
return return
} }