Set the connection.cryptoMethod
It was unwittingly dropped from received connections, and may never have been set for initiated connections.
This commit is contained in:
parent
b92e8b7814
commit
fc03dcb859
|
@ -748,7 +748,7 @@ func (cl *Client) incomingPeerPort() int {
|
||||||
func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) {
|
func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) {
|
||||||
if c.headerEncrypted {
|
if c.headerEncrypted {
|
||||||
var rw io.ReadWriter
|
var rw io.ReadWriter
|
||||||
rw, err = mse.InitiateHandshake(
|
rw, c.cryptoMethod, err = mse.InitiateHandshake(
|
||||||
struct {
|
struct {
|
||||||
io.Reader
|
io.Reader
|
||||||
io.Writer
|
io.Writer
|
||||||
|
|
|
@ -187,7 +187,7 @@ func handleEncryption(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
headerEncrypted = true
|
headerEncrypted = true
|
||||||
ret, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
|
ret, cryptoMethod, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
|
||||||
switch {
|
switch {
|
||||||
case policy.ForceEncryption:
|
case policy.ForceEncryption:
|
||||||
return mse.CryptoMethodRC4
|
return mse.CryptoMethodRC4
|
||||||
|
|
19
mse/mse.go
19
mse/mse.go
|
@ -367,7 +367,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
|
||||||
return newEncrypt(initer, h.s[:], h.skey)
|
return newEncrypt(initer, h.s[:], h.skey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
|
||||||
h.postWrite(hash(req1, h.s[:]))
|
h.postWrite(hash(req1, h.s[:]))
|
||||||
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
|
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
@ -409,7 +409,8 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch method & h.cryptoProvides {
|
selected = method & h.cryptoProvides
|
||||||
|
switch selected {
|
||||||
case CryptoMethodRC4:
|
case CryptoMethodRC4:
|
||||||
ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
|
ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
|
||||||
case CryptoMethodPlaintext:
|
case CryptoMethodPlaintext:
|
||||||
|
@ -422,7 +423,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
||||||
|
|
||||||
var ErrNoSecretKeyMatch = errors.New("no skey matched")
|
var ErrNoSecretKeyMatch = errors.New("no skey matched")
|
||||||
|
|
||||||
func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
|
||||||
// There is up to 512 bytes of padding, then the 20 byte hash.
|
// There is up to 512 bytes of padding, then the 20 byte hash.
|
||||||
err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
|
err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -460,7 +461,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
|
cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
|
||||||
chosen := h.chooseMethod(provides)
|
chosen = h.chooseMethod(provides)
|
||||||
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
|
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -499,7 +500,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) Do() (ret io.ReadWriter, err error) {
|
func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
|
||||||
h.writeCond.L = &h.writeMu
|
h.writeCond.L = &h.writeMu
|
||||||
h.writerCond.L = &h.writerMu
|
h.writerCond.L = &h.writerMu
|
||||||
go h.writer()
|
go h.writer()
|
||||||
|
@ -521,14 +522,14 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if h.initer {
|
if h.initer {
|
||||||
ret, err = h.initerSteps()
|
ret, method, err = h.initerSteps()
|
||||||
} else {
|
} else {
|
||||||
ret, err = h.receiverSteps()
|
ret, method, err = h.receiverSteps()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, err error) {
|
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, method CryptoMethod, err error) {
|
||||||
h := handshake{
|
h := handshake{
|
||||||
conn: rw,
|
conn: rw,
|
||||||
initer: true,
|
initer: true,
|
||||||
|
@ -539,7 +540,7 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry
|
||||||
return h.Do()
|
return h.Do()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, err error) {
|
func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) {
|
||||||
h := handshake{
|
h := handshake{
|
||||||
conn: rw,
|
conn: rw,
|
||||||
initer: false,
|
initer: false,
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
|
|
||||||
_ "github.com/anacrolix/envpprof"
|
_ "github.com/anacrolix/envpprof"
|
||||||
"github.com/bradfitz/iter"
|
"github.com/bradfitz/iter"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -64,11 +65,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
|
a, cm, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
assert.Equal(t, cryptoSelect(cryptoProvides), cm)
|
||||||
return
|
|
||||||
}
|
|
||||||
go a.Write([]byte(aData))
|
go a.Write([]byte(aData))
|
||||||
|
|
||||||
var msg [20]byte
|
var msg [20]byte
|
||||||
|
@ -80,11 +79,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
b, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
|
b, cm, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
assert.Equal(t, cryptoSelect(cryptoProvides), cm)
|
||||||
return
|
|
||||||
}
|
|
||||||
go b.Write([]byte(bData))
|
go b.Write([]byte(bData))
|
||||||
// Need to be exact here, as there are several reads, and net.Pipe is
|
// Need to be exact here, as there are several reads, and net.Pipe is
|
||||||
// most synchronous.
|
// most synchronous.
|
||||||
|
@ -134,7 +131,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) {
|
||||||
|
|
||||||
func TestReceiveRandomData(t *testing.T) {
|
func TestReceiveRandomData(t *testing.T) {
|
||||||
tr := trackReader{rand.Reader, 0}
|
tr := trackReader{rand.Reader, 0}
|
||||||
_, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
|
_, _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
|
||||||
// No skey matches
|
// No skey matches
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
// Establishing S, and then reading the maximum padding for giving up on
|
// Establishing S, and then reading the maximum padding for giving up on
|
||||||
|
@ -183,13 +180,13 @@ func benchmarkStream(t *testing.B, crypto CryptoMethod) {
|
||||||
go func() {
|
go func() {
|
||||||
defer ac.Close()
|
defer ac.Close()
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
|
rw, _, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, readAndWrite(rw, ar, a))
|
require.NoError(t, readAndWrite(rw, ar, a))
|
||||||
}()
|
}()
|
||||||
func() {
|
func() {
|
||||||
defer bc.Close()
|
defer bc.Close()
|
||||||
rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
|
rw, _, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, readAndWrite(rw, br, b))
|
require.NoError(t, readAndWrite(rw, br, b))
|
||||||
}()
|
}()
|
||||||
|
|
Loading…
Reference in New Issue