From b9c7d6266b5b4612696f0f1aa3b18ee1f55343aa Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Mon, 21 Jun 2021 13:29:26 +1000 Subject: [PATCH] Tidy up the Dialer interface --- dialer.go | 41 +++++++++++++++-------------------------- socket.go | 5 +++-- test/unix_test.go | 2 +- 3 files changed, 19 insertions(+), 29 deletions(-) diff --git a/dialer.go b/dialer.go index e8126bd6..d499af30 100644 --- a/dialer.go +++ b/dialer.go @@ -3,43 +3,32 @@ package torrent import ( "context" "net" - - "github.com/anacrolix/missinggo/perf" ) +// Dialers have the network locked in. type Dialer interface { Dial(_ context.Context, addr string) (net.Conn, error) DialerNetwork() string } -type NetDialer struct { - Network string - Dialer net.Dialer +// An interface to ease wrapping dialers that explicitly include a network parameter. +type DialContexter interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) } -func (me NetDialer) DialerNetwork() string { +// Used by wrappers of standard library network types. +var DefaultNetDialer = &net.Dialer{} + +// Adapts a DialContexter to the Dial interface in this package. +type NetworkDialer struct { + Network string + Dialer DialContexter +} + +func (me NetworkDialer) DialerNetwork() string { return me.Network } -func (me NetDialer) Dial(ctx context.Context, addr string) (_ net.Conn, err error) { - defer perf.ScopeTimerErr(&err)() +func (me NetworkDialer) Dial(ctx context.Context, addr string) (_ net.Conn, err error) { return me.Dialer.DialContext(ctx, me.Network, addr) } - -func (me NetDialer) LocalAddr() net.Addr { - return netDialerLocalAddr{me.Network, me.Dialer.LocalAddr} -} - -type netDialerLocalAddr struct { - network string - addr net.Addr -} - -func (me netDialerLocalAddr) Network() string { return me.network } - -func (me netDialerLocalAddr) String() string { - if me.addr == nil { - return "" - } - return me.addr.String() -} diff --git a/socket.go b/socket.go index ba2a091b..7313f632 100644 --- a/socket.go +++ b/socket.go @@ -39,15 +39,16 @@ func listenTcp(network, address string) (s socket, err error) { l, err := net.Listen(network, address) return tcpSocket{ Listener: l, - NetDialer: NetDialer{ + NetworkDialer: NetworkDialer{ Network: network, + Dialer: DefaultNetDialer, }, }, err } type tcpSocket struct { net.Listener - NetDialer + NetworkDialer } func listenAll(networks []network, getHost func(string) string, port int, f firewallCallback) ([]socket, error) { diff --git a/test/unix_test.go b/test/unix_test.go index 1e877c0d..d8a3ff9f 100644 --- a/test/unix_test.go +++ b/test/unix_test.go @@ -24,7 +24,7 @@ func TestUnixConns(t *testing.T) { cfg.Debug = true }, Client: func(cl *torrent.Client) { - cl.AddDialer(torrent.NetDialer{Network: "unix"}) + cl.AddDialer(torrent.NetworkDialer{Network: "unix", Dialer: torrent.DefaultNetDialer}) l, err := net.Listen("unix", filepath.Join(t.TempDir(), "socket")) if err != nil { panic(err)