Support initial payload, and improve tests
This commit is contained in:
parent
203da0aab0
commit
8e8d75dda1
28
mse/mse.go
28
mse/mse.go
|
@ -9,11 +9,12 @@ import (
|
|||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/bradfitz/iter"
|
||||
|
@ -35,6 +36,8 @@ var (
|
|||
req1 = []byte("req1")
|
||||
req2 = []byte("req2")
|
||||
req3 = []byte("req3")
|
||||
|
||||
cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -176,6 +179,7 @@ type handshake struct {
|
|||
initer bool
|
||||
skeys [][]byte
|
||||
skey []byte
|
||||
ia []byte // Initial payload. Only used by the initiator.
|
||||
|
||||
writeMu sync.Mutex
|
||||
writes [][]byte
|
||||
|
@ -288,7 +292,6 @@ type cryptoNegotiation struct {
|
|||
|
||||
func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
|
||||
err = binary.Read(r, binary.BigEndian, me.VC[:])
|
||||
// _, err = io.ReadFull(r, me.VC[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -300,7 +303,6 @@ func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Print(me.PadLen)
|
||||
_, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
|
||||
return
|
||||
}
|
||||
|
@ -344,7 +346,6 @@ func suffixMatchLen(a, b []byte) int {
|
|||
}
|
||||
|
||||
func readUntil(r io.Reader, b []byte) error {
|
||||
log.Println("read until", b)
|
||||
b1 := make([]byte, len(b))
|
||||
i := 0
|
||||
for {
|
||||
|
@ -379,7 +380,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = marshal(buf, uint16(0))
|
||||
err = marshal(buf, uint16(len(h.ia)), h.ia)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -390,7 +391,6 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
|||
bC := newEncrypt(false, h.s.Bytes(), h.skey)
|
||||
var eVC [8]byte
|
||||
bC.XORKeyStream(eVC[:], make([]byte, 8))
|
||||
log.Print(eVC)
|
||||
// Read until the all zero VC.
|
||||
err = readUntil(h.conn, eVC[:])
|
||||
if err != nil {
|
||||
|
@ -400,7 +400,6 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
|||
var cn cryptoNegotiation
|
||||
r := &cipherReader{bC, h.conn}
|
||||
err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
|
||||
log.Printf("initer got %v", cn)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading crypto negotiation: %s", err)
|
||||
return
|
||||
|
@ -436,12 +435,17 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Printf("receiver got %v", cn)
|
||||
cryptoProvidesCount.Add(strconv.FormatUint(uint64(cn.Method), 16), 1)
|
||||
if cn.Method&cryptoMethodRC4 == 0 {
|
||||
err = errors.New("no supported crypto methods were provided")
|
||||
return
|
||||
}
|
||||
unmarshal(r, new(uint16))
|
||||
var lenIA uint16
|
||||
unmarshal(r, &lenIA)
|
||||
if lenIA != 0 {
|
||||
h.ia = make([]byte, lenIA)
|
||||
unmarshal(r, h.ia)
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf}
|
||||
err = (&cryptoNegotiation{
|
||||
|
@ -455,7 +459,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret = readWriter{r, &cipherWriter{w.c, h.conn}}
|
||||
ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -483,15 +487,15 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Print("ermahgerd, finished MSE handshake")
|
||||
return
|
||||
}
|
||||
|
||||
func InitiateHandshake(rw io.ReadWriteCloser, skey []byte) (ret io.ReadWriter, err error) {
|
||||
func InitiateHandshake(rw io.ReadWriteCloser, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
|
||||
h := handshake{
|
||||
conn: rw,
|
||||
initer: true,
|
||||
skey: skey,
|
||||
ia: initialPayload,
|
||||
}
|
||||
h.writeCond.L = &h.writeMu
|
||||
h.writerCond.L = &h.writerMu
|
||||
|
|
|
@ -3,10 +3,11 @@ package mse
|
|||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/bradfitz/iter"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -43,21 +44,25 @@ func TestSuffixMatchLen(t *testing.T) {
|
|||
test("sup", "person", 1)
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
|
||||
a, b := net.Pipe()
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a, err := InitiateHandshake(a, []byte("yep"))
|
||||
a, err := InitiateHandshake(a, []byte("yep"), ia)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
a.Write([]byte("hello world"))
|
||||
go a.Write([]byte(aData))
|
||||
|
||||
var msg [20]byte
|
||||
n, _ := a.Read(msg[:])
|
||||
log.Print(string(msg[:n]))
|
||||
if n != len(bData) {
|
||||
t.FailNow()
|
||||
}
|
||||
// t.Log(string(msg[:n]))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
@ -66,10 +71,34 @@ func TestHandshake(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
var msg [20]byte
|
||||
n, _ := b.Read(msg[:])
|
||||
log.Print(string(msg[:n]))
|
||||
b.Write([]byte("yo dawg"))
|
||||
go b.Write([]byte(bData))
|
||||
// Need to be exact here, as there are several reads, and net.Pipe is
|
||||
// most synchronous.
|
||||
msg := make([]byte, len(ia)+len(aData))
|
||||
n, _ := io.ReadFull(b, msg[:])
|
||||
if n != len(msg) {
|
||||
t.FailNow()
|
||||
}
|
||||
// t.Log(string(msg[:n]))
|
||||
}()
|
||||
wg.Wait()
|
||||
a.Close()
|
||||
b.Close()
|
||||
}
|
||||
|
||||
func allHandshakeTests(t testing.TB) {
|
||||
handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg")
|
||||
handshakeTest(t, nil, "hello world", "yo dawg")
|
||||
handshakeTest(t, []byte{}, "hello world", "yo dawg")
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
allHandshakeTests(t)
|
||||
t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
|
||||
}
|
||||
|
||||
func BenchmarkHandshake(b *testing.B) {
|
||||
for range iter.N(b.N) {
|
||||
allHandshakeTests(b)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue