diff --git a/dht/closest_nodes.go b/dht/closest_nodes.go new file mode 100644 index 00000000..2f98e6d6 --- /dev/null +++ b/dht/closest_nodes.go @@ -0,0 +1,49 @@ +package dht + +import ( + "container/heap" +) + +type nodeMaxHeap struct { + IDs []string + Target string +} + +func (me nodeMaxHeap) Len() int { return len(me.IDs) } + +func (me nodeMaxHeap) Less(i, j int) bool { + return idDistance(me.IDs[i], me.Target) > idDistance(me.IDs[j], me.Target) +} + +func (me *nodeMaxHeap) Pop() (ret interface{}) { + ret, me.IDs = me.IDs[len(me.IDs)-1], me.IDs[:len(me.IDs)-1] + return +} +func (me *nodeMaxHeap) Push(val interface{}) { + me.IDs = append(me.IDs, val.(string)) +} +func (me nodeMaxHeap) Swap(i, j int) { + me.IDs[i], me.IDs[j] = me.IDs[i], me.IDs[j] +} + +type closestNodesSelector struct { + closest nodeMaxHeap + k int +} + +func (me *closestNodesSelector) Push(id string) { + heap.Push(&me.closest, id) + if me.closest.Len() > me.k { + heap.Pop(&me.closest) + } +} + +func (me *closestNodesSelector) IDs() []string { + return me.closest.IDs +} + +func newKClosestNodesSelector(k int, targetID string) (ret closestNodesSelector) { + ret.k = k + ret.closest.Target = targetID + return +} diff --git a/dht/dht.go b/dht/dht.go index 5235a72c..717c45ad 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -422,3 +422,29 @@ func (s *Server) GoodNodes() (nis []NodeInfo) { func (s *Server) StopServing() { s.Socket.Close() } + +func idDistance(a, b string) (ret int) { + if len(a) != 20 { + panic(a) + } + if len(b) != 20 { + panic(b) + } + for i := 0; i < 20; i++ { + for j := uint(0); j < 8; j++ { + ret += int(a[i]>>j&1 ^ b[i]>>j&1) + } + } + return +} + +// func (s *Server) closestNodes(k int) (ret *closestNodes) { +// heap.Init(ret) +// for _, node := range s.nodes { +// heap.Push(ret, node) +// if ret.Len() > k { +// heap.Pop(ret) +// } +// } +// return +// } diff --git a/dht/dht_test.go b/dht/dht_test.go index 7a945f5f..050ced6a 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -1,6 +1,7 @@ package dht import ( + "math/rand" "net" "testing" ) @@ -26,3 +27,65 @@ func TestMarshalCompactNodeInfo(t *testing.T) { t.FailNow() } } + +func recoverPanicOrDie(t *testing.T, f func()) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic") + } + }() + f() +} + +const zeroID = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + +var testIDs = []string{ + zeroID, + "\x03" + zeroID[1:], + "\x03" + zeroID[1:18] + "\x55\xf0", + "\x55" + zeroID[1:17] + "\xff\x55\x0f", +} + +func TestBadIdStrings(t *testing.T) { + var a, b string + recoverPanicOrDie(t, func() { + idDistance(a, b) + }) + recoverPanicOrDie(t, func() { + idDistance(a, zeroID) + }) + recoverPanicOrDie(t, func() { + idDistance(zeroID, b) + }) + if idDistance(zeroID, zeroID) != 0 { + t.FailNow() + } + a = "\x03" + zeroID[1:] + b = zeroID + if idDistance(a, b) != 2 { + t.FailNow() + } + a = "\x03" + zeroID[1:18] + "\x55\xf0" + b = "\x55" + zeroID[1:17] + "\xff\x55\x0f" + if c := idDistance(a, b); c != 20 { + t.Fatal(c) + } +} + +func TestClosestNodes(t *testing.T) { + cn := newKClosestNodesSelector(2, zeroID) + for _, i := range rand.Perm(len(testIDs)) { + cn.Push(testIDs[i]) + } + if len(cn.IDs()) != 2 { + t.FailNow() + } + m := map[string]bool{} + for _, id := range cn.IDs() { + m[id] = true + } + if !m[zeroID] || !m[testIDs[1]] { + t.FailNow() + } +}