Set the connection.cryptoMethod
It was unwittingly dropped from received connections, and may never have been set for initiated connections.
This commit is contained in:
parent
b92e8b7814
commit
fc03dcb859
|
@ -748,7 +748,7 @@ func (cl *Client) incomingPeerPort() int {
|
|||
func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) {
|
||||
if c.headerEncrypted {
|
||||
var rw io.ReadWriter
|
||||
rw, err = mse.InitiateHandshake(
|
||||
rw, c.cryptoMethod, err = mse.InitiateHandshake(
|
||||
struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
|
|
|
@ -187,7 +187,7 @@ func handleEncryption(
|
|||
}
|
||||
}
|
||||
headerEncrypted = true
|
||||
ret, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
|
||||
ret, cryptoMethod, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
|
||||
switch {
|
||||
case policy.ForceEncryption:
|
||||
return mse.CryptoMethodRC4
|
||||
|
|
19
mse/mse.go
19
mse/mse.go
|
@ -367,7 +367,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
|
|||
return newEncrypt(initer, h.s[:], h.skey)
|
||||
}
|
||||
|
||||
func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
||||
func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
|
||||
h.postWrite(hash(req1, h.s[:]))
|
||||
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
|
||||
buf := &bytes.Buffer{}
|
||||
|
@ -409,7 +409,8 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch method & h.cryptoProvides {
|
||||
selected = method & h.cryptoProvides
|
||||
switch selected {
|
||||
case CryptoMethodRC4:
|
||||
ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
|
||||
case CryptoMethodPlaintext:
|
||||
|
@ -422,7 +423,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
|||
|
||||
var ErrNoSecretKeyMatch = errors.New("no skey matched")
|
||||
|
||||
func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
||||
func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
|
||||
// There is up to 512 bytes of padding, then the 20 byte hash.
|
||||
err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
|
||||
if err != nil {
|
||||
|
@ -460,7 +461,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
|||
return
|
||||
}
|
||||
cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
|
||||
chosen := h.chooseMethod(provides)
|
||||
chosen = h.chooseMethod(provides)
|
||||
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -499,7 +500,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (h *handshake) Do() (ret io.ReadWriter, err error) {
|
||||
func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
|
||||
h.writeCond.L = &h.writeMu
|
||||
h.writerCond.L = &h.writerMu
|
||||
go h.writer()
|
||||
|
@ -521,14 +522,14 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
|
|||
return
|
||||
}
|
||||
if h.initer {
|
||||
ret, err = h.initerSteps()
|
||||
ret, method, err = h.initerSteps()
|
||||
} else {
|
||||
ret, err = h.receiverSteps()
|
||||
ret, method, err = h.receiverSteps()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, err error) {
|
||||
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, method CryptoMethod, err error) {
|
||||
h := handshake{
|
||||
conn: rw,
|
||||
initer: true,
|
||||
|
@ -539,7 +540,7 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry
|
|||
return h.Do()
|
||||
}
|
||||
|
||||
func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, err error) {
|
||||
func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) {
|
||||
h := handshake{
|
||||
conn: rw,
|
||||
initer: false,
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
_ "github.com/anacrolix/envpprof"
|
||||
"github.com/bradfitz/iter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -64,11 +65,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
|
|||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
a, cm, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cryptoSelect(cryptoProvides), cm)
|
||||
go a.Write([]byte(aData))
|
||||
|
||||
var msg [20]byte
|
||||
|
@ -80,11 +79,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
|
|||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
b, cm, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cryptoSelect(cryptoProvides), cm)
|
||||
go b.Write([]byte(bData))
|
||||
// Need to be exact here, as there are several reads, and net.Pipe is
|
||||
// most synchronous.
|
||||
|
@ -134,7 +131,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) {
|
|||
|
||||
func TestReceiveRandomData(t *testing.T) {
|
||||
tr := trackReader{rand.Reader, 0}
|
||||
_, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
|
||||
_, _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
|
||||
// No skey matches
|
||||
require.Error(t, err)
|
||||
// Establishing S, and then reading the maximum padding for giving up on
|
||||
|
@ -183,13 +180,13 @@ func benchmarkStream(t *testing.B, crypto CryptoMethod) {
|
|||
go func() {
|
||||
defer ac.Close()
|
||||
defer wg.Done()
|
||||
rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
|
||||
rw, _, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, readAndWrite(rw, ar, a))
|
||||
}()
|
||||
func() {
|
||||
defer bc.Close()
|
||||
rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
|
||||
rw, _, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, readAndWrite(rw, br, b))
|
||||
}()
|
||||
|
|
Loading…
Reference in New Issue