diff --git a/dht/closest_nodes.go b/dht/closest_nodes.go index a677a2e6..70366dd5 100644 --- a/dht/closest_nodes.go +++ b/dht/closest_nodes.go @@ -12,7 +12,7 @@ type nodeMaxHeap struct { func (me nodeMaxHeap) Len() int { return len(me.IDs) } func (me nodeMaxHeap) Less(i, j int) bool { - return idDistance(me.IDs[i], me.Target) > idDistance(me.IDs[j], me.Target) + return idDistance(me.IDs[i], me.Target).Cmp(idDistance(me.IDs[j], me.Target)) > 0 } func (me *nodeMaxHeap) Pop() (ret interface{}) { diff --git a/dht/dht.go b/dht/dht.go index 6e5d29cb..094a6bc6 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log" + "math/big" "net" "os" "sync" @@ -901,19 +902,116 @@ func (s *Server) Close() { s.mu.Unlock() } -func idDistance(a, b string) (ret int) { - if len(a) != 20 { - panic(a) +type distance interface { + Cmp(distance) int + BitCount() int + IsZero() bool +} + +type bigIntDistance struct { + *big.Int +} + +// How many bits? +func bitCount(n *big.Int) int { + var count int = 0 + for _, b := range n.Bytes() { + count += int(bitCounts[b]) } - if len(b) != 20 { - panic(b) + return count +} + +// The bit counts for each byte value (0 - 255). +var bitCounts = []int8{ + // Generated by Java BitCount of all values from 0 to 255 + 0, 1, 1, 2, 1, 2, 2, 3, + 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, + 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, + 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, + 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, + 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, + 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, + 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, + 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, + 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, + 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, + 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, + 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, + 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, + 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, + 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, + 5, 6, 6, 7, 6, 7, 7, 8, +} + +func (me bigIntDistance) BitCount() int { + return bitCount(me.Int) +} + +func (me bigIntDistance) Cmp(d distance) int { + return me.Int.Cmp(d.(bigIntDistance).Int) +} + +func (me bigIntDistance) IsZero() bool { + return me.Int.Cmp(big.NewInt(0)) == 0 +} + +type bitCountDistance int + +func (me bitCountDistance) BitCount() int { return int(me) } + +func (me bitCountDistance) Cmp(rhs distance) int { + rhs_ := rhs.(bitCountDistance) + if me < rhs_ { + return -1 + } else if me == rhs_ { + return 0 + } else { + return 1 } - for i := 0; i < 20; i++ { - for j := uint(0); j < 8; j++ { - ret += int(a[i]>>j&1 ^ b[i]>>j&1) +} + +func (me bitCountDistance) IsZero() bool { + return me == 0 +} + +func idDistance(a, b string) distance { + if true { + if len(a) != 20 { + panic(a) } + if len(b) != 20 { + panic(b) + } + x := new(big.Int) + y := new(big.Int) + x.SetBytes([]byte(a)) + y.SetBytes([]byte(b)) + dist := new(big.Int) + return bigIntDistance{dist.Xor(x, y)} + } else { + ret := 0 + for i := 0; i < 20; i++ { + for j := uint(0); j < 8; j++ { + ret += int(a[i]>>j&1 ^ b[i]>>j&1) + } + } + return bitCountDistance(ret) } - return } func (s *Server) closestGoodNodes(k int, targetID string) []*Node { diff --git a/dht/dht_test.go b/dht/dht_test.go index 71afc5eb..5b5b5d1d 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -1,11 +1,17 @@ package dht import ( + "math/big" "math/rand" "net" "testing" ) +func TestSetNilBigInt(t *testing.T) { + i := new(big.Int) + i.SetBytes(make([]byte, 2)) +} + func TestMarshalCompactNodeInfo(t *testing.T) { cni := NodeInfo{ ID: [20]byte{'a', 'b', 'c'}, @@ -49,13 +55,13 @@ var testIDs = []string{ } func TestDistances(t *testing.T) { - if idDistance(testIDs[3], testIDs[0]) != 4+8+4+4 { + if idDistance(testIDs[3], testIDs[0]).BitCount() != 4+8+4+4 { t.FailNow() } - if idDistance(testIDs[3], testIDs[1]) != 4+8+4+4 { + if idDistance(testIDs[3], testIDs[1]).BitCount() != 4+8+4+4 { t.FailNow() } - if idDistance(testIDs[3], testIDs[2]) != 4+8+8 { + if idDistance(testIDs[3], testIDs[2]).BitCount() != 4+8+8 { t.FailNow() } } @@ -71,17 +77,17 @@ func TestBadIdStrings(t *testing.T) { recoverPanicOrDie(t, func() { idDistance(zeroID, b) }) - if idDistance(zeroID, zeroID) != 0 { - t.FailNow() + if !idDistance(zeroID, zeroID).IsZero() { + t.Fatal("identical IDs should have distance 0") } a = "\x03" + zeroID[1:] b = zeroID - if idDistance(a, b) != 2 { + if idDistance(a, b).BitCount() != 2 { t.FailNow() } a = "\x03" + zeroID[1:18] + "\x55\xf0" b = "\x55" + zeroID[1:17] + "\xff\x55\x0f" - if c := idDistance(a, b); c != 20 { + if c := idDistance(a, b).BitCount(); c != 20 { t.Fatal(c) } }