Add the ReceiveEncryptedHandshakeSkeys callback
This commit is contained in:
parent
15c8846153
commit
131037dd9f
13
callbacks.go
13
callbacks.go
|
@ -1,14 +1,21 @@
|
|||
package torrent
|
||||
|
||||
import (
|
||||
"github.com/anacrolix/torrent/mse"
|
||||
pp "github.com/anacrolix/torrent/peer_protocol"
|
||||
)
|
||||
|
||||
// These are called synchronously, and do not pass ownership. The Client and other locks may still
|
||||
// be held. nil functions are not called.
|
||||
// These are called synchronously, and do not pass ownership of arguments (do not expect to retain
|
||||
// data after returning from the callback). The Client and other locks may still be held. nil
|
||||
// functions are not called.
|
||||
type Callbacks struct {
|
||||
CompletedHandshake func(_ *PeerConn, infoHash InfoHash)
|
||||
// Called after a peer connection completes the BitTorrent handshake. The Client lock is not
|
||||
// held.
|
||||
CompletedHandshake func(*PeerConn, InfoHash)
|
||||
ReadMessage func(*PeerConn, *pp.Message)
|
||||
ReadExtendedHandshake func(*PeerConn, *pp.ExtendedHandshakeMessage)
|
||||
PeerConnClosed func(*PeerConn)
|
||||
|
||||
// Provides secret keys to be tried against incoming encrypted connections.
|
||||
ReceiveEncryptedHandshakeSkeys mse.SecretKeyIter
|
||||
}
|
||||
|
|
18
client.go
18
client.go
|
@ -798,10 +798,11 @@ func (cl *Client) initiateHandshakes(c *PeerConn, t *Torrent) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Calls f with any secret keys.
|
||||
// Calls f with any secret keys. Note that it takes the Client lock, and so must be used from code
|
||||
// that won't also try to take the lock. This saves us copying all the infohashes everytime.
|
||||
func (cl *Client) forSkeys(f func([]byte) bool) {
|
||||
cl.lock()
|
||||
defer cl.unlock()
|
||||
cl.rLock()
|
||||
defer cl.rUnlock()
|
||||
if false { // Emulate the bug from #114
|
||||
var firstIh InfoHash
|
||||
for ih := range cl.torrents {
|
||||
|
@ -822,11 +823,18 @@ func (cl *Client) forSkeys(f func([]byte) bool) {
|
|||
}
|
||||
}
|
||||
|
||||
func (cl *Client) handshakeReceiverSecretKeys() mse.SecretKeyIter {
|
||||
if ret := cl.config.Callbacks.ReceiveEncryptedHandshakeSkeys; ret != nil {
|
||||
return ret
|
||||
}
|
||||
return cl.forSkeys
|
||||
}
|
||||
|
||||
// Do encryption and bittorrent handshakes as receiver.
|
||||
func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
|
||||
defer perf.ScopeTimerErr(&err)()
|
||||
var rw io.ReadWriter
|
||||
rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.forSkeys, cl.config.HeaderObfuscationPolicy, cl.config.CryptoSelector)
|
||||
rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.handshakeReceiverSecretKeys(), cl.config.HeaderObfuscationPolicy, cl.config.CryptoSelector)
|
||||
c.setRW(rw)
|
||||
if err == nil || err == mse.ErrNoSecretKeyMatch {
|
||||
if c.headerEncrypted {
|
||||
|
@ -844,7 +852,7 @@ func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
|
|||
return
|
||||
}
|
||||
if cl.config.HeaderObfuscationPolicy.RequirePreferred && c.headerEncrypted != cl.config.HeaderObfuscationPolicy.Preferred {
|
||||
err = errors.New("connection not have required header obfuscation")
|
||||
err = errors.New("connection does not have required header obfuscation")
|
||||
return
|
||||
}
|
||||
ih, err := cl.connBtHandshake(c, nil)
|
||||
|
|
|
@ -27,6 +27,7 @@ func (r deadlineReader) Read(b []byte) (int, error) {
|
|||
return r.r.Read(b)
|
||||
}
|
||||
|
||||
// Handles stream encryption for inbound connections.
|
||||
func handleEncryption(
|
||||
rw io.ReadWriter,
|
||||
skeys mse.SecretKeyIter,
|
||||
|
@ -38,12 +39,14 @@ func handleEncryption(
|
|||
cryptoMethod mse.CryptoMethod,
|
||||
err error,
|
||||
) {
|
||||
// Tries to start an unencrypted stream.
|
||||
if !policy.RequirePreferred || !policy.Preferred {
|
||||
var protocol [len(pp.Protocol)]byte
|
||||
_, err = io.ReadFull(rw, protocol[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Put the protocol back into the stream.
|
||||
rw = struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
|
@ -56,6 +59,7 @@ func handleEncryption(
|
|||
return
|
||||
}
|
||||
if policy.RequirePreferred {
|
||||
// We are here because we require unencrypted connections.
|
||||
err = fmt.Errorf("unexpected protocol string %q and header obfuscation disabled", protocol)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue