iplist: Fail invalid IPs, they were always passing

This commit is contained in:
Matt Joiner 2015-03-28 02:54:17 +11:00
parent 5ecde3a874
commit 4084cad34b
2 changed files with 31 additions and 6 deletions

View File

@ -45,15 +45,22 @@ func (me *IPList) Lookup(ip net.IP) (r *Range) {
// TODO: Perhaps all addresses should be converted to IPv6, if the future // TODO: Perhaps all addresses should be converted to IPv6, if the future
// of IP is to always be backwards compatible. But this will cost 4x the // of IP is to always be backwards compatible. But this will cost 4x the
// memory for IPv4 addresses? // memory for IPv4 addresses?
if v4 := ip.To4(); v4 != nil { v4 := ip.To4()
if v4 != nil {
r = me.lookup(v4) r = me.lookup(v4)
if r != nil { if r != nil {
return return
} }
} }
if v6 := ip.To16(); v6 != nil { v6 := ip.To16()
if v6 != nil {
return me.lookup(v6) return me.lookup(v6)
} }
if v4 == nil && v6 == nil {
return &Range{
Description: fmt.Sprintf("unsupported IP: %s", ip),
}
}
return nil return nil
} }

View File

@ -73,6 +73,22 @@ func connRemoteAddrIP(network, laddr string, dialHost string) net.IP {
return ret return ret
} }
func TestBadIP(t *testing.T) {
iplist := New(nil)
if iplist.Lookup(net.IP(make([]byte, 4))) != nil {
t.FailNow()
}
if iplist.Lookup(net.IP(make([]byte, 16))) != nil {
t.FailNow()
}
if iplist.Lookup(nil) == nil {
t.FailNow()
}
if iplist.Lookup(net.IP(make([]byte, 5))) == nil {
t.FailNow()
}
}
func TestSimple(t *testing.T) { func TestSimple(t *testing.T) {
ranges, err := sampleRanges(t) ranges, err := sampleRanges(t)
if err != nil { if err != nil {
@ -90,14 +106,16 @@ func TestSimple(t *testing.T) {
{"1.2.3.255", false, ""}, {"1.2.3.255", false, ""},
{"1.2.8.0", true, "b"}, {"1.2.8.0", true, "b"},
{"1.2.4.255", true, "a"}, {"1.2.4.255", true, "a"},
// Try to roll over to the next octet on the parse. // Try to roll over to the next octet on the parse. Note the final
{"1.2.7.256", false, ""}, // octet is overbounds. In the next case.
{"1.2.7.256", true, "unsupported IP: <nil>"},
{"1.2.8.254", true, "b"}, {"1.2.8.254", true, "b"},
} { } {
r := iplist.Lookup(net.ParseIP(_case.IP)) ip := net.ParseIP(_case.IP)
r := iplist.Lookup(ip)
if !_case.Hit { if !_case.Hit {
if r != nil { if r != nil {
t.Fatalf("got hit when none was expected") t.Fatalf("got hit when none was expected: %s", ip)
} }
continue continue
} }