Set the connection.cryptoMethod

It was unwittingly dropped from received connections, and may never have been set for initiated connections.
This commit is contained in:
Matt Joiner 2018-02-16 10:59:56 +11:00
parent b92e8b7814
commit fc03dcb859
4 changed files with 22 additions and 24 deletions

View File

@ -748,7 +748,7 @@ func (cl *Client) incomingPeerPort() int {
func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) {
if c.headerEncrypted {
var rw io.ReadWriter
rw, err = mse.InitiateHandshake(
rw, c.cryptoMethod, err = mse.InitiateHandshake(
struct {
io.Reader
io.Writer

View File

@ -187,7 +187,7 @@ func handleEncryption(
}
}
headerEncrypted = true
ret, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
ret, cryptoMethod, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
switch {
case policy.ForceEncryption:
return mse.CryptoMethodRC4

View File

@ -367,7 +367,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
return newEncrypt(initer, h.s[:], h.skey)
}
func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
h.postWrite(hash(req1, h.s[:]))
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
buf := &bytes.Buffer{}
@ -409,7 +409,8 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
if err != nil {
return
}
switch method & h.cryptoProvides {
selected = method & h.cryptoProvides
switch selected {
case CryptoMethodRC4:
ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
case CryptoMethodPlaintext:
@ -422,7 +423,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
var ErrNoSecretKeyMatch = errors.New("no skey matched")
func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
// There is up to 512 bytes of padding, then the 20 byte hash.
err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
if err != nil {
@ -460,7 +461,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
return
}
cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
chosen := h.chooseMethod(provides)
chosen = h.chooseMethod(provides)
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
if err != nil {
return
@ -499,7 +500,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
return
}
func (h *handshake) Do() (ret io.ReadWriter, err error) {
func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
h.writeCond.L = &h.writeMu
h.writerCond.L = &h.writerMu
go h.writer()
@ -521,14 +522,14 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
return
}
if h.initer {
ret, err = h.initerSteps()
ret, method, err = h.initerSteps()
} else {
ret, err = h.receiverSteps()
ret, method, err = h.receiverSteps()
}
return
}
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, err error) {
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, method CryptoMethod, err error) {
h := handshake{
conn: rw,
initer: true,
@ -539,7 +540,7 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry
return h.Do()
}
func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, err error) {
func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) {
h := handshake{
conn: rw,
initer: false,

View File

@ -12,6 +12,7 @@ import (
_ "github.com/anacrolix/envpprof"
"github.com/bradfitz/iter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -64,11 +65,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
wg.Add(2)
go func() {
defer wg.Done()
a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
if err != nil {
t.Fatal(err)
return
}
a, cm, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
require.NoError(t, err)
assert.Equal(t, cryptoSelect(cryptoProvides), cm)
go a.Write([]byte(aData))
var msg [20]byte
@ -80,11 +79,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
}()
go func() {
defer wg.Done()
b, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
if err != nil {
t.Fatal(err)
return
}
b, cm, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
require.NoError(t, err)
assert.Equal(t, cryptoSelect(cryptoProvides), cm)
go b.Write([]byte(bData))
// Need to be exact here, as there are several reads, and net.Pipe is
// most synchronous.
@ -134,7 +131,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) {
func TestReceiveRandomData(t *testing.T) {
tr := trackReader{rand.Reader, 0}
_, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
_, _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
// No skey matches
require.Error(t, err)
// Establishing S, and then reading the maximum padding for giving up on
@ -183,13 +180,13 @@ func benchmarkStream(t *testing.B, crypto CryptoMethod) {
go func() {
defer ac.Close()
defer wg.Done()
rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
rw, _, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
require.NoError(t, err)
require.NoError(t, readAndWrite(rw, ar, a))
}()
func() {
defer bc.Close()
rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
rw, _, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
require.NoError(t, err)
require.NoError(t, readAndWrite(rw, br, b))
}()