Add mse.CryptoMethod type

This commit is contained in:
Matt Joiner 2018-02-16 10:36:29 +11:00
parent 3f7eab00de
commit 066cdd520b
5 changed files with 22 additions and 20 deletions

View File

@ -726,7 +726,7 @@ func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err er
}{c.r, c.w},
t.infoHash[:],
nil,
func() uint32 {
func() mse.CryptoMethod {
switch {
case cl.config.ForceEncryption:
return mse.CryptoMethodRC4

View File

@ -46,7 +46,7 @@ type connection struct {
r io.Reader
// True if the connection is operating over MSE obfuscation.
headerEncrypted bool
cryptoMethod uint32
cryptoMethod mse.CryptoMethod
Discovery peerSource
uTP bool
closed missinggo.Event

View File

@ -165,7 +165,7 @@ func handleEncryption(
) (
ret io.ReadWriter,
headerEncrypted bool,
cryptoMethod uint32,
cryptoMethod mse.CryptoMethod,
err error,
) {
if !policy.ForceEncryption {
@ -187,7 +187,7 @@ func handleEncryption(
}
}
headerEncrypted = true
ret, err = mse.ReceiveHandshake(rw, skeys, func(provides uint32) uint32 {
ret, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
switch {
case policy.ForceEncryption:
return mse.CryptoMethodRC4

View File

@ -24,11 +24,13 @@ import (
const (
maxPadLen = 512
CryptoMethodPlaintext = 1
CryptoMethodRC4 = 2
CryptoMethodPlaintext CryptoMethod = 1
CryptoMethodRC4 CryptoMethod = 2
AllSupportedCrypto = CryptoMethodPlaintext | CryptoMethodRC4
)
type CryptoMethod uint32
var (
// Prime P according to the spec, and G, the generator.
p, g big.Int
@ -212,9 +214,9 @@ type handshake struct {
skey []byte // Skey we're initiating with.
ia []byte // Initial payload. Only used by the initiator.
// Return the bit for the crypto method the receiver wants to use.
chooseMethod func(supported uint32) uint32
chooseMethod CryptoSelector
// Sent to the receiver.
cryptoProvides uint32
cryptoProvides CryptoMethod
writeMu sync.Mutex
writes [][]byte
@ -398,7 +400,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
return
}
r := newCipherReader(bC, h.conn)
var method uint32
var method CryptoMethod
err = unmarshal(r, &method, &padLen)
if err != nil {
return
@ -449,7 +451,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
var (
vc [8]byte
provides uint32
provides CryptoMethod
padLen uint16
)
@ -526,7 +528,7 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
return
}
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) {
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, err error) {
h := handshake{
conn: rw,
initer: true,
@ -537,7 +539,7 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry
return h.Do()
}
func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) {
func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, err error) {
h := handshake{
conn: rw,
initer: false,
@ -551,11 +553,11 @@ func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto func(u
// returns false or exhausted.
type SecretKeyIter func(callback func(skey []byte) (more bool))
func DefaultCryptoSelector(provided uint32) uint32 {
func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
if provided&CryptoMethodPlaintext != 0 {
return CryptoMethodPlaintext
}
return CryptoMethodRC4
}
type CryptoSelector func(uint32) uint32
type CryptoSelector func(CryptoMethod) CryptoMethod

View File

@ -58,7 +58,7 @@ func TestSuffixMatchLen(t *testing.T) {
test("sup", "person", 1)
}
func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides uint32, cryptoSelect func(uint32) uint32) {
func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides CryptoMethod, cryptoSelect CryptoSelector) {
a, b := net.Pipe()
wg := sync.WaitGroup{}
wg.Add(2)
@ -100,7 +100,7 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
b.Close()
}
func allHandshakeTests(t testing.TB, provides uint32, selector CryptoSelector) {
func allHandshakeTests(t testing.TB, provides CryptoMethod, selector CryptoSelector) {
handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector)
handshakeTest(t, nil, "hello world", "yo dawg", provides, selector)
handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector)
@ -112,7 +112,7 @@ func TestHandshakeDefault(t *testing.T) {
}
func TestHandshakeSelectPlaintext(t *testing.T) {
allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return CryptoMethodPlaintext })
allHandshakeTests(t, AllSupportedCrypto, func(CryptoMethod) CryptoMethod { return CryptoMethodPlaintext })
}
func BenchmarkHandshakeDefault(b *testing.B) {
@ -165,7 +165,7 @@ func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error {
return wErr
}
func benchmarkStream(t *testing.B, crypto uint32) {
func benchmarkStream(t *testing.B, crypto CryptoMethod) {
ia := make([]byte, 0x1000)
a := make([]byte, 1<<20)
b := make([]byte, 1<<20)
@ -189,7 +189,7 @@ func benchmarkStream(t *testing.B, crypto uint32) {
}()
func() {
defer bc.Close()
rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(uint32) uint32 { return crypto })
rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
require.NoError(t, err)
require.NoError(t, readAndWrite(rw, br, b))
}()