From 333c878d2bb74098614b43f110e349665376eea5 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Tue, 22 Jun 2021 22:36:43 +1000 Subject: [PATCH] Rewrite UDP tracker client --- tracker/server.go | 35 ++--- tracker/tracker.go | 24 +--- tracker/udp.go | 274 ++++-------------------------------- tracker/udp/announce.go | 35 +++++ tracker/udp/client.go | 132 +++++++++++++++++ tracker/udp/dispatcher.go | 64 +++++++++ tracker/udp/options.go | 24 ++++ tracker/udp/protocol.go | 69 +++++++++ tracker/udp/timeout.go | 18 +++ tracker/udp/timeout_test.go | 15 ++ tracker/udp/transaction.go | 23 +++ tracker/udp_test.go | 50 ++++--- 12 files changed, 462 insertions(+), 301 deletions(-) create mode 100644 tracker/udp/announce.go create mode 100644 tracker/udp/client.go create mode 100644 tracker/udp/dispatcher.go create mode 100644 tracker/udp/options.go create mode 100644 tracker/udp/protocol.go create mode 100644 tracker/udp/timeout.go create mode 100644 tracker/udp/timeout_test.go create mode 100644 tracker/udp/transaction.go diff --git a/tracker/server.go b/tracker/server.go index 34417be6..59c64f1a 100644 --- a/tracker/server.go +++ b/tracker/server.go @@ -10,6 +10,7 @@ import ( "github.com/anacrolix/dht/v2/krpc" "github.com/anacrolix/missinggo" + "github.com/anacrolix/torrent/tracker/udp" ) type torrent struct { @@ -36,7 +37,7 @@ func marshal(parts ...interface{}) (ret []byte, err error) { return } -func (s *server) respond(addr net.Addr, rh ResponseHeader, parts ...interface{}) (err error) { +func (s *server) respond(addr net.Addr, rh udp.ResponseHeader, parts ...interface{}) (err error) { b, err := marshal(append([]interface{}{rh}, parts...)...) if err != nil { return @@ -61,34 +62,34 @@ func (s *server) serveOne() (err error) { return } r := bytes.NewReader(b[:n]) - var h RequestHeader - err = readBody(r, &h) + var h udp.RequestHeader + err = udp.Read(r, &h) if err != nil { return } switch h.Action { - case ActionConnect: - if h.ConnectionId != connectRequestConnectionId { + case udp.ActionConnect: + if h.ConnectionId != udp.ConnectRequestConnectionId { return } connId := s.newConn() - err = s.respond(addr, ResponseHeader{ - ActionConnect, + err = s.respond(addr, udp.ResponseHeader{ + udp.ActionConnect, h.TransactionId, - }, ConnectionResponse{ + }, udp.ConnectionResponse{ connId, }) return - case ActionAnnounce: + case udp.ActionAnnounce: if _, ok := s.conns[h.ConnectionId]; !ok { - s.respond(addr, ResponseHeader{ + s.respond(addr, udp.ResponseHeader{ TransactionId: h.TransactionId, - Action: ActionError, + Action: udp.ActionError, }, []byte("not connected")) return } var ar AnnounceRequest - err = readBody(r, &ar) + err = udp.Read(r, &ar) if err != nil { return } @@ -104,10 +105,10 @@ func (s *server) serveOne() (err error) { if err != nil { panic(err) } - err = s.respond(addr, ResponseHeader{ + err = s.respond(addr, udp.ResponseHeader{ TransactionId: h.TransactionId, - Action: ActionAnnounce, - }, AnnounceResponseHeader{ + Action: udp.ActionAnnounce, + }, udp.AnnounceResponseHeader{ Interval: 900, Leechers: t.Leechers, Seeders: t.Seeders, @@ -115,9 +116,9 @@ func (s *server) serveOne() (err error) { return default: err = fmt.Errorf("unhandled action: %d", h.Action) - s.respond(addr, ResponseHeader{ + s.respond(addr, udp.ResponseHeader{ TransactionId: h.TransactionId, - Action: ActionError, + Action: udp.ActionError, }, []byte("unhandled action")) return } diff --git a/tracker/tracker.go b/tracker/tracker.go index 1b6d1412..0a187574 100644 --- a/tracker/tracker.go +++ b/tracker/tracker.go @@ -8,23 +8,10 @@ import ( "time" "github.com/anacrolix/dht/v2/krpc" + "github.com/anacrolix/torrent/tracker/udp" ) -// Marshalled as binary by the UDP client, so be careful making changes. -type AnnounceRequest struct { - InfoHash [20]byte - PeerId [20]byte - Downloaded int64 - Left int64 // If less than 0, math.MaxInt64 will be used for HTTP trackers instead. - Uploaded int64 - // Apparently this is optional. None can be used for announces done at - // regular intervals. - Event AnnounceEvent - IPAddress uint32 - Key int32 - NumWant int32 // How many peer addresses are desired. -1 for default. - Port uint16 -} // 82 bytes +type AnnounceRequest = udp.AnnounceRequest type AnnounceResponse struct { Interval int32 // Minimum seconds the local peer should wait before next announce. @@ -33,12 +20,7 @@ type AnnounceResponse struct { Peers []Peer } -type AnnounceEvent int32 - -func (e AnnounceEvent) String() string { - // See BEP 3, "event", and https://github.com/anacrolix/torrent/issues/416#issuecomment-751427001. - return []string{"", "completed", "started", "stopped"}[e] -} +type AnnounceEvent = udp.AnnounceEvent const ( None AnnounceEvent = iota diff --git a/tracker/udp.go b/tracker/udp.go index 033598e5..4fb00b11 100644 --- a/tracker/udp.go +++ b/tracker/udp.go @@ -1,270 +1,83 @@ package tracker import ( - "bytes" - "context" "encoding" "encoding/binary" - "fmt" - "io" - "math/rand" "net" "net/url" - "time" "github.com/anacrolix/dht/v2/krpc" "github.com/anacrolix/missinggo" - "github.com/anacrolix/missinggo/pproffd" - "github.com/pkg/errors" + "github.com/anacrolix/torrent/tracker/udp" ) -type Action int32 - -const ( - ActionConnect Action = iota - ActionAnnounce - ActionScrape - ActionError - - connectRequestConnectionId = 0x41727101980 - - // BEP 41 - optionTypeEndOfOptions = 0 - optionTypeNOP = 1 - optionTypeURLData = 2 -) - -type ConnectionRequest struct { - ConnectionId int64 - Action int32 - TransctionId int32 -} - -type ConnectionResponse struct { - ConnectionId int64 -} - -type ResponseHeader struct { - Action Action - TransactionId int32 -} - -type RequestHeader struct { - ConnectionId int64 - Action Action - TransactionId int32 -} // 16 bytes - -type AnnounceResponseHeader struct { - Interval int32 - Leechers int32 - Seeders int32 -} - -func newTransactionId() int32 { - return int32(rand.Uint32()) -} - -func timeout(contiguousTimeouts int) (d time.Duration) { - if contiguousTimeouts > 8 { - contiguousTimeouts = 8 - } - d = 15 * time.Second - for ; contiguousTimeouts > 0; contiguousTimeouts-- { - d *= 2 - } - return -} - type udpAnnounce struct { - contiguousTimeouts int - connectionIdReceived time.Time - connectionId int64 - socket net.Conn - url url.URL - a *Announce + url url.URL + a *Announce } func (c *udpAnnounce) Close() error { - if c.socket != nil { - return c.socket.Close() - } return nil } -func (c *udpAnnounce) ipv6() bool { +func (c *udpAnnounce) ipv6(conn net.Conn) bool { if c.a.UdpNetwork == "udp6" { return true } - rip := missinggo.AddrIP(c.socket.RemoteAddr()) + rip := missinggo.AddrIP(conn.RemoteAddr()) return rip.To16() != nil && rip.To4() == nil } func (c *udpAnnounce) Do(req AnnounceRequest) (res AnnounceResponse, err error) { - err = c.connect() + conn, err := net.Dial(c.dialNetwork(), c.url.Host) if err != nil { return } - reqURI := c.url.RequestURI() - if c.ipv6() { + defer conn.Close() + if c.ipv6(conn) { // BEP 15 req.IPAddress = 0 } else if req.IPAddress == 0 && c.a.ClientIp4.IP != nil { req.IPAddress = binary.BigEndian.Uint32(c.a.ClientIp4.IP.To4()) } - // Clearly this limits the request URI to 255 bytes. BEP 41 supports - // longer but I'm not fussed. - options := append([]byte{optionTypeURLData, byte(len(reqURI))}, []byte(reqURI)...) - vars.Add("udp tracker announces", 1) - b, err := c.request(ActionAnnounce, req, options) - if err != nil { - return - } - var h AnnounceResponseHeader - err = readBody(b, &h) - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF + d := udp.Dispatcher{} + go func() { + for { + b := make([]byte, 0x800) + n, err := conn.Read(b) + if err != nil { + break + } + d.Dispatch(b[:n]) } - err = fmt.Errorf("error parsing announce response: %s", err) - return + }() + cl := udp.Client{ + Dispatcher: &d, + Writer: conn, } - res.Interval = h.Interval - res.Leechers = h.Leechers - res.Seeders = h.Seeders nas := func() interface { encoding.BinaryUnmarshaler NodeAddrs() []krpc.NodeAddr } { - if c.ipv6() { + if c.ipv6(conn) { return &krpc.CompactIPv6NodeAddrs{} } else { return &krpc.CompactIPv4NodeAddrs{} } }() - err = nas.UnmarshalBinary(b.Bytes()) + h, err := cl.Announce(c.a.Context, req, nas, udp.Options{RequestUri: c.url.RequestURI()}) if err != nil { return } + res.Interval = h.Interval + res.Leechers = h.Leechers + res.Seeders = h.Seeders for _, cp := range nas.NodeAddrs() { res.Peers = append(res.Peers, Peer{}.FromNodeAddr(cp)) } return } -// body is the binary serializable request body. trailer is optional data -// following it, such as for BEP 41. -func (c *udpAnnounce) write(h *RequestHeader, body interface{}, trailer []byte) (err error) { - var buf bytes.Buffer - err = binary.Write(&buf, binary.BigEndian, h) - if err != nil { - panic(err) - } - if body != nil { - err = binary.Write(&buf, binary.BigEndian, body) - if err != nil { - panic(err) - } - } - _, err = buf.Write(trailer) - if err != nil { - return - } - n, err := c.socket.Write(buf.Bytes()) - if err != nil { - return - } - if n != buf.Len() { - panic("write should send all or error") - } - return -} - -func read(r io.Reader, data interface{}) error { - return binary.Read(r, binary.BigEndian, data) -} - -func write(w io.Writer, data interface{}) error { - return binary.Write(w, binary.BigEndian, data) -} - -// args is the binary serializable request body. trailer is optional data -// following it, such as for BEP 41. -func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (*bytes.Buffer, error) { - tid := newTransactionId() - if err := errors.Wrap( - c.write( - &RequestHeader{ - ConnectionId: c.connectionId, - Action: action, - TransactionId: tid, - }, args, options), - "writing request", - ); err != nil { - return nil, err - } - c.socket.SetReadDeadline(time.Now().Add(timeout(c.contiguousTimeouts))) - b := make([]byte, 0x800) // 2KiB - for { - var ( - n int - readErr error - readDone = make(chan struct{}) - ) - go func() { - defer close(readDone) - n, readErr = c.socket.Read(b) - }() - ctx := c.a.Context - if ctx == nil { - ctx = context.Background() - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-readDone: - } - if opE, ok := readErr.(*net.OpError); ok && opE.Timeout() { - c.contiguousTimeouts++ - } - if readErr != nil { - return nil, errors.Wrap(readErr, "reading from socket") - } - buf := bytes.NewBuffer(b[:n]) - var h ResponseHeader - err := binary.Read(buf, binary.BigEndian, &h) - switch err { - default: - panic(err) - case io.ErrUnexpectedEOF, io.EOF: - continue - case nil: - } - if h.TransactionId != tid { - continue - } - c.contiguousTimeouts = 0 - if h.Action == ActionError { - err = errors.New(buf.String()) - } - return buf, err - } -} - -func readBody(r io.Reader, data ...interface{}) (err error) { - for _, datum := range data { - err = binary.Read(r, binary.BigEndian, datum) - if err != nil { - break - } - } - return -} - -func (c *udpAnnounce) connected() bool { - return !c.connectionIdReceived.IsZero() && time.Now().Before(c.connectionIdReceived.Add(time.Minute)) -} - func (c *udpAnnounce) dialNetwork() string { if c.a.UdpNetwork != "" { return c.a.UdpNetwork @@ -272,40 +85,7 @@ func (c *udpAnnounce) dialNetwork() string { return "udp" } -func (c *udpAnnounce) connect() (err error) { - if c.connected() { - return nil - } - c.connectionId = connectRequestConnectionId - if c.socket == nil { - hmp := missinggo.SplitHostMaybePort(c.url.Host) - if hmp.NoPort { - hmp.NoPort = false - hmp.Port = 80 - } - c.socket, err = net.Dial(c.dialNetwork(), hmp.String()) - if err != nil { - return - } - c.socket = pproffd.WrapNetConn(c.socket) - } - vars.Add("udp tracker connects", 1) - b, err := c.request(ActionConnect, nil, nil) - if err != nil { - return - } - var res ConnectionResponse - err = readBody(b, &res) - if err != nil { - return - } - c.connectionId = res.ConnectionId - c.connectionIdReceived = time.Now() - return -} - -// TODO: Split on IPv6, as BEP 15 says response peer decoding depends on -// network in use. +// TODO: Split on IPv6, as BEP 15 says response peer decoding depends on network in use. func announceUDP(opt Announce, _url *url.URL) (AnnounceResponse, error) { ua := udpAnnounce{ url: *_url, diff --git a/tracker/udp/announce.go b/tracker/udp/announce.go new file mode 100644 index 00000000..1573c275 --- /dev/null +++ b/tracker/udp/announce.go @@ -0,0 +1,35 @@ +package udp + +import ( + "encoding" + + "github.com/anacrolix/dht/v2/krpc" +) + +// Marshalled as binary by the UDP client, so be careful making changes. +type AnnounceRequest struct { + InfoHash [20]byte + PeerId [20]byte + Downloaded int64 + Left int64 // If less than 0, math.MaxInt64 will be used for HTTP trackers instead. + Uploaded int64 + // Apparently this is optional. None can be used for announces done at + // regular intervals. + Event AnnounceEvent + IPAddress uint32 + Key int32 + NumWant int32 // How many peer addresses are desired. -1 for default. + Port uint16 +} // 82 bytes + +type AnnounceEvent int32 + +func (e AnnounceEvent) String() string { + // See BEP 3, "event", and https://github.com/anacrolix/torrent/issues/416#issuecomment-751427001. + return []string{"", "completed", "started", "stopped"}[e] +} + +type AnnounceResponsePeers interface { + encoding.BinaryUnmarshaler + NodeAddrs() []krpc.NodeAddr +} diff --git a/tracker/udp/client.go b/tracker/udp/client.go new file mode 100644 index 00000000..54099ff6 --- /dev/null +++ b/tracker/udp/client.go @@ -0,0 +1,132 @@ +package udp + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "time" +) + +type Client struct { + connId ConnectionId + connIdIssued time.Time + Dispatcher *Dispatcher + Writer io.Writer +} + +func (cl *Client) Announce( + ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options, +) ( + respHdr AnnounceResponseHeader, err error, +) { + body, err := marshal(req) + if err != nil { + return + } + respBody, err := cl.request(ctx, ActionAnnounce, append(body, opts.Encode()...)) + if err != nil { + return + } + r := bytes.NewBuffer(respBody) + err = Read(r, &respHdr) + if err != nil { + err = fmt.Errorf("reading response header: %w", err) + return + } + err = peers.UnmarshalBinary(r.Bytes()) + if err != nil { + err = fmt.Errorf("reading response peers: %w", err) + } + return +} + +func (cl *Client) connect(ctx context.Context) (err error) { + if time.Since(cl.connIdIssued) < time.Minute { + return nil + } + respBody, err := cl.request(ctx, ActionConnect, nil) + if err != nil { + return err + } + var connResp ConnectionResponse + err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp) + if err != nil { + return + } + cl.connId = connResp.ConnectionId + cl.connIdIssued = time.Now() + return +} + +func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) { + if action == ActionConnect { + id = ConnectRequestConnectionId + return + } + err = cl.connect(ctx) + if err != nil { + return + } + id = cl.connId + return +} + +func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte, tId TransactionId) (err error) { + var buf bytes.Buffer + for n := 0; ; n++ { + var connId ConnectionId + connId, err = cl.connIdForRequest(ctx, action) + if err != nil { + return + } + buf.Reset() + err = binary.Write(&buf, binary.BigEndian, RequestHeader{ + ConnectionId: connId, + Action: action, + TransactionId: tId, + }) + if err != nil { + panic(err) + } + buf.Write(body) + _, err = cl.Writer.Write(buf.Bytes()) + if err != nil { + return + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(timeout(n)): + } + } +} + +func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) { + respChan := make(chan DispatchedResponse, 1) + t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) { + respChan <- dr + }) + defer t.End() + writeErr := make(chan error, 1) + go func() { + writeErr <- cl.requestWriter(ctx, action, body, t.Id()) + }() + select { + case dr := <-respChan: + if dr.Header.Action == action { + respBody = dr.Body + } else if dr.Header.Action == ActionError { + err = errors.New(string(dr.Body)) + } else { + err = fmt.Errorf("unexpected response action %v", dr.Header.Action) + } + case err = <-writeErr: + err = fmt.Errorf("write error: %w", err) + case <-ctx.Done(): + err = ctx.Err() + } + return +} diff --git a/tracker/udp/dispatcher.go b/tracker/udp/dispatcher.go new file mode 100644 index 00000000..907eb15b --- /dev/null +++ b/tracker/udp/dispatcher.go @@ -0,0 +1,64 @@ +package udp + +import ( + "bytes" + "fmt" + "sync" +) + +type Dispatcher struct { + mu sync.RWMutex + transactions map[TransactionId]Transaction +} + +func (me *Dispatcher) Dispatch(b []byte) error { + buf := bytes.NewBuffer(b) + var rh ResponseHeader + err := Read(buf, &rh) + if err != nil { + return err + } + me.mu.RLock() + defer me.mu.RUnlock() + if t, ok := me.transactions[rh.TransactionId]; ok { + t.h(DispatchedResponse{ + Header: rh, + Body: buf.Bytes(), + }) + return nil + } else { + return fmt.Errorf("unknown transaction id %v", rh.TransactionId) + } +} + +func (me *Dispatcher) forgetTransaction(id TransactionId) { + me.mu.Lock() + defer me.mu.Unlock() + delete(me.transactions, id) +} + +func (me *Dispatcher) NewTransaction(h TransactionResponseHandler) Transaction { + me.mu.Lock() + defer me.mu.Unlock() + for { + id := RandomTransactionId() + if _, ok := me.transactions[id]; ok { + continue + } + t := Transaction{ + d: me, + h: h, + id: id, + } + if me.transactions == nil { + me.transactions = make(map[TransactionId]Transaction) + } + me.transactions[id] = t + return t + } +} + +type DispatchedResponse struct { + Header ResponseHeader + Body []byte +} diff --git a/tracker/udp/options.go b/tracker/udp/options.go new file mode 100644 index 00000000..a2c223d0 --- /dev/null +++ b/tracker/udp/options.go @@ -0,0 +1,24 @@ +package udp + +import ( + "math" +) + +type Options struct { + RequestUri string +} + +func (opts Options) Encode() (ret []byte) { + for { + l := len(opts.RequestUri) + if l == 0 { + break + } + if l > math.MaxUint8 { + l = math.MaxUint8 + } + ret = append(append(ret, optionTypeURLData, byte(l)), opts.RequestUri[:l]...) + opts.RequestUri = opts.RequestUri[l:] + } + return +} diff --git a/tracker/udp/protocol.go b/tracker/udp/protocol.go new file mode 100644 index 00000000..365d3c5c --- /dev/null +++ b/tracker/udp/protocol.go @@ -0,0 +1,69 @@ +package udp + +import ( + "bytes" + "encoding/binary" + "io" +) + +type Action int32 + +const ( + ActionConnect Action = iota + ActionAnnounce + ActionScrape + ActionError + + ConnectRequestConnectionId = 0x41727101980 + + // BEP 41 + optionTypeEndOfOptions = 0 + optionTypeNOP = 1 + optionTypeURLData = 2 +) + +type TransactionId = int32 + +type ConnectionId = int64 + +type ConnectionRequest struct { + ConnectionId ConnectionId + Action Action + TransactionId TransactionId +} + +type ConnectionResponse struct { + ConnectionId ConnectionId +} + +type ResponseHeader struct { + Action Action + TransactionId TransactionId +} + +type RequestHeader struct { + ConnectionId ConnectionId + Action Action + TransactionId TransactionId +} // 16 bytes + +type AnnounceResponseHeader struct { + Interval int32 + Leechers int32 + Seeders int32 +} + +func marshal(data interface{}) (b []byte, err error) { + var buf bytes.Buffer + err = binary.Write(&buf, binary.BigEndian, data) + b = buf.Bytes() + return +} + +func Write(w io.Writer, data interface{}) error { + return binary.Write(w, binary.BigEndian, data) +} + +func Read(r io.Reader, data interface{}) error { + return binary.Read(r, binary.BigEndian, data) +} diff --git a/tracker/udp/timeout.go b/tracker/udp/timeout.go new file mode 100644 index 00000000..b5e18326 --- /dev/null +++ b/tracker/udp/timeout.go @@ -0,0 +1,18 @@ +package udp + +import ( + "time" +) + +const maxTimeout = 3840 * time.Second + +func timeout(contiguousTimeouts int) (d time.Duration) { + if contiguousTimeouts > 8 { + contiguousTimeouts = 8 + } + d = 15 * time.Second + for ; contiguousTimeouts > 0; contiguousTimeouts-- { + d *= 2 + } + return +} diff --git a/tracker/udp/timeout_test.go b/tracker/udp/timeout_test.go new file mode 100644 index 00000000..4bb0dc83 --- /dev/null +++ b/tracker/udp/timeout_test.go @@ -0,0 +1,15 @@ +package udp + +import ( + "math" + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestTimeoutMax(t *testing.T) { + c := qt.New(t) + c.Check(timeout(8), qt.Equals, maxTimeout) + c.Check(timeout(9), qt.Equals, maxTimeout) + c.Check(timeout(math.MaxInt32), qt.Equals, maxTimeout) +} diff --git a/tracker/udp/transaction.go b/tracker/udp/transaction.go new file mode 100644 index 00000000..2018b351 --- /dev/null +++ b/tracker/udp/transaction.go @@ -0,0 +1,23 @@ +package udp + +import "math/rand" + +func RandomTransactionId() TransactionId { + return TransactionId(rand.Uint32()) +} + +type TransactionResponseHandler func(dr DispatchedResponse) + +type Transaction struct { + id int32 + d *Dispatcher + h TransactionResponseHandler +} + +func (t *Transaction) Id() TransactionId { + return t.id +} + +func (t *Transaction) End() { + t.d.forgetTransaction(t.id) +} diff --git a/tracker/udp_test.go b/tracker/udp_test.go index d33550f1..39afa80c 100644 --- a/tracker/udp_test.go +++ b/tracker/udp_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/rand" "encoding/binary" + "errors" "fmt" "io" "io/ioutil" @@ -12,10 +13,11 @@ import ( "net/url" "sync" "testing" + "time" "github.com/anacrolix/dht/v2/krpc" _ "github.com/anacrolix/envpprof" - "github.com/pkg/errors" + "github.com/anacrolix/torrent/tracker/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -47,7 +49,7 @@ func TestMarshalAnnounceResponse(t *testing.T) { require.EqualValues(t, "\x7f\x00\x00\x01\x00\x02\xff\x00\x00\x03\x00\x04", b) - require.EqualValues(t, 12, binary.Size(AnnounceResponseHeader{})) + require.EqualValues(t, 12, binary.Size(udp.AnnounceResponseHeader{})) } // Failure to write an entire packet to UDP is expected to given an error. @@ -74,7 +76,7 @@ func TestLongWriteUDP(t *testing.T) { } func TestShortBinaryRead(t *testing.T) { - var data ResponseHeader + var data udp.ResponseHeader err := binary.Read(bytes.NewBufferString("\x00\x00\x00\x01"), binary.BigEndian, &data) if err != io.ErrUnexpectedEOF { t.FailNow() @@ -137,12 +139,20 @@ func TestUDPTracker(t *testing.T) { } rand.Read(req.PeerId[:]) copy(req.InfoHash[:], []uint8{0xa3, 0x56, 0x41, 0x43, 0x74, 0x23, 0xe6, 0x26, 0xd9, 0x38, 0x25, 0x4a, 0x6b, 0x80, 0x49, 0x10, 0xa6, 0x67, 0xa, 0xc1}) + var ctx context.Context + if dl, ok := t.Deadline(); ok { + var cancel func() + ctx, cancel = context.WithDeadline(context.Background(), dl.Add(-time.Second)) + defer cancel() + } ar, err := Announce{ TrackerUrl: trackers[0], Request: req, + Context: ctx, }.Do() // Skip any net errors as we don't control the server. - if _, ok := errors.Cause(err).(net.Error); ok { + var ne net.Error + if errors.As(err, &ne) { t.Skip(err) } require.NoError(t, err) @@ -163,6 +173,12 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) { rand.Read(req.InfoHash[:]) wg := sync.WaitGroup{} ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if dl, ok := t.Deadline(); ok { + var cancel func() + ctx, cancel = context.WithDeadline(ctx, dl.Add(-time.Second)) + defer cancel() + } for _, url := range trackers { wg.Add(1) go func(url string) { @@ -196,6 +212,7 @@ func TestURLPathOption(t *testing.T) { panic(err) } defer conn.Close() + announceErr := make(chan error) go func() { _, err := Announce{ TrackerUrl: (&url.URL{ @@ -204,34 +221,35 @@ func TestURLPathOption(t *testing.T) { Path: "/announce", }).String(), }.Do() - if err != nil { - defer conn.Close() - } - require.NoError(t, err) + defer conn.Close() + announceErr <- err }() var b [512]byte _, addr, _ := conn.ReadFrom(b[:]) r := bytes.NewReader(b[:]) - var h RequestHeader - read(r, &h) + var h udp.RequestHeader + udp.Read(r, &h) w := &bytes.Buffer{} - write(w, ResponseHeader{ + udp.Write(w, udp.ResponseHeader{ + Action: udp.ActionConnect, TransactionId: h.TransactionId, }) - write(w, ConnectionResponse{42}) + udp.Write(w, udp.ConnectionResponse{42}) conn.WriteTo(w.Bytes(), addr) n, _, _ := conn.ReadFrom(b[:]) r = bytes.NewReader(b[:n]) - read(r, &h) - read(r, &AnnounceRequest{}) + udp.Read(r, &h) + udp.Read(r, &AnnounceRequest{}) all, _ := ioutil.ReadAll(r) if string(all) != "\x02\x09/announce" { t.FailNow() } w = &bytes.Buffer{} - write(w, ResponseHeader{ + udp.Write(w, udp.ResponseHeader{ + Action: udp.ActionAnnounce, TransactionId: h.TransactionId, }) - write(w, AnnounceResponseHeader{}) + udp.Write(w, udp.AnnounceResponseHeader{}) conn.WriteTo(w.Bytes(), addr) + require.NoError(t, <-announceErr) }