Support initial payload, and improve tests

This commit is contained in:
Matt Joiner 2015-03-13 14:30:48 +11:00
parent 203da0aab0
commit 8e8d75dda1
2 changed files with 54 additions and 21 deletions

View File

@ -9,11 +9,12 @@ import (
"crypto/sha1"
"encoding/binary"
"errors"
"expvar"
"fmt"
"io"
"io/ioutil"
"log"
"math/big"
"strconv"
"sync"
"github.com/bradfitz/iter"
@ -35,6 +36,8 @@ var (
req1 = []byte("req1")
req2 = []byte("req2")
req3 = []byte("req3")
cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
)
func init() {
@ -176,6 +179,7 @@ type handshake struct {
initer bool
skeys [][]byte
skey []byte
ia []byte // Initial payload. Only used by the initiator.
writeMu sync.Mutex
writes [][]byte
@ -288,7 +292,6 @@ type cryptoNegotiation struct {
func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
err = binary.Read(r, binary.BigEndian, me.VC[:])
// _, err = io.ReadFull(r, me.VC[:])
if err != nil {
return
}
@ -300,7 +303,6 @@ func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
if err != nil {
return
}
log.Print(me.PadLen)
_, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
return
}
@ -344,7 +346,6 @@ func suffixMatchLen(a, b []byte) int {
}
func readUntil(r io.Reader, b []byte) error {
log.Println("read until", b)
b1 := make([]byte, len(b))
i := 0
for {
@ -379,7 +380,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
if err != nil {
return
}
err = marshal(buf, uint16(0))
err = marshal(buf, uint16(len(h.ia)), h.ia)
if err != nil {
return
}
@ -390,7 +391,6 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
bC := newEncrypt(false, h.s.Bytes(), h.skey)
var eVC [8]byte
bC.XORKeyStream(eVC[:], make([]byte, 8))
log.Print(eVC)
// Read until the all zero VC.
err = readUntil(h.conn, eVC[:])
if err != nil {
@ -400,7 +400,6 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
var cn cryptoNegotiation
r := &cipherReader{bC, h.conn}
err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
log.Printf("initer got %v", cn)
if err != nil {
err = fmt.Errorf("error reading crypto negotiation: %s", err)
return
@ -436,12 +435,17 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
if err != nil {
return
}
log.Printf("receiver got %v", cn)
cryptoProvidesCount.Add(strconv.FormatUint(uint64(cn.Method), 16), 1)
if cn.Method&cryptoMethodRC4 == 0 {
err = errors.New("no supported crypto methods were provided")
return
}
unmarshal(r, new(uint16))
var lenIA uint16
unmarshal(r, &lenIA)
if lenIA != 0 {
h.ia = make([]byte, lenIA)
unmarshal(r, h.ia)
}
buf := &bytes.Buffer{}
w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf}
err = (&cryptoNegotiation{
@ -455,7 +459,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
if err != nil {
return
}
ret = readWriter{r, &cipherWriter{w.c, h.conn}}
ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}}
return
}
@ -483,15 +487,15 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
if err != nil {
return
}
log.Print("ermahgerd, finished MSE handshake")
return
}
func InitiateHandshake(rw io.ReadWriteCloser, skey []byte) (ret io.ReadWriter, err error) {
func InitiateHandshake(rw io.ReadWriteCloser, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
h := handshake{
conn: rw,
initer: true,
skey: skey,
ia: initialPayload,
}
h.writeCond.L = &h.writeMu
h.writerCond.L = &h.writerMu

View File

@ -3,10 +3,11 @@ package mse
import (
"bytes"
"io"
"log"
"net"
"sync"
"github.com/bradfitz/iter"
"testing"
)
@ -43,21 +44,25 @@ func TestSuffixMatchLen(t *testing.T) {
test("sup", "person", 1)
}
func TestHandshake(t *testing.T) {
func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
a, b := net.Pipe()
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
a, err := InitiateHandshake(a, []byte("yep"))
a, err := InitiateHandshake(a, []byte("yep"), ia)
if err != nil {
t.Fatal(err)
return
}
a.Write([]byte("hello world"))
go a.Write([]byte(aData))
var msg [20]byte
n, _ := a.Read(msg[:])
log.Print(string(msg[:n]))
if n != len(bData) {
t.FailNow()
}
// t.Log(string(msg[:n]))
}()
go func() {
defer wg.Done()
@ -66,10 +71,34 @@ func TestHandshake(t *testing.T) {
t.Fatal(err)
return
}
var msg [20]byte
n, _ := b.Read(msg[:])
log.Print(string(msg[:n]))
b.Write([]byte("yo dawg"))
go b.Write([]byte(bData))
// Need to be exact here, as there are several reads, and net.Pipe is
// most synchronous.
msg := make([]byte, len(ia)+len(aData))
n, _ := io.ReadFull(b, msg[:])
if n != len(msg) {
t.FailNow()
}
// t.Log(string(msg[:n]))
}()
wg.Wait()
a.Close()
b.Close()
}
func allHandshakeTests(t testing.TB) {
handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg")
handshakeTest(t, nil, "hello world", "yo dawg")
handshakeTest(t, []byte{}, "hello world", "yo dawg")
}
func TestHandshake(t *testing.T) {
allHandshakeTests(t)
t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
}
func BenchmarkHandshake(b *testing.B) {
for range iter.N(b.N) {
allHandshakeTests(b)
}
}