Set the connection.cryptoMethod

It was unwittingly dropped from received connections, and may never have been set for initiated connections.
This commit is contained in:
Matt Joiner 2018-02-16 10:59:56 +11:00
parent b92e8b7814
commit fc03dcb859
4 changed files with 22 additions and 24 deletions

View File

@ -748,7 +748,7 @@ func (cl *Client) incomingPeerPort() int {
func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) { func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) {
if c.headerEncrypted { if c.headerEncrypted {
var rw io.ReadWriter var rw io.ReadWriter
rw, err = mse.InitiateHandshake( rw, c.cryptoMethod, err = mse.InitiateHandshake(
struct { struct {
io.Reader io.Reader
io.Writer io.Writer

View File

@ -187,7 +187,7 @@ func handleEncryption(
} }
} }
headerEncrypted = true headerEncrypted = true
ret, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod { ret, cryptoMethod, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
switch { switch {
case policy.ForceEncryption: case policy.ForceEncryption:
return mse.CryptoMethodRC4 return mse.CryptoMethodRC4

View File

@ -367,7 +367,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
return newEncrypt(initer, h.s[:], h.skey) return newEncrypt(initer, h.s[:], h.skey)
} }
func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
h.postWrite(hash(req1, h.s[:])) h.postWrite(hash(req1, h.s[:]))
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:]))) h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
@ -409,7 +409,8 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
if err != nil { if err != nil {
return return
} }
switch method & h.cryptoProvides { selected = method & h.cryptoProvides
switch selected {
case CryptoMethodRC4: case CryptoMethodRC4:
ret = readWriter{r, &cipherWriter{e, h.conn, nil}} ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
case CryptoMethodPlaintext: case CryptoMethodPlaintext:
@ -422,7 +423,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
var ErrNoSecretKeyMatch = errors.New("no skey matched") var ErrNoSecretKeyMatch = errors.New("no skey matched")
func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
// There is up to 512 bytes of padding, then the 20 byte hash. // There is up to 512 bytes of padding, then the 20 byte hash.
err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:])) err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
if err != nil { if err != nil {
@ -460,7 +461,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
return return
} }
cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1) cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
chosen := h.chooseMethod(provides) chosen = h.chooseMethod(provides)
_, err = io.CopyN(ioutil.Discard, r, int64(padLen)) _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
if err != nil { if err != nil {
return return
@ -499,7 +500,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
return return
} }
func (h *handshake) Do() (ret io.ReadWriter, err error) { func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
h.writeCond.L = &h.writeMu h.writeCond.L = &h.writeMu
h.writerCond.L = &h.writerMu h.writerCond.L = &h.writerMu
go h.writer() go h.writer()
@ -521,14 +522,14 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
return return
} }
if h.initer { if h.initer {
ret, err = h.initerSteps() ret, method, err = h.initerSteps()
} else { } else {
ret, err = h.receiverSteps() ret, method, err = h.receiverSteps()
} }
return return
} }
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, err error) { func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, method CryptoMethod, err error) {
h := handshake{ h := handshake{
conn: rw, conn: rw,
initer: true, initer: true,
@ -539,7 +540,7 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry
return h.Do() return h.Do()
} }
func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, err error) { func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) {
h := handshake{ h := handshake{
conn: rw, conn: rw,
initer: false, initer: false,

View File

@ -12,6 +12,7 @@ import (
_ "github.com/anacrolix/envpprof" _ "github.com/anacrolix/envpprof"
"github.com/bradfitz/iter" "github.com/bradfitz/iter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -64,11 +65,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
wg.Add(2) wg.Add(2)
go func() { go func() {
defer wg.Done() defer wg.Done()
a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides) a, cm, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
if err != nil { require.NoError(t, err)
t.Fatal(err) assert.Equal(t, cryptoSelect(cryptoProvides), cm)
return
}
go a.Write([]byte(aData)) go a.Write([]byte(aData))
var msg [20]byte var msg [20]byte
@ -80,11 +79,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
b, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect) b, cm, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
if err != nil { require.NoError(t, err)
t.Fatal(err) assert.Equal(t, cryptoSelect(cryptoProvides), cm)
return
}
go b.Write([]byte(bData)) go b.Write([]byte(bData))
// Need to be exact here, as there are several reads, and net.Pipe is // Need to be exact here, as there are several reads, and net.Pipe is
// most synchronous. // most synchronous.
@ -134,7 +131,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) {
func TestReceiveRandomData(t *testing.T) { func TestReceiveRandomData(t *testing.T) {
tr := trackReader{rand.Reader, 0} tr := trackReader{rand.Reader, 0}
_, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector) _, _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
// No skey matches // No skey matches
require.Error(t, err) require.Error(t, err)
// Establishing S, and then reading the maximum padding for giving up on // Establishing S, and then reading the maximum padding for giving up on
@ -183,13 +180,13 @@ func benchmarkStream(t *testing.B, crypto CryptoMethod) {
go func() { go func() {
defer ac.Close() defer ac.Close()
defer wg.Done() defer wg.Done()
rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto) rw, _, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, readAndWrite(rw, ar, a)) require.NoError(t, readAndWrite(rw, ar, a))
}() }()
func() { func() {
defer bc.Close() defer bc.Close()
rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto }) rw, _, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, readAndWrite(rw, br, b)) require.NoError(t, readAndWrite(rw, br, b))
}() }()