Add mse.CryptoMethod type

This commit is contained in:
Matt Joiner 2018-02-16 10:36:29 +11:00
parent 3f7eab00de
commit 066cdd520b
5 changed files with 22 additions and 20 deletions

View File

@ -726,7 +726,7 @@ func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err er
}{c.r, c.w}, }{c.r, c.w},
t.infoHash[:], t.infoHash[:],
nil, nil,
func() uint32 { func() mse.CryptoMethod {
switch { switch {
case cl.config.ForceEncryption: case cl.config.ForceEncryption:
return mse.CryptoMethodRC4 return mse.CryptoMethodRC4

View File

@ -46,7 +46,7 @@ type connection struct {
r io.Reader r io.Reader
// True if the connection is operating over MSE obfuscation. // True if the connection is operating over MSE obfuscation.
headerEncrypted bool headerEncrypted bool
cryptoMethod uint32 cryptoMethod mse.CryptoMethod
Discovery peerSource Discovery peerSource
uTP bool uTP bool
closed missinggo.Event closed missinggo.Event

View File

@ -165,7 +165,7 @@ func handleEncryption(
) ( ) (
ret io.ReadWriter, ret io.ReadWriter,
headerEncrypted bool, headerEncrypted bool,
cryptoMethod uint32, cryptoMethod mse.CryptoMethod,
err error, err error,
) { ) {
if !policy.ForceEncryption { if !policy.ForceEncryption {
@ -187,7 +187,7 @@ func handleEncryption(
} }
} }
headerEncrypted = true headerEncrypted = true
ret, err = mse.ReceiveHandshake(rw, skeys, func(provides uint32) uint32 { ret, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
switch { switch {
case policy.ForceEncryption: case policy.ForceEncryption:
return mse.CryptoMethodRC4 return mse.CryptoMethodRC4

View File

@ -24,11 +24,13 @@ import (
const ( const (
maxPadLen = 512 maxPadLen = 512
CryptoMethodPlaintext = 1 CryptoMethodPlaintext CryptoMethod = 1
CryptoMethodRC4 = 2 CryptoMethodRC4 CryptoMethod = 2
AllSupportedCrypto = CryptoMethodPlaintext | CryptoMethodRC4 AllSupportedCrypto = CryptoMethodPlaintext | CryptoMethodRC4
) )
type CryptoMethod uint32
var ( var (
// Prime P according to the spec, and G, the generator. // Prime P according to the spec, and G, the generator.
p, g big.Int p, g big.Int
@ -212,9 +214,9 @@ type handshake struct {
skey []byte // Skey we're initiating with. skey []byte // Skey we're initiating with.
ia []byte // Initial payload. Only used by the initiator. ia []byte // Initial payload. Only used by the initiator.
// Return the bit for the crypto method the receiver wants to use. // Return the bit for the crypto method the receiver wants to use.
chooseMethod func(supported uint32) uint32 chooseMethod CryptoSelector
// Sent to the receiver. // Sent to the receiver.
cryptoProvides uint32 cryptoProvides CryptoMethod
writeMu sync.Mutex writeMu sync.Mutex
writes [][]byte writes [][]byte
@ -398,7 +400,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
return return
} }
r := newCipherReader(bC, h.conn) r := newCipherReader(bC, h.conn)
var method uint32 var method CryptoMethod
err = unmarshal(r, &method, &padLen) err = unmarshal(r, &method, &padLen)
if err != nil { if err != nil {
return return
@ -449,7 +451,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn) r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
var ( var (
vc [8]byte vc [8]byte
provides uint32 provides CryptoMethod
padLen uint16 padLen uint16
) )
@ -526,7 +528,7 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
return return
} }
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) { func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, err error) {
h := handshake{ h := handshake{
conn: rw, conn: rw,
initer: true, initer: true,
@ -537,7 +539,7 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry
return h.Do() return h.Do()
} }
func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) { func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, err error) {
h := handshake{ h := handshake{
conn: rw, conn: rw,
initer: false, initer: false,
@ -551,11 +553,11 @@ func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto func(u
// returns false or exhausted. // returns false or exhausted.
type SecretKeyIter func(callback func(skey []byte) (more bool)) type SecretKeyIter func(callback func(skey []byte) (more bool))
func DefaultCryptoSelector(provided uint32) uint32 { func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
if provided&CryptoMethodPlaintext != 0 { if provided&CryptoMethodPlaintext != 0 {
return CryptoMethodPlaintext return CryptoMethodPlaintext
} }
return CryptoMethodRC4 return CryptoMethodRC4
} }
type CryptoSelector func(uint32) uint32 type CryptoSelector func(CryptoMethod) CryptoMethod

View File

@ -58,7 +58,7 @@ func TestSuffixMatchLen(t *testing.T) {
test("sup", "person", 1) test("sup", "person", 1)
} }
func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides uint32, cryptoSelect func(uint32) uint32) { func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides CryptoMethod, cryptoSelect CryptoSelector) {
a, b := net.Pipe() a, b := net.Pipe()
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(2) wg.Add(2)
@ -100,7 +100,7 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
b.Close() b.Close()
} }
func allHandshakeTests(t testing.TB, provides uint32, selector CryptoSelector) { func allHandshakeTests(t testing.TB, provides CryptoMethod, selector CryptoSelector) {
handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector) handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector)
handshakeTest(t, nil, "hello world", "yo dawg", provides, selector) handshakeTest(t, nil, "hello world", "yo dawg", provides, selector)
handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector) handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector)
@ -112,7 +112,7 @@ func TestHandshakeDefault(t *testing.T) {
} }
func TestHandshakeSelectPlaintext(t *testing.T) { func TestHandshakeSelectPlaintext(t *testing.T) {
allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return CryptoMethodPlaintext }) allHandshakeTests(t, AllSupportedCrypto, func(CryptoMethod) CryptoMethod { return CryptoMethodPlaintext })
} }
func BenchmarkHandshakeDefault(b *testing.B) { func BenchmarkHandshakeDefault(b *testing.B) {
@ -165,7 +165,7 @@ func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error {
return wErr return wErr
} }
func benchmarkStream(t *testing.B, crypto uint32) { func benchmarkStream(t *testing.B, crypto CryptoMethod) {
ia := make([]byte, 0x1000) ia := make([]byte, 0x1000)
a := make([]byte, 1<<20) a := make([]byte, 1<<20)
b := make([]byte, 1<<20) b := make([]byte, 1<<20)
@ -189,7 +189,7 @@ func benchmarkStream(t *testing.B, crypto uint32) {
}() }()
func() { func() {
defer bc.Close() defer bc.Close()
rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(uint32) uint32 { return crypto }) rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, readAndWrite(rw, br, b)) require.NoError(t, readAndWrite(rw, br, b))
}() }()