Expose DHT ID distances as an interface and switch to big.Int and possibly the correct algorithm

This commit is contained in:
Matt Joiner 2014-11-17 01:47:24 -06:00
parent c1049d0605
commit 02160eb8bc
3 changed files with 121 additions and 17 deletions

View File

@ -12,7 +12,7 @@ type nodeMaxHeap struct {
func (me nodeMaxHeap) Len() int { return len(me.IDs) }
func (me nodeMaxHeap) Less(i, j int) bool {
return idDistance(me.IDs[i], me.Target) > idDistance(me.IDs[j], me.Target)
return idDistance(me.IDs[i], me.Target).Cmp(idDistance(me.IDs[j], me.Target)) > 0
}
func (me *nodeMaxHeap) Pop() (ret interface{}) {

View File

@ -8,6 +8,7 @@ import (
"fmt"
"io"
"log"
"math/big"
"net"
"os"
"sync"
@ -901,19 +902,116 @@ func (s *Server) Close() {
s.mu.Unlock()
}
func idDistance(a, b string) (ret int) {
if len(a) != 20 {
panic(a)
type distance interface {
Cmp(distance) int
BitCount() int
IsZero() bool
}
type bigIntDistance struct {
*big.Int
}
// How many bits?
func bitCount(n *big.Int) int {
var count int = 0
for _, b := range n.Bytes() {
count += int(bitCounts[b])
}
if len(b) != 20 {
panic(b)
return count
}
// The bit counts for each byte value (0 - 255).
var bitCounts = []int8{
// Generated by Java BitCount of all values from 0 to 255
0, 1, 1, 2, 1, 2, 2, 3,
1, 2, 2, 3, 2, 3, 3, 4,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7,
5, 6, 6, 7, 6, 7, 7, 8,
}
func (me bigIntDistance) BitCount() int {
return bitCount(me.Int)
}
func (me bigIntDistance) Cmp(d distance) int {
return me.Int.Cmp(d.(bigIntDistance).Int)
}
func (me bigIntDistance) IsZero() bool {
return me.Int.Cmp(big.NewInt(0)) == 0
}
type bitCountDistance int
func (me bitCountDistance) BitCount() int { return int(me) }
func (me bitCountDistance) Cmp(rhs distance) int {
rhs_ := rhs.(bitCountDistance)
if me < rhs_ {
return -1
} else if me == rhs_ {
return 0
} else {
return 1
}
for i := 0; i < 20; i++ {
for j := uint(0); j < 8; j++ {
ret += int(a[i]>>j&1 ^ b[i]>>j&1)
}
func (me bitCountDistance) IsZero() bool {
return me == 0
}
func idDistance(a, b string) distance {
if true {
if len(a) != 20 {
panic(a)
}
if len(b) != 20 {
panic(b)
}
x := new(big.Int)
y := new(big.Int)
x.SetBytes([]byte(a))
y.SetBytes([]byte(b))
dist := new(big.Int)
return bigIntDistance{dist.Xor(x, y)}
} else {
ret := 0
for i := 0; i < 20; i++ {
for j := uint(0); j < 8; j++ {
ret += int(a[i]>>j&1 ^ b[i]>>j&1)
}
}
return bitCountDistance(ret)
}
return
}
func (s *Server) closestGoodNodes(k int, targetID string) []*Node {

View File

@ -1,11 +1,17 @@
package dht
import (
"math/big"
"math/rand"
"net"
"testing"
)
func TestSetNilBigInt(t *testing.T) {
i := new(big.Int)
i.SetBytes(make([]byte, 2))
}
func TestMarshalCompactNodeInfo(t *testing.T) {
cni := NodeInfo{
ID: [20]byte{'a', 'b', 'c'},
@ -49,13 +55,13 @@ var testIDs = []string{
}
func TestDistances(t *testing.T) {
if idDistance(testIDs[3], testIDs[0]) != 4+8+4+4 {
if idDistance(testIDs[3], testIDs[0]).BitCount() != 4+8+4+4 {
t.FailNow()
}
if idDistance(testIDs[3], testIDs[1]) != 4+8+4+4 {
if idDistance(testIDs[3], testIDs[1]).BitCount() != 4+8+4+4 {
t.FailNow()
}
if idDistance(testIDs[3], testIDs[2]) != 4+8+8 {
if idDistance(testIDs[3], testIDs[2]).BitCount() != 4+8+8 {
t.FailNow()
}
}
@ -71,17 +77,17 @@ func TestBadIdStrings(t *testing.T) {
recoverPanicOrDie(t, func() {
idDistance(zeroID, b)
})
if idDistance(zeroID, zeroID) != 0 {
t.FailNow()
if !idDistance(zeroID, zeroID).IsZero() {
t.Fatal("identical IDs should have distance 0")
}
a = "\x03" + zeroID[1:]
b = zeroID
if idDistance(a, b) != 2 {
if idDistance(a, b).BitCount() != 2 {
t.FailNow()
}
a = "\x03" + zeroID[1:18] + "\x55\xf0"
b = "\x55" + zeroID[1:17] + "\xff\x55\x0f"
if c := idDistance(a, b); c != 20 {
if c := idDistance(a, b).BitCount(); c != 20 {
t.Fatal(c)
}
}