Tidy up the Dialer interface

This commit is contained in:
Matt Joiner 2021-06-21 13:29:26 +10:00
parent 88d144e65e
commit b9c7d6266b
3 changed files with 19 additions and 29 deletions

View File

@ -3,43 +3,32 @@ package torrent
import (
"context"
"net"
"github.com/anacrolix/missinggo/perf"
)
// Dialers have the network locked in.
type Dialer interface {
Dial(_ context.Context, addr string) (net.Conn, error)
DialerNetwork() string
}
type NetDialer struct {
Network string
Dialer net.Dialer
// An interface to ease wrapping dialers that explicitly include a network parameter.
type DialContexter interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}
func (me NetDialer) DialerNetwork() string {
// Used by wrappers of standard library network types.
var DefaultNetDialer = &net.Dialer{}
// Adapts a DialContexter to the Dial interface in this package.
type NetworkDialer struct {
Network string
Dialer DialContexter
}
func (me NetworkDialer) DialerNetwork() string {
return me.Network
}
func (me NetDialer) Dial(ctx context.Context, addr string) (_ net.Conn, err error) {
defer perf.ScopeTimerErr(&err)()
func (me NetworkDialer) Dial(ctx context.Context, addr string) (_ net.Conn, err error) {
return me.Dialer.DialContext(ctx, me.Network, addr)
}
func (me NetDialer) LocalAddr() net.Addr {
return netDialerLocalAddr{me.Network, me.Dialer.LocalAddr}
}
type netDialerLocalAddr struct {
network string
addr net.Addr
}
func (me netDialerLocalAddr) Network() string { return me.network }
func (me netDialerLocalAddr) String() string {
if me.addr == nil {
return ""
}
return me.addr.String()
}

View File

@ -39,15 +39,16 @@ func listenTcp(network, address string) (s socket, err error) {
l, err := net.Listen(network, address)
return tcpSocket{
Listener: l,
NetDialer: NetDialer{
NetworkDialer: NetworkDialer{
Network: network,
Dialer: DefaultNetDialer,
},
}, err
}
type tcpSocket struct {
net.Listener
NetDialer
NetworkDialer
}
func listenAll(networks []network, getHost func(string) string, port int, f firewallCallback) ([]socket, error) {

View File

@ -24,7 +24,7 @@ func TestUnixConns(t *testing.T) {
cfg.Debug = true
},
Client: func(cl *torrent.Client) {
cl.AddDialer(torrent.NetDialer{Network: "unix"})
cl.AddDialer(torrent.NetworkDialer{Network: "unix", Dialer: torrent.DefaultNetDialer})
l, err := net.Listen("unix", filepath.Join(t.TempDir(), "socket"))
if err != nil {
panic(err)