mse: Got basic message stream encryption working with some tests
Not complete yet.
This commit is contained in:
parent
5b2a91e24a
commit
954e03952a
|
@ -0,0 +1,444 @@
|
|||
// https://wiki.vuze.com/w/Message_Stream_Encryption
|
||||
|
||||
package mse
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rc4"
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/big"
|
||||
"sync"
|
||||
|
||||
"github.com/bradfitz/iter"
|
||||
)
|
||||
|
||||
const (
|
||||
maxPadLen = 512
|
||||
|
||||
cryptoMethodPlaintext = 1
|
||||
cryptoMethodRC4 = 2
|
||||
)
|
||||
|
||||
var (
|
||||
// Prime P according to the spec, and G, the generator.
|
||||
p, g big.Int
|
||||
// The rand.Int max arg for use in newPadLen()
|
||||
newPadLenMax big.Int
|
||||
// For use in initer's hashes
|
||||
req1 = []byte("req1")
|
||||
req2 = []byte("req2")
|
||||
req3 = []byte("req3")
|
||||
)
|
||||
|
||||
func init() {
|
||||
p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
|
||||
g.SetInt64(2)
|
||||
newPadLenMax.SetInt64(maxPadLen + 1)
|
||||
}
|
||||
|
||||
func hash(parts ...[]byte) []byte {
|
||||
h := sha1.New()
|
||||
for _, p := range parts {
|
||||
n, err := h.Write(p)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if n != len(p) {
|
||||
panic(n)
|
||||
}
|
||||
}
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
|
||||
c, err := rc4.NewCipher(hash([]byte(func() string {
|
||||
if initer {
|
||||
return "keyA"
|
||||
} else {
|
||||
return "keyB"
|
||||
}
|
||||
}()), s, skey))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var burnSrc, burnDst [1024]byte
|
||||
c.XORKeyStream(burnDst[:], burnSrc[:])
|
||||
return
|
||||
}
|
||||
|
||||
type cipherReader struct {
|
||||
c *rc4.Cipher
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (me *cipherReader) Read(b []byte) (n int, err error) {
|
||||
be := make([]byte, len(b))
|
||||
n, err = me.r.Read(be)
|
||||
me.c.XORKeyStream(b[:n], be[:n])
|
||||
return
|
||||
}
|
||||
|
||||
func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
|
||||
return &cipherReader{c, r}
|
||||
}
|
||||
|
||||
type cipherWriter struct {
|
||||
c *rc4.Cipher
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (me *cipherWriter) Write(b []byte) (n int, err error) {
|
||||
be := make([]byte, len(b))
|
||||
me.c.XORKeyStream(be, b)
|
||||
n, err = me.w.Write(be)
|
||||
if n != len(be) {
|
||||
// The cipher will have advanced beyond the callers stream position.
|
||||
// We can't use the cipher anymore.
|
||||
me.c = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func newCipherWriter(c *rc4.Cipher, w io.Writer) io.Writer {
|
||||
return &cipherWriter{c, w}
|
||||
}
|
||||
|
||||
func readY(r io.Reader) (y big.Int, err error) {
|
||||
var b [96]byte
|
||||
_, err = io.ReadFull(r, b[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
y.SetBytes(b[:])
|
||||
return
|
||||
}
|
||||
|
||||
func newX() big.Int {
|
||||
var X big.Int
|
||||
X.SetBytes(func() []byte {
|
||||
var b [20]byte
|
||||
_, err := rand.Read(b[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b[:]
|
||||
}())
|
||||
return X
|
||||
}
|
||||
|
||||
func (h *handshake) postY(x *big.Int) error {
|
||||
var y big.Int
|
||||
y.Exp(&g, x, &p)
|
||||
b := y.Bytes()
|
||||
if len(b) != 96 {
|
||||
panic(len(b))
|
||||
}
|
||||
return h.postWrite(b)
|
||||
}
|
||||
|
||||
func (h *handshake) establishS() (err error) {
|
||||
x := newX()
|
||||
h.postY(&x)
|
||||
var b [96]byte
|
||||
_, err = io.ReadFull(h.conn, b[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var Y big.Int
|
||||
Y.SetBytes(b[:])
|
||||
h.s.Exp(&Y, &x, &p)
|
||||
return
|
||||
}
|
||||
|
||||
func newPadLen() int64 {
|
||||
i, err := rand.Int(rand.Reader, &newPadLenMax)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ret := i.Int64()
|
||||
if ret < 0 || ret > maxPadLen {
|
||||
panic(ret)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
type handshake struct {
|
||||
conn io.ReadWriteCloser
|
||||
s big.Int
|
||||
initer bool
|
||||
skey []byte
|
||||
|
||||
writeMu sync.Mutex
|
||||
writes [][]byte
|
||||
writeErr error
|
||||
writeCond sync.Cond
|
||||
writeClose bool
|
||||
|
||||
writerMu sync.Mutex
|
||||
writerCond sync.Cond
|
||||
writerDone bool
|
||||
}
|
||||
|
||||
func (h *handshake) finishWriting() (err error) {
|
||||
h.writeMu.Lock()
|
||||
h.writeClose = true
|
||||
h.writeCond.Broadcast()
|
||||
err = h.writeErr
|
||||
h.writeMu.Unlock()
|
||||
|
||||
h.writerMu.Lock()
|
||||
for !h.writerDone {
|
||||
h.writerCond.Wait()
|
||||
}
|
||||
h.writerMu.Unlock()
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (h *handshake) writer() {
|
||||
defer func() {
|
||||
h.writerMu.Lock()
|
||||
h.writerDone = true
|
||||
h.writerCond.Broadcast()
|
||||
h.writerMu.Unlock()
|
||||
}()
|
||||
for {
|
||||
h.writeMu.Lock()
|
||||
for {
|
||||
if len(h.writes) != 0 {
|
||||
break
|
||||
}
|
||||
if h.writeClose {
|
||||
h.writeMu.Unlock()
|
||||
return
|
||||
}
|
||||
h.writeCond.Wait()
|
||||
}
|
||||
b := h.writes[0]
|
||||
h.writes = h.writes[1:]
|
||||
h.writeMu.Unlock()
|
||||
_, err := h.conn.Write(b)
|
||||
if err != nil {
|
||||
h.writeMu.Lock()
|
||||
h.writeErr = err
|
||||
h.writeMu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handshake) postWrite(b []byte) error {
|
||||
h.writeMu.Lock()
|
||||
defer h.writeMu.Unlock()
|
||||
if h.writeErr != nil {
|
||||
return h.writeErr
|
||||
}
|
||||
h.writes = append(h.writes, b)
|
||||
h.writeCond.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
func xor(dst, src []byte) (ret []byte) {
|
||||
max := len(dst)
|
||||
if max > len(src) {
|
||||
max = len(src)
|
||||
}
|
||||
ret = make([]byte, 0, max)
|
||||
for i := range iter.N(max) {
|
||||
ret = append(ret, dst[i]^src[i])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type cryptoNegotiation struct {
|
||||
VC [8]byte
|
||||
Method uint32
|
||||
PadLen uint16
|
||||
IA []byte
|
||||
}
|
||||
|
||||
func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
|
||||
_, err = io.ReadFull(r, me.VC[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Read(r, binary.BigEndian, &me.Method)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Read(r, binary.BigEndian, &me.PadLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Print(me.PadLen)
|
||||
_, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
|
||||
return
|
||||
}
|
||||
|
||||
func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
|
||||
_, err = w.Write(me.VC[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Write(w, binary.BigEndian, me.Method)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Write(w, binary.BigEndian, me.PadLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = w.Write(make([]byte, me.PadLen))
|
||||
return
|
||||
}
|
||||
|
||||
// Looking for b at the end of a.
|
||||
func suffixMatchLen(a, b []byte) int {
|
||||
if len(b) > len(a) {
|
||||
b = b[:len(a)]
|
||||
}
|
||||
// i is how much of b to try to match
|
||||
for i := len(b); i > 0; i-- {
|
||||
// j is how many chars we've compared
|
||||
j := 0
|
||||
for ; j < i; j++ {
|
||||
if b[i-1-j] != a[len(a)-1-j] {
|
||||
goto shorter
|
||||
}
|
||||
}
|
||||
return j
|
||||
shorter:
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func readUntil(r io.Reader, b []byte) error {
|
||||
log.Println("read until", b)
|
||||
b1 := make([]byte, len(b))
|
||||
i := 0
|
||||
for {
|
||||
_, err := io.ReadFull(r, b1[i:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i = suffixMatchLen(b1, b)
|
||||
if i == len(b) {
|
||||
break
|
||||
}
|
||||
if copy(b1, b1[len(b1)-i:]) != i {
|
||||
panic("wat")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handshake) Do() (ret io.ReadWriteCloser, err error) {
|
||||
err = h.establishS()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
pad := make([]byte, newPadLen())
|
||||
io.ReadFull(rand.Reader, pad)
|
||||
err = h.postWrite(pad)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if h.initer {
|
||||
h.postWrite(hash(req1, h.s.Bytes()))
|
||||
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
|
||||
buf := &bytes.Buffer{}
|
||||
err = (&cryptoNegotiation{
|
||||
Method: cryptoMethodRC4,
|
||||
PadLen: uint16(newPadLen()),
|
||||
}).MarshalWriter(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
e := newEncrypt(true, h.s.Bytes(), h.skey)
|
||||
be := make([]byte, buf.Len())
|
||||
e.XORKeyStream(be, buf.Bytes())
|
||||
h.postWrite(be)
|
||||
bC := newEncrypt(false, h.s.Bytes(), h.skey)
|
||||
var eVC [8]byte
|
||||
bC.XORKeyStream(eVC[:], make([]byte, 8))
|
||||
log.Print(eVC)
|
||||
// Read until the all zero VC.
|
||||
err = readUntil(h.conn, eVC[:])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading until VC: %s", err)
|
||||
return
|
||||
}
|
||||
var cn cryptoNegotiation
|
||||
r := &cipherReader{bC, h.conn}
|
||||
err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
|
||||
log.Printf("initer got %v", cn)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading crypto negotiation: %s", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
err = readUntil(h.conn, hash(req1, h.s.Bytes()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var b [20]byte
|
||||
_, err = io.ReadFull(h.conn, b[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())), b[:]) {
|
||||
err = errors.New("skey doesn't match")
|
||||
return
|
||||
}
|
||||
var cn cryptoNegotiation
|
||||
r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
|
||||
err = cn.UnmarshalReader(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Printf("receiver got %v", cn)
|
||||
if cn.Method&cryptoMethodRC4 == 0 {
|
||||
err = errors.New("no supported crypto methods were provided")
|
||||
return
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
w := newCipherWriter(newEncrypt(false, h.s.Bytes(), h.skey), buf)
|
||||
err = (&cryptoNegotiation{
|
||||
Method: cryptoMethodRC4,
|
||||
PadLen: uint16(newPadLen()),
|
||||
}).MarshalWriter(w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Println("encrypted VC", buf.Bytes()[:8])
|
||||
err = h.postWrite(buf.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = h.finishWriting()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret = h.conn
|
||||
return
|
||||
}
|
||||
|
||||
func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriteCloser, err error) {
|
||||
h := handshake{
|
||||
conn: rw,
|
||||
initer: initer,
|
||||
skey: skey,
|
||||
}
|
||||
h.writeCond.L = &h.writeMu
|
||||
h.writerCond.L = &h.writerMu
|
||||
go h.writer()
|
||||
return h.Do()
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package mse
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadUntil(t *testing.T) {
|
||||
test := func(data, until string, leftover int, expectedErr error) {
|
||||
r := bytes.NewReader([]byte(data))
|
||||
err := readUntil(r, []byte(until))
|
||||
if err != expectedErr {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if r.Len() != leftover {
|
||||
t.Fatal(r.Len())
|
||||
}
|
||||
}
|
||||
test("feakjfeafeafegbaabc00", "abc", 2, nil)
|
||||
test("feakjfeafeafegbaadc00", "abc", 0, io.EOF)
|
||||
}
|
||||
|
||||
func TestSuffixMatchLen(t *testing.T) {
|
||||
test := func(a, b string, expected int) {
|
||||
actual := suffixMatchLen([]byte(a), []byte(b))
|
||||
if actual != expected {
|
||||
t.Fatalf("expected %d, got %d for %q and %q", expected, actual, a, b)
|
||||
}
|
||||
}
|
||||
test("hello", "world", 0)
|
||||
test("hello", "lo", 2)
|
||||
test("hello", "llo", 3)
|
||||
test("hello", "hell", 0)
|
||||
test("hello", "helloooo!", 5)
|
||||
test("hello", "lol!", 2)
|
||||
test("hello", "mondo", 0)
|
||||
test("mongo", "webscale", 0)
|
||||
test("sup", "person", 1)
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a, err := Handshake(a, true, []byte("yep"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
a.Close()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b, err := Handshake(b, false, []byte("yep"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
b.Close()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
Loading…
Reference in New Issue