diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/.gitkeep | 0 | ||||
-rw-r--r-- | internal/conn/conn.go | 50 | ||||
-rw-r--r-- | internal/proto/proto.go | 65 | ||||
-rw-r--r-- | internal/proto/proto_test.go | 39 |
4 files changed, 154 insertions, 0 deletions
diff --git a/internal/.gitkeep b/internal/.gitkeep deleted file mode 100644 index e69de29..0000000 --- a/internal/.gitkeep +++ /dev/null diff --git a/internal/conn/conn.go b/internal/conn/conn.go new file mode 100644 index 0000000..a1da6f1 --- /dev/null +++ b/internal/conn/conn.go @@ -0,0 +1,50 @@ +package conn + +import ( + "context" + "io" + "net" + "time" +) + +type Conn struct { + netConn net.Conn +} + +// FromNetConn wraps an existing net.Conn. Useful for tests. +func FromNetConn(n net.Conn) *Conn { return &Conn{netConn: n} } + +func Dial(ctx context.Context, addr string, timeout time.Duration) (*Conn, error) { + d := net.Dialer{Timeout: timeout} + c, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + return &Conn{netConn: c}, nil +} + +func (c *Conn) Close() error { + if c.netConn == nil { + return nil + } + return c.netConn.Close() +} + +func (c *Conn) Send(ctx context.Context, data []byte, timeout time.Duration) error { + if dl, ok := ctx.Deadline(); ok { + c.netConn.SetWriteDeadline(dl) + } else { + c.netConn.SetWriteDeadline(time.Now().Add(timeout)) + } + _, err := c.netConn.Write(data) + return err +} + +func (c *Conn) Receive(ctx context.Context, timeout time.Duration) ([]byte, error) { + if dl, ok := ctx.Deadline(); ok { + c.netConn.SetReadDeadline(dl) + } else { + c.netConn.SetReadDeadline(time.Now().Add(timeout)) + } + return io.ReadAll(c.netConn) +} diff --git a/internal/proto/proto.go b/internal/proto/proto.go new file mode 100644 index 0000000..d86ce75 --- /dev/null +++ b/internal/proto/proto.go @@ -0,0 +1,65 @@ +package proto + +import ( + "encoding/base64" + "fmt" + "sort" + "strings" + "unicode" + + "golang.org/x/text/encoding/charmap" +) + +// EncodeParams converts params map into a sorted base64-encoded string using Windows-1251 encoding. +func EncodeParams(params map[string]string) (string, error) { + keys := make([]string, 0, len(params)) + for k := range params { + keys = append(keys, k) + } + sort.Strings(keys) + + var sb strings.Builder + for i, k := range keys { + if i > 0 { + sb.WriteByte('|') + } + sb.WriteString(k) + sb.WriteByte('=') + sb.WriteString(params[k]) + } + sb.WriteByte('|') + + enc := charmap.Windows1251.NewEncoder() + encoded, err := enc.String(sb.String()) + if err != nil { + return "", fmt.Errorf("encode params: %w", err) + } + return base64.StdEncoding.EncodeToString([]byte(encoded)), nil +} + +// DecodeResponse decodes base64-encoded Windows-1251 text to UTF-8 and removes control characters. +func DecodeResponse(data string) (string, error) { + raw, err := base64.StdEncoding.DecodeString(strings.TrimSpace(data)) + if err != nil { + return "", fmt.Errorf("base64 decode: %w", err) + } + decoded, err := charmap.Windows1251.NewDecoder().Bytes(raw) + if err != nil { + return "", fmt.Errorf("decode charset: %w", err) + } + cleaned := strings.Map(func(r rune) rune { + if unicode.IsPrint(r) || r == '\n' || r == '\r' || r == '\t' { + return r + } + return -1 + }, string(decoded)) + return cleaned, nil +} + +// BuildRequest returns byte slice representing the command and parameters. +func BuildRequest(command, encodedParams string, quit bool) []byte { + if quit { + return []byte(fmt.Sprintf("%s %s\nQUIT\n", command, encodedParams)) + } + return []byte(fmt.Sprintf("%s %s\n", command, encodedParams)) +} diff --git a/internal/proto/proto_test.go b/internal/proto/proto_test.go new file mode 100644 index 0000000..0e7e644 --- /dev/null +++ b/internal/proto/proto_test.go @@ -0,0 +1,39 @@ +package proto + +import ( + "strings" + "testing" +) + +func TestEncodeParamsOrder(t *testing.T) { + params := map[string]string{"B": "2", "A": "1"} + encoded1, err := EncodeParams(params) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // encode again with different map order + encoded2, err := EncodeParams(map[string]string{"A": "1", "B": "2"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if encoded1 != encoded2 { + t.Fatalf("expected deterministic encode, got %s vs %s", encoded1, encoded2) + } +} + +func TestDecodeResponse(t *testing.T) { + // "привет" in Cyrillic + original := "привет" + params := map[string]string{"MSG": original} + enc, err := EncodeParams(params) + if err != nil { + t.Fatalf("encode params: %v", err) + } + dec, err := DecodeResponse(enc) + if err != nil { + t.Fatalf("decode: %v", err) + } + if !strings.Contains(dec, original) { + t.Fatalf("expected to contain %q, got %q", original, dec) + } +} |