aboutsummaryrefslogtreecommitdiff
path: root/internal/conn/conn.go
blob: 6df7e49146295adeece5c604510d65c53c0330a4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
package conn

import (
	"context"
	"io"
	"net"
	"time"
)

// Conn is a wrapper around net.Conn with convenience helpers.
type Conn struct {
	netConn net.Conn
}

// FromNetConn wraps an existing net.Conn. Useful for tests.
func FromNetConn(n net.Conn) *Conn { return &Conn{netConn: n} }

// Dial opens a TCP connection to addr using the given timeout.
func Dial(ctx context.Context, addr string, timeout time.Duration) (*Conn, error) {
	d := net.Dialer{Timeout: timeout}
	c, err := d.DialContext(ctx, "tcp", addr)
	if err != nil {
		return nil, err
	}
	return &Conn{netConn: c}, nil
}

// Close closes the underlying network connection.
func (c *Conn) Close() error {
	if c.netConn == nil {
		return nil
	}
	return c.netConn.Close()
}

// Send writes data to the connection with the provided timeout.
func (c *Conn) Send(ctx context.Context, data []byte, timeout time.Duration) error {
	if dl, ok := ctx.Deadline(); ok {
		if err := c.netConn.SetWriteDeadline(dl); err != nil {
			return err
		}
	} else {
		if err := c.netConn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
			return err
		}
	}
	_, err := c.netConn.Write(data)
	return err
}

// Receive reads all data from the connection with the provided timeout.
func (c *Conn) Receive(ctx context.Context, timeout time.Duration) ([]byte, error) {
	if dl, ok := ctx.Deadline(); ok {
		if err := c.netConn.SetReadDeadline(dl); err != nil {
			return nil, err
		}
	} else {
		if err := c.netConn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
			return nil, err
		}
	}
	return io.ReadAll(c.netConn)
}