aboutsummaryrefslogtreecommitdiff
path: root/internal/conn/conn.go
blob: a1da6f19c6014f0457c9dd489da20dfb150353f3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package conn

import (
	"context"
	"io"
	"net"
	"time"
)

type Conn struct {
	netConn net.Conn
}

// FromNetConn wraps an existing net.Conn. Useful for tests.
func FromNetConn(n net.Conn) *Conn { return &Conn{netConn: n} }

func Dial(ctx context.Context, addr string, timeout time.Duration) (*Conn, error) {
	d := net.Dialer{Timeout: timeout}
	c, err := d.DialContext(ctx, "tcp", addr)
	if err != nil {
		return nil, err
	}
	return &Conn{netConn: c}, nil
}

func (c *Conn) Close() error {
	if c.netConn == nil {
		return nil
	}
	return c.netConn.Close()
}

func (c *Conn) Send(ctx context.Context, data []byte, timeout time.Duration) error {
	if dl, ok := ctx.Deadline(); ok {
		c.netConn.SetWriteDeadline(dl)
	} else {
		c.netConn.SetWriteDeadline(time.Now().Add(timeout))
	}
	_, err := c.netConn.Write(data)
	return err
}

func (c *Conn) Receive(ctx context.Context, timeout time.Duration) ([]byte, error) {
	if dl, ok := ctx.Deadline(); ok {
		c.netConn.SetReadDeadline(dl)
	} else {
		c.netConn.SetReadDeadline(time.Now().Add(timeout))
	}
	return io.ReadAll(c.netConn)
}