FedP2P/mse/mse.go

503 lines
9.7 KiB
Go
Raw Normal View History

// https://wiki.vuze.com/w/Message_Stream_Encryption
package mse
import (
"bytes"
"crypto/rand"
"crypto/rc4"
"crypto/sha1"
"encoding/binary"
"errors"
"expvar"
"fmt"
"io"
"io/ioutil"
"math/big"
"strconv"
"sync"
"github.com/anacrolix/torrent/util"
2015-03-18 15:14:57 +08:00
"github.com/bradfitz/iter"
)
const (
maxPadLen = 512
cryptoMethodPlaintext = 1
cryptoMethodRC4 = 2
)
var (
// Prime P according to the spec, and G, the generator.
p, g big.Int
// The rand.Int max arg for use in newPadLen()
newPadLenMax big.Int
// For use in initer's hashes
req1 = []byte("req1")
req2 = []byte("req2")
req3 = []byte("req3")
2015-03-18 15:14:57 +08:00
// Verification constant "VC" which is all zeroes in the bittorrent
// implementation.
vc [8]byte
// Zero padding
zeroPad [512]byte
// Tracks counts of received crypto_provides
cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
)
func init() {
p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
g.SetInt64(2)
newPadLenMax.SetInt64(maxPadLen + 1)
}
func hash(parts ...[]byte) []byte {
h := sha1.New()
for _, p := range parts {
n, err := h.Write(p)
if err != nil {
panic(err)
}
if n != len(p) {
panic(n)
}
}
return h.Sum(nil)
}
func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
c, err := rc4.NewCipher(hash([]byte(func() string {
if initer {
return "keyA"
} else {
return "keyB"
}
}()), s, skey))
if err != nil {
panic(err)
}
var burnSrc, burnDst [1024]byte
c.XORKeyStream(burnDst[:], burnSrc[:])
return
}
type cipherReader struct {
c *rc4.Cipher
r io.Reader
}
func (me *cipherReader) Read(b []byte) (n int, err error) {
be := make([]byte, len(b))
n, err = me.r.Read(be)
me.c.XORKeyStream(b[:n], be[:n])
return
}
func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
return &cipherReader{c, r}
}
type cipherWriter struct {
c *rc4.Cipher
w io.Writer
}
func (me *cipherWriter) Write(b []byte) (n int, err error) {
be := make([]byte, len(b))
me.c.XORKeyStream(be, b)
n, err = me.w.Write(be)
if n != len(be) {
// The cipher will have advanced beyond the callers stream position.
// We can't use the cipher anymore.
me.c = nil
}
return
}
func readY(r io.Reader) (y big.Int, err error) {
var b [96]byte
_, err = io.ReadFull(r, b[:])
if err != nil {
return
}
y.SetBytes(b[:])
return
}
func newX() big.Int {
var X big.Int
X.SetBytes(func() []byte {
var b [20]byte
_, err := rand.Read(b[:])
if err != nil {
panic(err)
}
return b[:]
}())
return X
}
2015-03-18 15:14:57 +08:00
func paddedLeft(b []byte, _len int) []byte {
if len(b) == _len {
return b
}
ret := make([]byte, _len)
if n := copy(ret[_len-len(b):], b); n != len(b) {
panic(n)
}
return ret
}
2015-03-13 03:16:49 +08:00
// Calculate, and send Y, our public key.
func (h *handshake) postY(x *big.Int) error {
var y big.Int
y.Exp(&g, x, &p)
2015-03-18 15:14:57 +08:00
return h.postWrite(paddedLeft(y.Bytes(), 96))
}
func (h *handshake) establishS() (err error) {
x := newX()
h.postY(&x)
var b [96]byte
_, err = io.ReadFull(h.conn, b[:])
if err != nil {
return
}
2015-03-18 15:14:57 +08:00
var Y, S big.Int
Y.SetBytes(b[:])
2015-03-18 15:14:57 +08:00
S.Exp(&Y, &x, &p)
util.CopyExact(&h.s, paddedLeft(S.Bytes(), 96))
return
}
func newPadLen() int64 {
i, err := rand.Int(rand.Reader, &newPadLenMax)
if err != nil {
panic(err)
}
ret := i.Int64()
if ret < 0 || ret > maxPadLen {
panic(ret)
}
return ret
}
type handshake struct {
2015-03-18 15:14:57 +08:00
conn io.ReadWriter
s [96]byte
initer bool
2015-03-13 03:16:49 +08:00
skeys [][]byte
skey []byte
ia []byte // Initial payload. Only used by the initiator.
writeMu sync.Mutex
writes [][]byte
writeErr error
writeCond sync.Cond
writeClose bool
writerMu sync.Mutex
writerCond sync.Cond
writerDone bool
}
2015-03-18 15:14:57 +08:00
func (h *handshake) finishWriting() {
h.writeMu.Lock()
h.writeClose = true
h.writeCond.Broadcast()
h.writeMu.Unlock()
h.writerMu.Lock()
for !h.writerDone {
h.writerCond.Wait()
}
h.writerMu.Unlock()
return
}
func (h *handshake) writer() {
defer func() {
h.writerMu.Lock()
h.writerDone = true
h.writerCond.Broadcast()
h.writerMu.Unlock()
}()
for {
h.writeMu.Lock()
for {
if len(h.writes) != 0 {
break
}
if h.writeClose {
h.writeMu.Unlock()
return
}
h.writeCond.Wait()
}
b := h.writes[0]
h.writes = h.writes[1:]
h.writeMu.Unlock()
_, err := h.conn.Write(b)
if err != nil {
h.writeMu.Lock()
h.writeErr = err
h.writeMu.Unlock()
return
}
}
}
func (h *handshake) postWrite(b []byte) error {
h.writeMu.Lock()
defer h.writeMu.Unlock()
if h.writeErr != nil {
return h.writeErr
}
h.writes = append(h.writes, b)
h.writeCond.Signal()
return nil
}
func xor(dst, src []byte) (ret []byte) {
max := len(dst)
if max > len(src) {
max = len(src)
}
ret = make([]byte, 0, max)
for i := range iter.N(max) {
ret = append(ret, dst[i]^src[i])
}
return
}
2015-03-13 03:16:49 +08:00
func marshal(w io.Writer, data ...interface{}) (err error) {
for _, data := range data {
err = binary.Write(w, binary.BigEndian, data)
if err != nil {
break
}
}
return
}
func unmarshal(r io.Reader, data ...interface{}) (err error) {
for _, data := range data {
err = binary.Read(r, binary.BigEndian, data)
if err != nil {
break
}
}
return
}
// Looking for b at the end of a.
func suffixMatchLen(a, b []byte) int {
if len(b) > len(a) {
b = b[:len(a)]
}
// i is how much of b to try to match
for i := len(b); i > 0; i-- {
// j is how many chars we've compared
j := 0
for ; j < i; j++ {
if b[i-1-j] != a[len(a)-1-j] {
goto shorter
}
}
return j
shorter:
}
return 0
}
func readUntil(r io.Reader, b []byte) error {
b1 := make([]byte, len(b))
i := 0
for {
_, err := io.ReadFull(r, b1[i:])
if err != nil {
return err
}
i = suffixMatchLen(b1, b)
if i == len(b) {
break
}
if copy(b1, b1[len(b1)-i:]) != i {
panic("wat")
}
}
return nil
}
type readWriter struct {
io.Reader
io.Writer
}
2015-03-18 15:14:57 +08:00
func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
return newEncrypt(initer, h.s[:], h.skey)
}
2015-03-13 03:16:49 +08:00
func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
2015-03-18 15:14:57 +08:00
h.postWrite(hash(req1, h.s[:]))
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
2015-03-13 03:16:49 +08:00
buf := &bytes.Buffer{}
2015-03-18 15:14:57 +08:00
padLen := uint16(newPadLen())
err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
2015-03-13 03:16:49 +08:00
if err != nil {
return
}
2015-03-18 15:14:57 +08:00
e := h.newEncrypt(true)
2015-03-13 03:16:49 +08:00
be := make([]byte, buf.Len())
e.XORKeyStream(be, buf.Bytes())
h.postWrite(be)
2015-03-18 15:14:57 +08:00
bC := h.newEncrypt(false)
2015-03-13 03:16:49 +08:00
var eVC [8]byte
2015-03-18 15:14:57 +08:00
bC.XORKeyStream(eVC[:], vc[:])
// Read until the all zero VC. At this point we've only read the 96 byte
// public key, Y. There is potentially 512 byte padding, between us and
// the 8 byte verification constant.
err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
2015-03-13 03:16:49 +08:00
if err != nil {
2015-03-18 15:14:57 +08:00
if err == io.EOF {
err = errors.New("failed to synchronize on VC")
} else {
err = fmt.Errorf("error reading until VC: %s", err)
}
2015-03-13 03:16:49 +08:00
return
}
r := &cipherReader{bC, h.conn}
2015-03-18 15:14:57 +08:00
var method uint32
err = unmarshal(r, &method, &padLen)
if err != nil {
return
}
if method != cryptoMethodRC4 {
err = fmt.Errorf("receiver chose unsupported method: %x", method)
return
}
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
2015-03-13 03:16:49 +08:00
if err != nil {
return
}
ret = readWriter{r, &cipherWriter{e, h.conn}}
return
}
2015-03-18 15:14:57 +08:00
var ErrNoSecretKeyMatch = errors.New("no skey matched")
2015-03-13 03:16:49 +08:00
func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
2015-03-18 15:14:57 +08:00
// There is up to 512 bytes of padding, then the 20 byte hash.
err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
2015-03-13 03:16:49 +08:00
if err != nil {
2015-03-18 15:14:57 +08:00
if err == io.EOF {
err = errors.New("failed to synchronize on S hash")
}
2015-03-13 03:16:49 +08:00
return
}
var b [20]byte
_, err = io.ReadFull(h.conn, b[:])
if err != nil {
return
}
2015-03-18 15:14:57 +08:00
err = ErrNoSecretKeyMatch
2015-03-13 03:16:49 +08:00
for _, skey := range h.skeys {
2015-03-18 15:14:57 +08:00
if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
2015-03-13 03:16:49 +08:00
h.skey = skey
err = nil
break
}
}
if err != nil {
return
}
2015-03-18 15:14:57 +08:00
r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
var (
vc [8]byte
method uint32
padLen uint16
)
err = unmarshal(r, vc[:], &method, &padLen)
2015-03-13 03:16:49 +08:00
if err != nil {
return
}
2015-03-18 15:14:57 +08:00
cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
if method&cryptoMethodRC4 == 0 {
2015-03-13 03:16:49 +08:00
err = errors.New("no supported crypto methods were provided")
return
}
2015-03-18 15:14:57 +08:00
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
if err != nil {
return
}
var lenIA uint16
unmarshal(r, &lenIA)
if lenIA != 0 {
h.ia = make([]byte, lenIA)
unmarshal(r, h.ia)
}
2015-03-13 03:16:49 +08:00
buf := &bytes.Buffer{}
2015-03-18 15:14:57 +08:00
w := cipherWriter{h.newEncrypt(false), buf}
padLen = uint16(newPadLen())
err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
2015-03-13 03:16:49 +08:00
if err != nil {
return
}
err = h.postWrite(buf.Bytes())
if err != nil {
return
}
ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}}
2015-03-13 03:16:49 +08:00
return
}
func (h *handshake) Do() (ret io.ReadWriter, err error) {
2015-03-18 15:14:57 +08:00
h.writeCond.L = &h.writeMu
h.writerCond.L = &h.writerMu
go h.writer()
defer func() {
h.finishWriting()
if err == nil {
err = h.writeErr
}
}()
err = h.establishS()
if err != nil {
2015-03-13 03:16:49 +08:00
err = fmt.Errorf("error while establishing secret: %s", err)
return
}
pad := make([]byte, newPadLen())
io.ReadFull(rand.Reader, pad)
err = h.postWrite(pad)
if err != nil {
return
}
if h.initer {
2015-03-13 03:16:49 +08:00
ret, err = h.initerSteps()
} else {
2015-03-13 03:16:49 +08:00
ret, err = h.receiverSteps()
}
return
}
2015-03-18 15:14:57 +08:00
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
h := handshake{
conn: rw,
2015-03-13 03:16:49 +08:00
initer: true,
skey: skey,
ia: initialPayload,
}
return h.Do()
}
2015-03-18 15:14:57 +08:00
func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
2015-03-13 03:16:49 +08:00
h := handshake{
conn: rw,
initer: false,
skeys: skeys,
}
return h.Do()
}