From 19436c06c2cf13452e1bc05893c23a33bc526c6d Mon Sep 17 00:00:00 2001 From: Jan Dittberner Date: Thu, 1 Dec 2022 21:36:10 +0100 Subject: [PATCH] Implement unit tests for public packages This commit adds a comprehensive unit test suite for all public packages. --- cmd/clientsim/main.go | 32 +- internal/handler/msgpack.go | 62 +- internal/serial/seriallink.go | 6 +- pkg/protocol/protocol.go | 405 ++++--- pkg/protocol/protocol_test.go | 2151 +++++++++++++++++++++++++++++++++ 5 files changed, 2458 insertions(+), 198 deletions(-) create mode 100644 pkg/protocol/protocol_test.go diff --git a/cmd/clientsim/main.go b/cmd/clientsim/main.go index 19b3913..6aaa72c 100644 --- a/cmd/clientsim/main.go +++ b/cmd/clientsim/main.go @@ -115,13 +115,13 @@ func (c *clientSimulator) Run(ctx context.Context) error { generatorErrors := make(chan error) go func() { - err := c.framer.ReadFrames(os.Stdin, c.framesIn) + err := c.framer.ReadFrames(ctx, os.Stdin, c.framesIn) framerErrors <- err }() go func() { - err := c.framer.WriteFrames(os.Stdout, c.framesOut) + err := c.framer.WriteFrames(ctx, os.Stdout, c.framesOut) framerErrors <- err }() @@ -129,7 +129,7 @@ func (c *clientSimulator) Run(ctx context.Context) error { go func() { clientProtocol := protocol.NewClient(c.clientHandler, c.commandGenerator.commands, c.framesIn, c.framesOut, c.logger) - err := clientProtocol.Handle() + err := clientProtocol.Handle(ctx) protocolErrors <- err }() @@ -170,7 +170,7 @@ type ClientHandler struct { logger *logrus.Logger } -func (c *ClientHandler) Send(command *protocol.Command, out chan []byte) error { +func (c *ClientHandler) Send(ctx context.Context, command *protocol.Command, out chan []byte) error { var ( frame []byte err error @@ -185,7 +185,12 @@ func (c *ClientHandler) Send(command *protocol.Command, out chan []byte) error { c.logger.Trace("writing command announcement") - out <- frame + select { + case <-ctx.Done(): + return nil + case out <- frame: + break + } frame, err = msgpack.Marshal(command.Command) if err != nil { @@ -194,17 +199,24 @@ func (c *ClientHandler) Send(command *protocol.Command, out chan []byte) error { c.logger.WithField("command", command.Command).Info("write command data") - out <- frame + select { + case <-ctx.Done(): + return nil + case out <- frame: + break + } return nil } -func (c *ClientHandler) ResponseAnnounce(in chan []byte) (*protocol.Response, error) { +func (c *ClientHandler) ResponseAnnounce(ctx context.Context, in chan []byte) (*protocol.Response, error) { response := &protocol.Response{} var announce messages.ResponseAnnounce select { + case <-ctx.Done(): + return nil, nil case frame := <-in: if err := msgpack.Unmarshal(frame, &announce); err != nil { return nil, fmt.Errorf("could not unmarshal response announcement: %w", err) @@ -220,8 +232,10 @@ func (c *ClientHandler) ResponseAnnounce(in chan []byte) (*protocol.Response, er } } -func (c *ClientHandler) ResponseData(in chan []byte, response *protocol.Response) error { +func (c *ClientHandler) ResponseData(ctx context.Context, in chan []byte, response *protocol.Response) error { select { + case <-ctx.Done(): + return nil case frame := <-in: switch response.Announce.Code { case messages.RespHealth: @@ -248,7 +262,7 @@ func (c *ClientHandler) ResponseData(in chan []byte, response *protocol.Response return nil } -func (c *ClientHandler) HandleResponse(response *protocol.Response) error { +func (c *ClientHandler) HandleResponse(_ context.Context, response *protocol.Response) error { c.logger.WithField("response", response.Announce).Info("handled response") c.logger.WithField("response", response).Debug("full response") diff --git a/internal/handler/msgpack.go b/internal/handler/msgpack.go index cca5c0d..23d0357 100644 --- a/internal/handler/msgpack.go +++ b/internal/handler/msgpack.go @@ -18,10 +18,10 @@ limitations under the License. package handler import ( + "context" "errors" "fmt" "math/big" - "sync" "time" "github.com/shamaton/msgpackgen/msgpack" @@ -43,35 +43,33 @@ type MsgPackHandler struct { logger *logrus.Logger healthHandler *health.Handler fetchCRLHandler *revoking.FetchCRLHandler - lock sync.Mutex } -func (m *MsgPackHandler) CommandAnnounce(frames chan []byte) (*protocol.Command, error) { - m.lock.Lock() - defer m.lock.Unlock() +func (m *MsgPackHandler) CommandAnnounce(ctx context.Context, frames chan []byte) (*protocol.Command, error) { + select { + case <-ctx.Done(): + return nil, nil + case frame := <-frames: + var ann messages.CommandAnnounce - frame := <-frames + if err := msgpack.Unmarshal(frame, &ann); err != nil { + return nil, fmt.Errorf("could not unmarshal command announcement: %w", err) + } - var ann messages.CommandAnnounce + if ann.Code == messages.CmdUndef { + return nil, fmt.Errorf("received undefined command announcement: %s", ann) + } - if err := msgpack.Unmarshal(frame, &ann); err != nil { - return nil, fmt.Errorf("could not unmarshal command announcement: %w", err) - } + m.logger.WithField("announcement", &ann).Debug("received command announcement") - if ann.Code == messages.CmdUndef { - return nil, fmt.Errorf("received undefined command announcement: %s", ann) + return &protocol.Command{Announce: &ann}, nil } - - m.logger.WithField("announcement", &ann).Debug("received command announcement") - - return &protocol.Command{Announce: &ann}, nil } -func (m *MsgPackHandler) CommandData(frames chan []byte, command *protocol.Command) error { - m.lock.Lock() - defer m.lock.Unlock() - +func (m *MsgPackHandler) CommandData(ctx context.Context, frames chan []byte, command *protocol.Command) error { select { + case <-ctx.Done(): + return nil case frame := <-frames: err := m.parseCommand(frame, command) if err != nil { @@ -84,10 +82,7 @@ func (m *MsgPackHandler) CommandData(frames chan []byte, command *protocol.Comma } } -func (m *MsgPackHandler) HandleCommand(command *protocol.Command) (*protocol.Response, error) { - m.lock.Lock() - defer m.lock.Unlock() - +func (m *MsgPackHandler) HandleCommand(_ context.Context, command *protocol.Command) (*protocol.Response, error) { var ( response *protocol.Response err error @@ -110,10 +105,7 @@ func (m *MsgPackHandler) logCommandResponse(command *protocol.Command, response m.logger.WithField("command", command).WithField("response", response).Debug("command and response") } -func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) error { - m.lock.Lock() - defer m.lock.Unlock() - +func (m *MsgPackHandler) Respond(ctx context.Context, response *protocol.Response, out chan []byte) error { announce, err := msgpack.Marshal(response.Announce) if err != nil { return fmt.Errorf("could not marshal response announcement: %w", err) @@ -121,7 +113,12 @@ func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) e m.logger.WithField("length", len(announce)).Debug("write response announcement") - out <- announce + select { + case <-ctx.Done(): + return nil + case out <- announce: + break + } data, err := msgpack.Marshal(response.Response) if err != nil { @@ -130,7 +127,12 @@ func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) e m.logger.WithField("length", len(data)).Debug("write response") - out <- data + select { + case <-ctx.Done(): + return nil + case out <- data: + break + } return nil } diff --git a/internal/serial/seriallink.go b/internal/serial/seriallink.go index 9fc0787..71c09f2 100644 --- a/internal/serial/seriallink.go +++ b/internal/serial/seriallink.go @@ -63,13 +63,13 @@ func (h *Handler) Run(ctx context.Context) error { protocolErrors, framerErrors := make(chan error), make(chan error) go func() { - err := h.framer.ReadFrames(h.port, h.framesIn) + err := h.framer.ReadFrames(ctx, h.port, h.framesIn) framerErrors <- err }() go func() { - err := h.framer.WriteFrames(h.port, h.framesOut) + err := h.framer.WriteFrames(ctx, h.port, h.framesOut) framerErrors <- err }() @@ -77,7 +77,7 @@ func (h *Handler) Run(ctx context.Context) error { go func() { serverProtocol := protocol.NewServer(h.serverHandler, h.framesIn, h.framesOut, h.logger) - err := serverProtocol.Handle() + err := serverProtocol.Handle(ctx) protocolErrors <- err }() diff --git a/pkg/protocol/protocol.go b/pkg/protocol/protocol.go index 50143ff..49f066f 100644 --- a/pkg/protocol/protocol.go +++ b/pkg/protocol/protocol.go @@ -20,6 +20,7 @@ package protocol import ( "bytes" + "context" "errors" "fmt" "io" @@ -54,25 +55,27 @@ func (r *Response) String() string { // ServerHandler is responsible for parsing incoming frames and calling commands type ServerHandler interface { // CommandAnnounce handles the initial announcement of a command. - CommandAnnounce(chan []byte) (*Command, error) + CommandAnnounce(context.Context, chan []byte) (*Command, error) // CommandData handles the command data. - CommandData(chan []byte, *Command) error + CommandData(context.Context, chan []byte, *Command) error // HandleCommand executes the command, generating a response. - HandleCommand(*Command) (*Response, error) + HandleCommand(context.Context, *Command) (*Response, error) // Respond generates the response for a command. - Respond(*Response, chan []byte) error + Respond(context.Context, *Response, chan []byte) error } type ClientHandler interface { - Send(*Command, chan []byte) error - ResponseAnnounce(chan []byte) (*Response, error) - ResponseData(chan []byte, *Response) error - HandleResponse(*Response) error + Send(context.Context, *Command, chan []byte) error + ResponseAnnounce(context.Context, chan []byte) (*Response, error) + ResponseData(context.Context, chan []byte, *Response) error + HandleResponse(context.Context, *Response) error } var ( - errCommandExpected = errors.New("command must not be nil") - errResponseExpected = errors.New("response must not be nil") + errCommandExpected = errors.New("command must not be nil") + errCommandAnnounceExpected = errors.New("command must have an announcement") + errCommandDataExpected = errors.New("command must have data") + errResponseExpected = errors.New("response must not be nil") ErrResponseAnnounceTimeoutExpired = errors.New("response announce timeout expired") ErrResponseDataTimeoutExpired = errors.New("response data timeout expired") @@ -115,7 +118,7 @@ type ServerProtocol struct { state protocolState } -func (p *ServerProtocol) Handle() error { +func (p *ServerProtocol) Handle(ctx context.Context) error { var ( command *Command response *Response @@ -123,67 +126,107 @@ func (p *ServerProtocol) Handle() error { ) for { - p.logger.Debugf("handling protocol state %s", p.state) + select { + case <-ctx.Done(): + close(p.out) - switch p.state { - case cmdAnnounce: - command, err = p.handler.CommandAnnounce(p.in) - if err != nil { - p.logger.WithError(err).Error("could not handle command announce") - - break + return nil + default: + p.logger.Debugf("handling protocol state %s", p.state) + + switch p.state { + case cmdAnnounce: + command = p.commandAnnounce(ctx) + case cmdData: + err = p.commandData(ctx, command) + if err != nil { + return err + } + case handleCommand: + response, err = p.handleCommand(ctx, command) + if err != nil { + return err + } + case respond: + err = p.respond(ctx, response) + if err != nil { + return err + } + default: + return fmt.Errorf("unknown protocol state %s", p.state) } + } + } +} - p.state = cmdData - case cmdData: - if command == nil { - return errCommandExpected - } +func (p *ServerProtocol) commandAnnounce(ctx context.Context) *Command { + command, err := p.handler.CommandAnnounce(ctx, p.in) + if err != nil { + p.logger.WithError(err).Error("could not handle command announce") - err = p.handler.CommandData(p.in, command) - if err != nil { - p.logger.WithError(err).Error("could not handle command data") + return nil + } - p.state = cmdAnnounce + p.state = cmdData - break - } + return command +} - p.state = handleCommand - case handleCommand: - if command == nil { - return errCommandExpected - } +func (p *ServerProtocol) commandData(ctx context.Context, command *Command) error { + if command == nil || command.Announce == nil { + return errCommandAnnounceExpected + } - response, err = p.handler.HandleCommand(command) - if err != nil { - p.logger.WithError(err).Error("could not handle command") + err := p.handler.CommandData(ctx, p.in, command) + if err != nil { + p.logger.WithError(err).Error("could not handle command data") - p.state = cmdAnnounce + p.state = cmdAnnounce - break - } + return nil + } - p.state = respond - case respond: - if response == nil { - return errResponseExpected - } + p.state = handleCommand - err = p.handler.Respond(response, p.out) - if err != nil { - p.logger.WithError(err).Error("could not respond") + return nil +} - p.state = cmdAnnounce +func (p *ServerProtocol) handleCommand(ctx context.Context, command *Command) (*Response, error) { + if command == nil || command.Announce == nil || command.Command == nil { + return nil, errCommandDataExpected + } - break - } + response, err := p.handler.HandleCommand(ctx, command) + if err != nil { + p.logger.WithError(err).Error("could not handle command") - p.state = cmdAnnounce - default: - return fmt.Errorf("unknown protocol state %s", p.state) - } + p.state = cmdAnnounce + + return nil, nil } + + p.state = respond + + return response, nil +} + +func (p *ServerProtocol) respond(ctx context.Context, response *Response) error { + if response == nil { + return errResponseExpected + } + + err := p.handler.Respond(ctx, response, p.out) + if err != nil { + p.logger.WithError(err).Error("could not respond") + + p.state = cmdAnnounce + + return nil + } + + p.state = cmdAnnounce + + return nil } func NewServer(handler ServerHandler, in, out chan []byte, logger *logrus.Logger) *ServerProtocol { @@ -196,14 +239,6 @@ func NewServer(handler ServerHandler, in, out chan []byte, logger *logrus.Logger } } -// Framer handles bytes on the wire by adding or removing framing information. -type Framer interface { - // ReadFrames reads data frames and publishes unframed data to the channel. - ReadFrames(io.Reader, chan []byte) error - // WriteFrames takes data from the channel and writes framed data to the writer. - WriteFrames(io.Writer, chan []byte) error -} - type ClientProtocol struct { handler ClientHandler commands chan *Command @@ -212,80 +247,121 @@ type ClientProtocol struct { state protocolState } -func (p *ClientProtocol) Handle() error { +func (p *ClientProtocol) Handle(ctx context.Context) error { var ( - command *Command response *Response err error ) for { - p.logger.Debugf("handling protocol state %s", p.state) + select { + case <-ctx.Done(): + return nil + default: + p.logger.Debugf("handling protocol state %s", p.state) - switch p.state { - case cmdAnnounce: - command = <-p.commands - if command == nil { - return errCommandExpected + switch p.state { + case cmdAnnounce: + err = p.cmdAnnounce(ctx) + if err != nil { + return err + } + case respAnnounce: + response = p.respAnnounce(ctx) + case respData: + err = p.respData(ctx, response) + if err != nil { + return err + } + case handleResponse: + err = p.handleResponse(ctx, response) + if err != nil { + return err + } + default: + return fmt.Errorf("unknown protocol state %s", p.state) } + } + } +} - err = p.handler.Send(command, p.out) - if err != nil { - p.logger.WithError(err).Error("could not send command announce") - - break - } +func (p *ClientProtocol) cmdAnnounce(ctx context.Context) error { + select { + case <-ctx.Done(): + return nil + case command := <-p.commands: + if command == nil { + return errCommandExpected + } - p.state = respAnnounce - case respAnnounce: - response, err = p.handler.ResponseAnnounce(p.in) - if err != nil { - p.logger.WithError(err).Error("could not handle response announce") + err := p.handler.Send(ctx, command, p.out) + if err != nil { + p.logger.WithError(err).Error("could not send command announce") - p.state = cmdAnnounce + return nil + } + } - break - } + p.state = respAnnounce - p.state = respData - case respData: - if response == nil { - return errResponseExpected - } + return nil +} - err = p.handler.ResponseData(p.in, response) - if err != nil { - p.logger.WithError(err).Error("could not handle response data") +func (p *ClientProtocol) respAnnounce(ctx context.Context) *Response { + response, err := p.handler.ResponseAnnounce(ctx, p.in) + if err != nil { + p.logger.WithError(err).Error("could not handle response announce") - if errors.Is(err, ErrResponseDataTimeoutExpired) { - p.state = cmdAnnounce - } else { - p.state = respAnnounce - } + p.state = cmdAnnounce - break - } + return nil + } - p.state = handleResponse - case handleResponse: - if response == nil { - return errResponseExpected - } + p.state = respData - err = p.handler.HandleResponse(response) - if err != nil { - p.logger.WithError(err).Error("could not handle response") + return response +} - p.state = cmdAnnounce +func (p *ClientProtocol) respData(ctx context.Context, response *Response) error { + if response == nil || response.Announce == nil { + return errResponseExpected + } - break - } + err := p.handler.ResponseData(ctx, p.in, response) + if err != nil { + p.logger.WithError(err).Error("could not handle response data") + if errors.Is(err, ErrResponseDataTimeoutExpired) { p.state = cmdAnnounce - default: - return fmt.Errorf("unknown protocol state %s", p.state) + } else { + p.state = respAnnounce } + + return nil + } + + p.state = handleResponse + + return nil +} + +func (p *ClientProtocol) handleResponse(ctx context.Context, response *Response) error { + if response == nil || response.Announce == nil || response.Response == nil { + return errResponseExpected + } + + err := p.handler.HandleResponse(ctx, response) + if err != nil { + p.logger.WithError(err).Error("could not handle response") + + p.state = cmdAnnounce + + return nil } + + p.state = cmdAnnounce + + return nil } func NewClient( @@ -304,6 +380,14 @@ func NewClient( } } +// Framer handles bytes on the wire by adding or removing framing information. +type Framer interface { + // ReadFrames reads data frames and publishes unframed data to the channel. + ReadFrames(context.Context, io.Reader, chan []byte) error + // WriteFrames takes data from the channel and writes framed data to the writer. + WriteFrames(context.Context, io.Writer, chan []byte) error +} + const bufferSize = 1024 const readInterval = 50 * time.Millisecond @@ -319,7 +403,7 @@ func NewCOBSFramer(logger *logrus.Logger) *COBSFramer { } } -func (c *COBSFramer) ReadFrames(reader io.Reader, frameChan chan []byte) error { +func (c *COBSFramer) ReadFrames(ctx context.Context, reader io.Reader, frameChan chan []byte) error { var ( err error raw, data, frame []byte @@ -328,49 +412,55 @@ func (c *COBSFramer) ReadFrames(reader io.Reader, frameChan chan []byte) error { buffer := &bytes.Buffer{} for { - raw, err = c.readRaw(reader) - if err != nil { - close(frameChan) + select { + case <-ctx.Done(): + return nil + default: + raw, err = c.readRaw(reader) + if err != nil { + close(frameChan) - return err - } + return err + } - if len(raw) == 0 { - time.Sleep(readInterval) + if len(raw) == 0 { + time.Sleep(readInterval) - continue - } + continue + } - c.logger.Tracef("read %d raw bytes", len(raw)) + c.logger.Tracef("read %d raw bytes", len(raw)) - buffer.Write(raw) + buffer.Write(raw) - for { - data, err = buffer.ReadBytes(c.config.SpecialByte) - if err != nil { - if errors.Is(err, io.EOF) { - buffer.Write(data) + for { + data, err = buffer.ReadBytes(c.config.SpecialByte) + if err != nil { + if errors.Is(err, io.EOF) { + buffer.Write(data) + + break + } - break + // this is a safety measure, buffer.ReadBytes should only return io.EOF + return fmt.Errorf("could not read from buffer: %w", err) } - return fmt.Errorf("could not read from buffer: %w", err) - } + if err = cobs.Verify(data, c.config); err != nil { + c.logger.WithError(err).Warnf("skipping invalid frame of %d bytes", len(data)) - if err = cobs.Verify(data, c.config); err != nil { - c.logger.WithError(err).Warnf("skipping invalid frame of %d bytes", len(data)) + continue + } - continue - } + frame = cobs.Decode(data, c.config) - frame = cobs.Decode(data, c.config) + c.logger.Tracef("frame decoded to length %d", len(frame)) - c.logger.Tracef("frame decoded to length %d", len(frame)) + frameChan <- frame + } - frameChan <- frame + c.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) } - - c.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) } } @@ -380,7 +470,7 @@ func (c *COBSFramer) readRaw(reader io.Reader) ([]byte, error) { count, err := reader.Read(buf) if err != nil { if errors.Is(err, io.EOF) { - return []byte{}, nil + return buf[:count], nil } return nil, fmt.Errorf("could not read data: %w", err) @@ -391,23 +481,26 @@ func (c *COBSFramer) readRaw(reader io.Reader) ([]byte, error) { return raw, nil } -func (c *COBSFramer) WriteFrames(writer io.Writer, frameChan chan []byte) error { +func (c *COBSFramer) WriteFrames(ctx context.Context, writer io.Writer, frameChan chan []byte) error { for { - frame := <-frameChan + select { + case <-ctx.Done(): + return nil + case frame := <-frameChan: + if frame == nil { + c.logger.Debug("channel closed") - if frame == nil { - c.logger.Debug("channel closed") + return nil + } - return nil - } + encoded := cobs.Encode(frame, c.config) - encoded := cobs.Encode(frame, c.config) + n, err := io.Copy(writer, bytes.NewReader(encoded)) + if err != nil { + return fmt.Errorf("cold not write data: %w", err) + } - n, err := io.Copy(writer, bytes.NewReader(encoded)) - if err != nil { - return fmt.Errorf("cold not write data: %w", err) + c.logger.Tracef("wrote %d bytes", n) } - - c.logger.Tracef("wrote %d bytes", n) } } diff --git a/pkg/protocol/protocol_test.go b/pkg/protocol/protocol_test.go new file mode 100644 index 0000000..b5ecbfb --- /dev/null +++ b/pkg/protocol/protocol_test.go @@ -0,0 +1,2151 @@ +/* +Copyright 2022 CAcert Inc. +SPDX-License-Identifier: Apache-2.0 + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package protocol + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "strings" + "sync" + "testing" + "testing/iotest" + "time" + + "github.com/google/uuid" + "github.com/justincpresley/go-cobs" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "git.cacert.org/cacert-gosigner/pkg/messages" +) + +type expectedLogs struct { + level logrus.Level + message string +} + +func assertLogs(t *testing.T, hook *test.Hook, expected []expectedLogs) { + t.Helper() + + logEntries := hook.AllEntries() + assert.Len(t, logEntries, len(expected)) + + for i, e := range expected { + assert.Equal(t, e.level, logEntries[i].Level) + assert.Equal(t, e.message, logEntries[i].Message) + } +} + +func TestCommand_String(t *testing.T) { + c := &Command{ + Announce: messages.BuildCommandAnnounce(messages.CmdUndef), + Command: "my undefined command", + } + + str := c.String() + + assert.NotEmpty(t, str) + assert.Contains(t, str, c.Announce.String()) + assert.Contains(t, str, c.Command) + assert.Contains(t, str, "announce") + assert.Contains(t, str, "data") + assert.Contains(t, str, "Cmd[") + assert.True(t, strings.HasSuffix(str, "}]")) +} + +func TestResponse_String(t *testing.T) { + r := &Response{ + Announce: messages.BuildResponseAnnounce(messages.RespUndef, uuid.NewString()), + Response: "my undefined response", + } + + str := r.String() + + assert.NotEmpty(t, str) + assert.Contains(t, str, r.Announce.String()) + assert.Contains(t, str, r.Response) + assert.Contains(t, str, "announce") + assert.Contains(t, str, "data") + assert.Contains(t, str, "Rsp[") + assert.True(t, strings.HasSuffix(str, "}]")) +} + +func TestProtocolState_String(t *testing.T) { + goodStates := []struct { + name string + state protocolState + }{ + {"command announce", cmdAnnounce}, + {"command data", cmdData}, + {"handle command", handleCommand}, + {"respond", respond}, + {"response announce", respAnnounce}, + {"response data", respData}, + {"handle response", handleResponse}, + } + + for _, s := range goodStates { + t.Run(s.name, func(t *testing.T) { + str := s.state.String() + + assert.NotEmpty(t, str) + assert.NotContains(t, str, "unknown") + }) + } + + t.Run("unsupported state", func(t *testing.T) { + str := protocolState(-1).String() + + assert.NotEmpty(t, str) + assert.Contains(t, str, "unknown") + assert.Contains(t, str, "-1") + }) +} + +type noopServerHandler struct{} + +func (h *noopServerHandler) CommandAnnounce(context.Context, chan []byte) (*Command, error) { + return nil, nil +} + +func (h *noopServerHandler) CommandData(context.Context, chan []byte, *Command) error { + return nil +} + +func (h *noopServerHandler) HandleCommand(context.Context, *Command) (*Response, error) { + return nil, nil +} + +func (h *noopServerHandler) Respond(context.Context, *Response, chan []byte) error { + return nil +} + +type testServerHandler struct { + logger *logrus.Logger +} + +func (h *testServerHandler) CommandAnnounce(ctx context.Context, in chan []byte) (*Command, error) { + select { + case <-ctx.Done(): + return nil, nil + case frame := <-in: + h.logger.Infof("announce frame %s", string(frame)) + + return &Command{Announce: messages.BuildCommandAnnounce(messages.CmdUndef)}, nil + } +} + +func (h *testServerHandler) CommandData(ctx context.Context, in chan []byte, command *Command) error { + select { + case <-ctx.Done(): + return nil + case frame := <-in: + h.logger.Infof("command frame %s", string(frame)) + + command.Command = frame + + return nil + } +} + +func (h *testServerHandler) HandleCommand(_ context.Context, command *Command) (*Response, error) { + h.logger.Info("handle command") + + return &Response{ + Announce: messages.BuildResponseAnnounce(messages.RespUndef, command.Announce.ID), + Response: fmt.Sprintf("response for command %s", command.Command), + }, nil +} + +func (h *testServerHandler) Respond(ctx context.Context, response *Response, out chan []byte) error { + h.logger.Info("send response") + + buf := bytes.NewBuffer([]byte("test-response-")) + buf.WriteString(response.Announce.ID) + + select { + case <-ctx.Done(): + return nil + case out <- buf.Bytes(): + return nil + } +} + +type commandAnnounceErrServerHandler struct{ testServerHandler } + +func (h *commandAnnounceErrServerHandler) CommandAnnounce(ctx context.Context, in chan []byte) (*Command, error) { + select { + case <-ctx.Done(): + return nil, nil + case announce := <-in: + return nil, fmt.Errorf("failed to handle announce %s", announce) + } +} + +type commandAnnounceNilServerHandler struct{ testServerHandler } + +func (h *commandAnnounceNilServerHandler) CommandAnnounce(ctx context.Context, in chan []byte) (*Command, error) { + select { + case <-ctx.Done(): + return nil, nil + case <-in: + return nil, nil + } +} + +type commandDataErrServerHandler struct{ testServerHandler } + +func (h *commandDataErrServerHandler) CommandData(ctx context.Context, in chan []byte, _ *Command) error { + select { + case <-ctx.Done(): + return nil + case data := <-in: + return fmt.Errorf("failed to handle command data %s", data) + } +} + +type commandDataNilAnnouncementServerHandler struct{ testServerHandler } + +func (h *commandDataNilAnnouncementServerHandler) CommandAnnounce( + ctx context.Context, + in chan []byte, +) (*Command, error) { + select { + case <-ctx.Done(): + return nil, nil + case <-in: + return &Command{}, nil + } +} + +type handleCommandErrServerHandler struct{ testServerHandler } + +func (h *handleCommandErrServerHandler) HandleCommand(_ context.Context, command *Command) (*Response, error) { + return nil, fmt.Errorf("failed to handle command %s", command) +} + +type handleCommandNilCommandServerHandler struct{ testServerHandler } + +func (h *handleCommandNilCommandServerHandler) CommandData(ctx context.Context, in chan []byte, _ *Command) error { + select { + case <-ctx.Done(): + return nil + case <-in: + return nil + } +} + +type respondErrServerHandler struct{ testServerHandler } + +func (h *respondErrServerHandler) Respond(_ context.Context, r *Response, _ chan []byte) error { + return fmt.Errorf("failed to respond %s", r) +} + +type respondNilResponseServerHandler struct{ testServerHandler } + +func (h *respondNilResponseServerHandler) HandleCommand(context.Context, *Command) (*Response, error) { + return nil, nil +} + +func TestNewServer(t *testing.T) { + logger, _ := test.NewNullLogger() + + in := make(chan []byte) + out := make(chan []byte) + + server := NewServer(&noopServerHandler{}, in, out, logger) + + assert.NotNil(t, server) + assert.IsType(t, (*ServerProtocol)(nil), server) +} + +func TestServerProtocol_Handle(t *testing.T) { + protoLog, pHook := test.NewNullLogger() + protoLog.SetLevel(logrus.TraceLevel) + + handlerLog, hHook := test.NewNullLogger() + + t.Run("initialization", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + + assertLogs(t, hHook, []expectedLogs{}) + }) + + t.Run("happy-path", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "test command") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.NoError(t, err) + + assert.NotNil(t, response) + assert.NotEmpty(t, response) + assert.True(t, strings.HasPrefix(string(response), "test-response")) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respond)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "announce frame dropped announcement"}, + {logrus.InfoLevel, "command frame test command"}, + {logrus.InfoLevel, "handle command"}, + {logrus.InfoLevel, "send response"}, + }) + }) + + t.Run("command-announce-nil", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&commandAnnounceNilServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.ErrorIs(t, err, errCommandAnnounceExpected) + assert.Nil(t, response) + + assertLogs(t, hHook, []expectedLogs{}) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, + }) + }) + + t.Run("command-announce-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&commandAnnounceErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, response) + + assertLogs(t, hHook, []expectedLogs{}) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.ErrorLevel, "could not handle command announce"}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("command-data-nil-announce", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&commandDataNilAnnouncementServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + select { + case <-ctx.Done(): + break + case response = <-out: + break + } + wg.Done() + }() + + in <- []byte("dropped announcement") + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.ErrorIs(t, err, errCommandAnnounceExpected) + assert.Nil(t, response) + + assertLogs(t, hHook, []expectedLogs{}) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, + }) + }) + + t.Run("command-data-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&commandDataErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "command fails") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, response) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "announce frame dropped announcement"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, + {logrus.ErrorLevel, "could not handle command data"}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("handle-command-nil-command", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&handleCommandNilCommandServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "dropped command") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.ErrorIs(t, err, errCommandDataExpected) + assert.Nil(t, response) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "announce frame dropped announcement"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)}, + }) + }) + + t.Run("handle-command-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&handleCommandErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "command fails") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, response) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "announce frame dropped announcement"}, + {logrus.InfoLevel, "command frame command fails"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)}, + {logrus.ErrorLevel, "could not handle command"}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("respond-nil-response", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&respondNilResponseServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "dropped command") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.ErrorIs(t, err, errResponseExpected) + assert.Nil(t, response) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "announce frame dropped announcement"}, + {logrus.InfoLevel, "command frame dropped command"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respond)}, + }) + }) + + t.Run("respond-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&respondErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "command data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, response) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "announce frame dropped announcement"}, + {logrus.InfoLevel, "command frame command data"}, + {logrus.InfoLevel, "handle command"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respond)}, + {logrus.ErrorLevel, "could not respond"}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("unknown-protocol-state", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog) + p.state = protocolState(100) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.ErrorContains(t, err, "unknown protocol state") + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, "handling protocol state unknown 100"}, + }) + + assertLogs(t, hHook, []expectedLogs{}) + }) +} + +func sendServerCommand(ctx context.Context, in chan []byte, announce string, command string) { + select { + case <-ctx.Done(): + return + case in <- []byte(announce): + break + } + + if command == "" { + return + } + + select { + case <-ctx.Done(): + return + case in <- []byte(command): + break + } +} + +func readServerResponse(ctx context.Context, out chan []byte) []byte { + select { + case <-ctx.Done(): + return nil + case response := <-out: + return response + } +} + +type noopClientHandler struct{} + +func (h *noopClientHandler) Send(context.Context, *Command, chan []byte) error { + return nil +} + +func (h *noopClientHandler) ResponseAnnounce(context.Context, chan []byte) (*Response, error) { + return nil, nil +} + +func (h *noopClientHandler) ResponseData(context.Context, chan []byte, *Response) error { + return nil +} + +func (h *noopClientHandler) HandleResponse(context.Context, *Response) error { + return nil +} + +type testClientHandler struct{ logger *logrus.Logger } + +func (h *testClientHandler) Send(ctx context.Context, command *Command, out chan []byte) error { + h.logger.Infof("send command %s", command.Announce.Code) + + select { + case <-ctx.Done(): + return nil + case out <- []byte(command.Announce.String()): + break + } + + select { + case <-ctx.Done(): + return nil + case out <- []byte(command.Command.(string)): //nolint:forcetypeassert + break + } + + return nil +} + +func (h *testClientHandler) ResponseAnnounce(ctx context.Context, in chan []byte) (*Response, error) { + select { + case <-ctx.Done(): + return nil, nil + case announce := <-in: + h.logger.Infof("received response announce %s", announce) + + response := &Response{Announce: messages.BuildResponseAnnounce(messages.RespUndef, string(announce))} + + return response, nil + } +} + +func (h *testClientHandler) ResponseData(ctx context.Context, in chan []byte, response *Response) error { + select { + case <-ctx.Done(): + return nil + case data := <-in: + h.logger.Infof("received response data %s", string(data)) + + response.Response = data + + return nil + } +} + +func (h *testClientHandler) HandleResponse(_ context.Context, response *Response) error { + h.logger.Infof("handle response %s", response.Announce.ID) + + return nil +} + +type commandAnnounceErrClientHandler struct{ testClientHandler } + +func (h *commandAnnounceErrClientHandler) Send(context.Context, *Command, chan []byte) error { + return errors.New("failed sending command") +} + +func TestNewClient(t *testing.T) { + logger, _ := test.NewNullLogger() + + in := make(chan []byte) + out := make(chan []byte) + + commands := make(chan *Command, 10) + + client := NewClient(&noopClientHandler{}, commands, in, out, logger) + + assert.NotNil(t, client) + assert.IsType(t, (*ClientProtocol)(nil), client) +} + +type responseAnnounceErrClientHandler struct{ testClientHandler } + +func (h *responseAnnounceErrClientHandler) ResponseAnnounce(ctx context.Context, in chan []byte) (*Response, error) { + select { + case <-ctx.Done(): + return nil, nil + case <-in: + return nil, errors.New("failed receiving response announce") + } +} + +type responseAnnounceNilClientHandler struct{ testClientHandler } + +func (h *responseAnnounceNilClientHandler) ResponseAnnounce(ctx context.Context, in chan []byte) (*Response, error) { + select { + case <-ctx.Done(): + return nil, nil + case <-in: + return nil, nil + } +} + +type responseDataErrClientHandler struct{ testClientHandler } + +func (h *responseDataErrClientHandler) ResponseData(ctx context.Context, in chan []byte, _ *Response) error { + select { + case <-ctx.Done(): + return nil + case <-in: + return errors.New("failed to handle response data") + } +} + +type responseDataTimeoutErrClientHandler struct{ testClientHandler } + +func (h *responseDataTimeoutErrClientHandler) ResponseData(ctx context.Context, in chan []byte, _ *Response) error { + select { + case <-ctx.Done(): + return nil + case <-in: + return ErrResponseDataTimeoutExpired + } +} + +type responseDataNilClientHandler struct{ testClientHandler } + +func (h *responseDataNilClientHandler) ResponseData(ctx context.Context, in chan []byte, response *Response) error { + select { + case <-ctx.Done(): + return nil + case <-in: + response.Response = nil + + return nil + } +} + +type handleResponseErrClientHandler struct{ testClientHandler } + +func (h *handleResponseErrClientHandler) HandleResponse(context.Context, *Response) error { + return errors.New("failed to handle response") +} + +func TestClientProtocol_Handle(t *testing.T) { //nolint:cyclop + protoLog, pHook := test.NewNullLogger() + protoLog.SetLevel(logrus.TraceLevel) + + handlerLog, hHook := test.NewNullLogger() + + t.Run("initialize", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + var ( + err error + sent []byte + ) + + c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + return + case sent = <-out: + return + } + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, sent) + + assertLogs(t, hHook, []expectedLogs{}) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("happy-case", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + var ( + err error + announce, command []byte + ) + + c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "send command UNDEFINED"}, + {logrus.InfoLevel, "received response announce test response announce"}, + {logrus.InfoLevel, "received response data test response data"}, + {logrus.InfoLevel, "handle response test response announce"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleResponse)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("command-announce-nil", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + close(commands) + + var ( + err error + announce, command []byte + ) + + c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.Error(t, err) + assert.Nil(t, announce) + assert.Nil(t, command) + + assertLogs(t, hHook, []expectedLogs{}) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("command-announce-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + var ( + err error + announce, command []byte + ) + + c := NewClient(&commandAnnounceErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, announce) + assert.Nil(t, command) + + assertLogs(t, hHook, []expectedLogs{}) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.ErrorLevel, "could not send command announce"}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("response-announce-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + var ( + err error + announce, command []byte + ) + + c := NewClient(&responseAnnounceErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "send command UNDEFINED"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, + {logrus.ErrorLevel, "could not handle response announce"}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("response-data-nil", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + var ( + err error + announce, command []byte + ) + + c := NewClient(&responseAnnounceNilClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.ErrorIs(t, err, errResponseExpected) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "send command UNDEFINED"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, + }) + }) + + t.Run("response-data-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + var ( + err error + announce, command []byte + ) + + c := NewClient(&responseDataErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "send command UNDEFINED"}, + {logrus.InfoLevel, "received response announce test response announce"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, + {logrus.ErrorLevel, "could not handle response data"}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, + }) + }) + + t.Run("response-data-timeout-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + var ( + err error + announce, command []byte + ) + + c := NewClient(&responseDataTimeoutErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "send command UNDEFINED"}, + {logrus.InfoLevel, "received response announce test response announce"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, + {logrus.ErrorLevel, "could not handle response data"}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("handle-response-nil", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + var ( + err error + announce, command []byte + ) + + c := NewClient(&responseDataNilClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.ErrorIs(t, err, errResponseExpected) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "send command UNDEFINED"}, + {logrus.InfoLevel, "received response announce test response announce"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleResponse)}, + }) + }) + + t.Run("handle-response-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + var ( + err error + announce, command []byte + ) + + c := NewClient(&handleResponseErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + assertLogs(t, hHook, []expectedLogs{ + {logrus.InfoLevel, "send command UNDEFINED"}, + {logrus.InfoLevel, "received response announce test response announce"}, + {logrus.InfoLevel, "received response data test response data"}, + }) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleResponse)}, + {logrus.ErrorLevel, "could not handle response"}, + {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, + }) + }) + + t.Run("unknown-protocol-state", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) + + var ( + err error + sent []byte + ) + + c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) + c.state = protocolState(100) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + return + case sent = <-out: + return + } + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.ErrorContains(t, err, "unknown protocol state") + assert.Nil(t, sent) + + assertLogs(t, hHook, []expectedLogs{}) + + assertLogs(t, pHook, []expectedLogs{ + {logrus.DebugLevel, "handling protocol state unknown 100"}, + }) + }) +} + +func prepareClientCommand(ctx context.Context, commands chan *Command, command string) { + select { + case <-ctx.Done(): + return + case commands <- &Command{Announce: messages.BuildCommandAnnounce(messages.CmdUndef), Command: command}: + return + } +} + +func receiveClientAnnounce(ctx context.Context, in chan []byte, announce, data string) { + select { + case <-ctx.Done(): + return + case in <- []byte(announce): + break + } + + select { + case <-ctx.Done(): + return + case in <- []byte(data): + break + } +} + +func sendClientCommand(ctx context.Context, out chan []byte) ([]byte, []byte) { + var announce, command []byte + + select { + case <-ctx.Done(): + return nil, nil + case announce = <-out: + break + } + + select { + case <-ctx.Done(): + return announce, nil + case command = <-out: + return announce, command + } +} + +func TestNewCOBSFramer(t *testing.T) { + logger, _ := test.NewNullLogger() + + framer := NewCOBSFramer(logger) + + require.NotNil(t, framer) + assert.IsType(t, (*COBSFramer)(nil), framer) + assert.Implements(t, (*Framer)(nil), framer) +} + +func TestCOBSFramer_ReadFrames(t *testing.T) { + logger, loggerHook := test.NewNullLogger() + logger.SetLevel(logrus.TraceLevel) + + framer := NewCOBSFramer(logger) + + t.Run("read error", func(t *testing.T) { + t.Cleanup(func() { + loggerHook.Reset() + }) + + readFrames := make(chan []byte) + + testError := errors.New("test error") + + reader := iotest.ErrReader(testError) + + err := framer.ReadFrames(context.Background(), reader, readFrames) + + assert.ErrorIs(t, err, testError) + + frame := <-readFrames + + assert.Nil(t, frame) + + assert.Empty(t, loggerHook.AllEntries()) + }) + + t.Run("no bytes", func(t *testing.T) { + t.Cleanup(func() { + loggerHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + readFrames := make(chan []byte) + + reader := &bytes.Buffer{} + + var err error + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = framer.ReadFrames(ctx, reader, readFrames) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + + assert.Empty(t, loggerHook.AllEntries()) + }) + + t.Run("incomplete bytes", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + readFrames := make(chan []byte) + + reader := strings.NewReader("some bytes") + + var err error + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = framer.ReadFrames(ctx, reader, readFrames) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.Nil(t, err) + + logEntries := loggerHook.AllEntries() + + require.Len(t, logEntries, 2) + assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) + assert.Equal(t, "read 10 raw bytes", logEntries[0].Message) + assert.Equal(t, logrus.TraceLevel, logEntries[1].Level) + assert.Equal(t, "read buffer is now 10 bytes long", logEntries[1].Message) + }) + + t.Run("invalid bytes", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + readFrames := make(chan []byte) + + reader := bytes.NewBuffer([]byte("some bytes")) + + reader.WriteByte(CobsDelimiter) + + var err error + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = framer.ReadFrames(ctx, reader, readFrames) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + + logEntries := loggerHook.AllEntries() + + require.Len(t, logEntries, 3) + assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) + assert.Equal(t, "read 11 raw bytes", logEntries[0].Message) + assert.Equal(t, logrus.WarnLevel, logEntries[1].Level) + assert.Equal(t, "skipping invalid frame of 11 bytes", logEntries[1].Message) + assert.Equal(t, logrus.TraceLevel, logEntries[2].Level) + assert.Equal(t, "read buffer is now 0 bytes long", logEntries[2].Message) + }) + + t.Run("valid frame", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + readFrames := make(chan []byte) + + reader := bytes.NewBuffer(cobs.Encode([]byte("some bytes"), framer.config)) + + var ( + err error + frame []byte + ) + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + err = framer.ReadFrames(ctx, reader, readFrames) + wg.Done() + }() + + go func() { + frame = <-readFrames + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.NoError(t, err) + + assert.NotNil(t, frame) + if frame != nil { + assert.Equal(t, []byte("some bytes"), frame) + } + + logEntries := loggerHook.AllEntries() + + require.Len(t, logEntries, 3) + assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) + assert.Contains(t, logEntries[0].Message, "raw bytes") + assert.Equal(t, logrus.TraceLevel, logEntries[1].Level) + assert.Equal(t, "frame decoded to length 10", logEntries[1].Message) + assert.Equal(t, logrus.TraceLevel, logEntries[2].Level) + assert.Equal(t, "read buffer is now 0 bytes long", logEntries[2].Message) + }) +} + +type brokenWriter struct{} + +var errB0rk3d = errors.New("you b0rk3d it") + +func (b brokenWriter) Write([]byte) (int, error) { + return 0, errB0rk3d +} + +func TestCOBSFramer_WriteFrames(t *testing.T) { + logger, loggerHook := test.NewNullLogger() + logger.SetLevel(logrus.TraceLevel) + + framer := NewCOBSFramer(logger) + + t.Run("closed channel", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + var result error + + in := make(chan []byte) + + close(in) + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + + wg.Add(1) + + go func() { + result = framer.WriteFrames(ctx, io.Discard, in) + + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.Nil(t, result) + + logEntries := loggerHook.AllEntries() + + assert.Len(t, logEntries, 1) + assert.Equal(t, logrus.DebugLevel, logEntries[0].Level) + assert.Equal(t, "channel closed", logEntries[0].Message) + }) + + t.Run("closed writer", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + var result error + + in := make(chan []byte) + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + + wg.Add(1) + + go func() { + result = framer.WriteFrames(ctx, &brokenWriter{}, in) + + wg.Done() + }() + + in <- []byte("test message") + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.NotNil(t, result) + assert.True(t, errors.Is(result, errB0rk3d)) + + assert.Len(t, loggerHook.AllEntries(), 0) + }) + + t.Run("valid frame", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + var err error + + in := make(chan []byte) + + out := &bytes.Buffer{} + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = framer.WriteFrames(ctx, out, in) + + wg.Done() + }() + + in <- []byte("test message") + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.Nil(t, err) + + logEntries := loggerHook.AllEntries() + assert.Len(t, logEntries, 1) + assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) + + frame := out.Bytes() + assert.NoError(t, cobs.Verify(frame, framer.config)) + assert.Equal(t, []byte("test message"), cobs.Decode(out.Bytes(), framer.config)) + }) +}