From 51ca9cc69d8c35878f75058b1d7d58ffcb293fb7 Mon Sep 17 00:00:00 2001 From: Jan Dittberner Date: Sun, 4 Dec 2022 19:17:36 +0100 Subject: [PATCH] Refactor public API tests for protocol - move tests for public API to protocol_test package - add tests for context handling of COBSFramer --- pkg/protocol/protocol.go | 28 +- pkg/protocol/protocol_it_test.go | 2079 ++++++++++++++++++++++++++++++ pkg/protocol/protocol_test.go | 2057 +---------------------------- 3 files changed, 2165 insertions(+), 1999 deletions(-) create mode 100644 pkg/protocol/protocol_it_test.go diff --git a/pkg/protocol/protocol.go b/pkg/protocol/protocol.go index 044c4d2..59b9f98 100644 --- a/pkg/protocol/protocol.go +++ b/pkg/protocol/protocol.go @@ -32,7 +32,7 @@ import ( "git.cacert.org/cacert-gosigner/pkg/messages" ) -const CobsDelimiter = 0x00 +const COBSDelimiter = 0x00 type Command struct { Announce *messages.CommandAnnounce @@ -73,9 +73,9 @@ type ClientHandler interface { var ( errCommandExpected = errors.New("command must not be nil") - errCommandAnnounceExpected = errors.New("command must have an announcement") - errCommandDataExpected = errors.New("command must have data") - errResponseExpected = errors.New("response must not be nil") + ErrCommandAnnounceExpected = errors.New("command must have an announcement") + ErrCommandDataExpected = errors.New("command must have data") + ErrResponseExpected = errors.New("response must not be nil") ErrResponseAnnounceTimeoutExpired = errors.New("response announce timeout expired") ErrResponseDataTimeoutExpired = errors.New("response data timeout expired") @@ -133,7 +133,7 @@ func (p *ServerProtocol) Handle(ctx context.Context) error { return nil default: - p.logger.Debugf("handling protocol state %s", p.state) + p.logger.WithField("state", p.state).Debug("handling protocol state") switch p.state { case cmdAnnounce: @@ -175,7 +175,7 @@ func (p *ServerProtocol) commandAnnounce(ctx context.Context) *Command { func (p *ServerProtocol) commandData(ctx context.Context, command *Command) error { if command == nil || command.Announce == nil { - return errCommandAnnounceExpected + return ErrCommandAnnounceExpected } err := p.handler.CommandData(ctx, p.in, command) @@ -194,7 +194,7 @@ func (p *ServerProtocol) commandData(ctx context.Context, command *Command) erro func (p *ServerProtocol) handleCommand(ctx context.Context, command *Command) (*Response, error) { if command == nil || command.Announce == nil || command.Command == nil { - return nil, errCommandDataExpected + return nil, ErrCommandDataExpected } response, err := p.handler.HandleCommand(ctx, command) @@ -213,7 +213,7 @@ func (p *ServerProtocol) handleCommand(ctx context.Context, command *Command) (* func (p *ServerProtocol) respond(ctx context.Context, response *Response) error { if response == nil { - return errResponseExpected + return ErrResponseExpected } err := p.handler.Respond(ctx, response, p.out) @@ -260,7 +260,7 @@ func (p *ClientProtocol) Handle(ctx context.Context) error { case <-ctx.Done(): return nil default: - p.logger.Debugf("handling protocol state %s", p.state) + p.logger.WithField("state", p.state).Debug("handling protocol state") switch p.state { case cmdAnnounce: @@ -331,7 +331,7 @@ func (p *ClientProtocol) respAnnounce(ctx context.Context) *Response { func (p *ClientProtocol) respData(ctx context.Context, response *Response) error { if response == nil || response.Announce == nil { - return errResponseExpected + return ErrResponseExpected } err := p.handler.ResponseData(ctx, p.in, response) @@ -354,7 +354,7 @@ func (p *ClientProtocol) respData(ctx context.Context, response *Response) error func (p *ClientProtocol) handleResponse(ctx context.Context, response *Response) error { if response == nil || response.Announce == nil || response.Response == nil { - return errResponseExpected + return ErrResponseExpected } err := p.handler.HandleResponse(ctx, response) @@ -404,8 +404,10 @@ type COBSFramer struct { encoder cobs.Encoder } +var COBSConfig = cobs.Config{SpecialByte: COBSDelimiter, Delimiter: true, EndingSave: true} + func NewCOBSFramer(logger *logrus.Logger) (*COBSFramer, error) { - encoder, err := cobs.NewEncoder(cobs.Config{SpecialByte: CobsDelimiter, Delimiter: true, EndingSave: true}) + encoder, err := cobs.NewEncoder(COBSConfig) if err != nil { return nil, fmt.Errorf("could not setup encoder: %w", err) } @@ -447,7 +449,7 @@ func (c *COBSFramer) ReadFrames(ctx context.Context, reader io.Reader, frameChan buffer.Write(raw) for { - data, err = buffer.ReadBytes(CobsDelimiter) + data, err = buffer.ReadBytes(COBSDelimiter) if err != nil { if errors.Is(err, io.EOF) { buffer.Write(data) diff --git a/pkg/protocol/protocol_it_test.go b/pkg/protocol/protocol_it_test.go new file mode 100644 index 0000000..53a602c --- /dev/null +++ b/pkg/protocol/protocol_it_test.go @@ -0,0 +1,2079 @@ +/* +Copyright 2022 CAcert Inc. +SPDX-License-Identifier: Apache-2.0 + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package protocol_test provides tests for the public API of the protocol package. +package protocol_test + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "strings" + "sync" + "testing" + "testing/iotest" + "time" + + "github.com/justincpresley/go-cobs" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "git.cacert.org/cacert-gosigner/pkg/messages" + "git.cacert.org/cacert-gosigner/pkg/protocol" +) + +func TestNewServer(t *testing.T) { + logger, _ := test.NewNullLogger() + + in := make(chan []byte) + out := make(chan []byte) + + server := protocol.NewServer(&protocol.NoopServerHandler{}, in, out, logger) + + assert.NotNil(t, server) + assert.IsType(t, (*protocol.ServerProtocol)(nil), server) +} + +type testServerHandler struct { + logger *logrus.Logger +} + +func (h *testServerHandler) CommandAnnounce(ctx context.Context, in <-chan []byte) (*protocol.Command, error) { + select { + case <-ctx.Done(): + return nil, nil + case frame := <-in: + h.logger.Infof("announce frame %s", string(frame)) + + return &protocol.Command{Announce: messages.BuildCommandAnnounce(messages.CmdUndef)}, nil + } +} + +func (h *testServerHandler) CommandData(ctx context.Context, in <-chan []byte, command *protocol.Command) error { + select { + case <-ctx.Done(): + return nil + case frame := <-in: + h.logger.Infof("command frame %s", string(frame)) + + command.Command = frame + + return nil + } +} + +func (h *testServerHandler) HandleCommand(_ context.Context, command *protocol.Command) (*protocol.Response, error) { + h.logger.Info("handle command") + + return &protocol.Response{ + Announce: messages.BuildResponseAnnounce(messages.RespUndef, command.Announce.ID), + Response: fmt.Sprintf("response for command %s", command.Command), + }, nil +} + +func (h *testServerHandler) Respond(ctx context.Context, response *protocol.Response, out chan<- []byte) error { + h.logger.Info("send response") + + buf := bytes.NewBuffer([]byte("test-response-")) + buf.WriteString(response.Announce.ID) + + select { + case <-ctx.Done(): + return nil + case out <- buf.Bytes(): + return nil + } +} + +type commandAnnounceErrServerHandler struct{ testServerHandler } + +func (h *commandAnnounceErrServerHandler) CommandAnnounce( + ctx context.Context, in <-chan []byte, +) (*protocol.Command, error) { + select { + case <-ctx.Done(): + return nil, nil + case announce := <-in: + return nil, fmt.Errorf("failed to handle announce %s", announce) + } +} + +type commandAnnounceNilServerHandler struct{ testServerHandler } + +func (h *commandAnnounceNilServerHandler) CommandAnnounce( + ctx context.Context, in <-chan []byte, +) (*protocol.Command, error) { + select { + case <-ctx.Done(): + return nil, nil + case <-in: + return nil, nil + } +} + +type commandDataErrServerHandler struct{ testServerHandler } + +func (h *commandDataErrServerHandler) CommandData(ctx context.Context, in <-chan []byte, _ *protocol.Command) error { + select { + case <-ctx.Done(): + return nil + case data := <-in: + return fmt.Errorf("failed to handle command data %s", data) + } +} + +type commandDataNilAnnouncementServerHandler struct{ testServerHandler } + +func (h *commandDataNilAnnouncementServerHandler) CommandAnnounce( + ctx context.Context, in <-chan []byte, +) (*protocol.Command, error) { + select { + case <-ctx.Done(): + return nil, nil + case <-in: + return &protocol.Command{}, nil + } +} + +type handleCommandErrServerHandler struct{ testServerHandler } + +func (h *handleCommandErrServerHandler) HandleCommand( + _ context.Context, command *protocol.Command, +) (*protocol.Response, error) { + return nil, fmt.Errorf("failed to handle command %s", command) +} + +type handleCommandNilCommandServerHandler struct{ testServerHandler } + +func (h *handleCommandNilCommandServerHandler) CommandData( + ctx context.Context, in <-chan []byte, _ *protocol.Command, +) error { + select { + case <-ctx.Done(): + return nil + case <-in: + return nil + } +} + +type respondErrServerHandler struct{ testServerHandler } + +func (h *respondErrServerHandler) Respond(_ context.Context, r *protocol.Response, _ chan<- []byte) error { + return fmt.Errorf("failed to respond %s", r) +} + +type respondNilResponseServerHandler struct{ testServerHandler } + +func (h *respondNilResponseServerHandler) HandleCommand( + context.Context, *protocol.Command, +) (*protocol.Response, error) { + return nil, nil +} + +func TestServerProtocol_Handle(t *testing.T) { + protoLog, pHook := test.NewNullLogger() + protoLog.SetLevel(logrus.TraceLevel) + + handlerLog, hHook := test.NewNullLogger() + + t.Run("initialization", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := protocol.NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + { + Level: logrus.DebugLevel, + Message: "handling protocol state", + }, + }) + + protocol.AssertLogs(t, hHook, []logrus.Entry{}) + }) + + t.Run("happy-path", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := protocol.NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "test command") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.NoError(t, err) + + assert.NotNil(t, response) + assert.NotEmpty(t, response) + assert.True(t, strings.HasPrefix(string(response), "test-response")) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, + {Level: logrus.InfoLevel, Message: "command frame test command"}, + {Level: logrus.InfoLevel, Message: "handle command"}, + {Level: logrus.InfoLevel, Message: "send response"}, + }) + }) + + t.Run("command-announce-nil", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := protocol.NewServer(&commandAnnounceNilServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.ErrorIs(t, err, protocol.ErrCommandAnnounceExpected) + assert.Nil(t, response) + + protocol.AssertLogs(t, hHook, []logrus.Entry{}) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("command-announce-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := protocol.NewServer(&commandAnnounceErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, response) + + protocol.AssertLogs(t, hHook, []logrus.Entry{}) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.ErrorLevel, Message: "could not handle command announce"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("command-data-nil-announce", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := protocol.NewServer( + &commandDataNilAnnouncementServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog, + ) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + select { + case <-ctx.Done(): + break + case response = <-out: + break + } + wg.Done() + }() + + in <- []byte("dropped announcement") + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.ErrorIs(t, err, protocol.ErrCommandAnnounceExpected) + assert.Nil(t, response) + + protocol.AssertLogs(t, hHook, []logrus.Entry{}) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("command-data-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := protocol.NewServer(&commandDataErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "command fails") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, response) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.ErrorLevel, Message: "could not handle command data"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("handle-command-nil-command", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := protocol.NewServer( + &handleCommandNilCommandServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog, + ) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "dropped command") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.ErrorIs(t, err, protocol.ErrCommandDataExpected) + assert.Nil(t, response) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("handle-command-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := protocol.NewServer(&handleCommandErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "command fails") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, response) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, + {Level: logrus.InfoLevel, Message: "command frame command fails"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.ErrorLevel, Message: "could not handle command"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("respond-nil-response", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := protocol.NewServer( + &respondNilResponseServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog, + ) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "dropped command") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.ErrorIs(t, err, protocol.ErrResponseExpected) + assert.Nil(t, response) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, + {Level: logrus.InfoLevel, Message: "command frame dropped command"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("respond-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + + var err error + + p := protocol.NewServer(&respondErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + + var response []byte + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = p.Handle(ctx) + wg.Done() + }() + + go func() { + response = readServerResponse(ctx, out) + wg.Done() + }() + + go func() { + sendServerCommand(ctx, in, "dropped announcement", "command data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, response) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, + {Level: logrus.InfoLevel, Message: "command frame command data"}, + {Level: logrus.InfoLevel, Message: "handle command"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.ErrorLevel, Message: "could not respond"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) +} + +func sendServerCommand(ctx context.Context, in chan []byte, announce string, command string) { + select { + case <-ctx.Done(): + return + case in <- []byte(announce): + break + } + + if command == "" { + return + } + + select { + case <-ctx.Done(): + return + case in <- []byte(command): + break + } +} + +func readServerResponse(ctx context.Context, out chan []byte) []byte { + select { + case <-ctx.Done(): + return nil + case response := <-out: + return response + } +} + +type testClientHandler struct{ logger *logrus.Logger } + +func (h *testClientHandler) Send(ctx context.Context, command *protocol.Command, out chan<- []byte) error { + h.logger.Infof("send command %s", command.Announce.Code) + + select { + case <-ctx.Done(): + return nil + case out <- []byte(command.Announce.String()): + break + } + + outBytes, ok := command.Command.(string) + if !ok { + return fmt.Errorf("could not use command '%+v' as string", command.Command) + } + + select { + case <-ctx.Done(): + return nil + case out <- []byte(outBytes): + break + } + + return nil +} + +func (h *testClientHandler) ResponseAnnounce(ctx context.Context, in <-chan []byte) (*protocol.Response, error) { + select { + case <-ctx.Done(): + return nil, nil + case announce := <-in: + h.logger.Infof("received response announce %s", announce) + + response := &protocol.Response{Announce: messages.BuildResponseAnnounce(messages.RespUndef, string(announce))} + + return response, nil + } +} + +func (h *testClientHandler) ResponseData(ctx context.Context, in <-chan []byte, response *protocol.Response) error { + select { + case <-ctx.Done(): + return nil + case data := <-in: + h.logger.Infof("received response data %s", string(data)) + + response.Response = data + + return nil + } +} + +func (h *testClientHandler) HandleResponse(_ context.Context, response *protocol.Response) error { + h.logger.Infof("handle response %s", response.Announce.ID) + + return nil +} + +type commandAnnounceErrClientHandler struct{ testClientHandler } + +func (h *commandAnnounceErrClientHandler) Send(context.Context, *protocol.Command, chan<- []byte) error { + return errors.New("failed sending command") +} + +func TestNewClient(t *testing.T) { + logger, _ := test.NewNullLogger() + + in := make(chan []byte) + out := make(chan []byte) + + commands := make(chan *protocol.Command, 10) + + client := protocol.NewClient(&protocol.NoopClientHandler{}, commands, in, out, logger) + + assert.NotNil(t, client) + assert.IsType(t, (*protocol.ClientProtocol)(nil), client) +} + +type responseAnnounceErrClientHandler struct{ testClientHandler } + +func (h *responseAnnounceErrClientHandler) ResponseAnnounce( + ctx context.Context, in <-chan []byte, +) (*protocol.Response, error) { + select { + case <-ctx.Done(): + return nil, nil + case <-in: + return nil, errors.New("failed receiving response announce") + } +} + +type responseAnnounceNilClientHandler struct{ testClientHandler } + +func (h *responseAnnounceNilClientHandler) ResponseAnnounce( + ctx context.Context, in <-chan []byte, +) (*protocol.Response, error) { + select { + case <-ctx.Done(): + return nil, nil + case <-in: + return nil, nil + } +} + +type responseDataErrClientHandler struct{ testClientHandler } + +func (h *responseDataErrClientHandler) ResponseData(ctx context.Context, in <-chan []byte, _ *protocol.Response) error { + select { + case <-ctx.Done(): + return nil + case <-in: + return errors.New("failed to handle response data") + } +} + +type responseDataTimeoutErrClientHandler struct{ testClientHandler } + +func (h *responseDataTimeoutErrClientHandler) ResponseData( + ctx context.Context, in <-chan []byte, _ *protocol.Response, +) error { + select { + case <-ctx.Done(): + return nil + case <-in: + return protocol.ErrResponseDataTimeoutExpired + } +} + +type responseDataNilClientHandler struct{ testClientHandler } + +func (h *responseDataNilClientHandler) ResponseData( + ctx context.Context, in <-chan []byte, response *protocol.Response, +) error { + select { + case <-ctx.Done(): + return nil + case <-in: + response.Response = nil + + return nil + } +} + +type handleResponseErrClientHandler struct{ testClientHandler } + +func (h *handleResponseErrClientHandler) HandleResponse(context.Context, *protocol.Response) error { + return errors.New("failed to handle response") +} + +func TestClientProtocol_Handle(t *testing.T) { + protoLog, pHook := test.NewNullLogger() + protoLog.SetLevel(logrus.TraceLevel) + + handlerLog, hHook := test.NewNullLogger() + + t.Run("initialize", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *protocol.Command, 10) + + var ( + err error + sent []byte + ) + + c := protocol.NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + return + case sent = <-out: + return + } + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, sent) + + protocol.AssertLogs(t, hHook, []logrus.Entry{}) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + { + Level: logrus.DebugLevel, + Message: "handling protocol state", + }, + }) + }) + + t.Run("happy-case", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *protocol.Command, 10) + + var ( + err error + announce, command []byte + ) + + c := protocol.NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, + {Level: logrus.InfoLevel, Message: "received response announce test response announce"}, + {Level: logrus.InfoLevel, Message: "received response data test response data"}, + {Level: logrus.InfoLevel, Message: "handle response test response announce"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.TraceLevel, Message: "handled command"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("command-announce-nil", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *protocol.Command, 10) + + close(commands) + + var ( + err error + announce, command []byte + ) + + c := protocol.NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.Error(t, err) + assert.Nil(t, announce) + assert.Nil(t, command) + + protocol.AssertLogs(t, hHook, []logrus.Entry{}) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("command-announce-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *protocol.Command, 10) + + var ( + err error + announce, command []byte + ) + + c := protocol.NewClient(&commandAnnounceErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.Nil(t, announce) + assert.Nil(t, command) + + protocol.AssertLogs(t, hHook, []logrus.Entry{}) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.ErrorLevel, Message: "could not send command announce"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("response-announce-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *protocol.Command, 10) + + var ( + err error + announce, command []byte + ) + + c := protocol.NewClient(&responseAnnounceErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.TraceLevel, Message: "handled command"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.ErrorLevel, Message: "could not handle response announce"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("response-data-nil", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *protocol.Command, 10) + + var ( + err error + announce, command []byte + ) + + c := protocol.NewClient(&responseAnnounceNilClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.ErrorIs(t, err, protocol.ErrResponseExpected) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, + }) + + protocol.AssertLogs( + t, + pHook, + []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.TraceLevel, Message: "handled command"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }, + ) + }) + + t.Run("response-data-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *protocol.Command, 10) + + var ( + err error + announce, command []byte + ) + + c := protocol.NewClient(&responseDataErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, + {Level: logrus.InfoLevel, Message: "received response announce test response announce"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.TraceLevel, Message: "handled command"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.ErrorLevel, Message: "could not handle response data"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("response-data-timeout-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *protocol.Command, 10) + + var ( + err error + announce, command []byte + ) + + c := protocol.NewClient( + &responseDataTimeoutErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog, + ) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, + {Level: logrus.InfoLevel, Message: "received response announce test response announce"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.TraceLevel, Message: "handled command"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.ErrorLevel, Message: "could not handle response data"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("handle-response-nil", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *protocol.Command, 10) + + var ( + err error + announce, command []byte + ) + + c := protocol.NewClient(&responseDataNilClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.ErrorIs(t, err, protocol.ErrResponseExpected) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, + {Level: logrus.InfoLevel, Message: "received response announce test response announce"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.TraceLevel, Message: "handled command"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) + + t.Run("handle-response-error", func(t *testing.T) { + t.Cleanup(func() { + pHook.Reset() + hHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + in, out := make(chan []byte), make(chan []byte) + commands := make(chan *protocol.Command, 10) + + var ( + err error + announce, command []byte + ) + + c := protocol.NewClient(&handleResponseErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(4) + + go func() { + err = c.Handle(ctx) + wg.Done() + }() + + go func() { + announce, command = sendClientCommand(ctx, out) + wg.Done() + }() + + go func() { + prepareClientCommand(ctx, commands, "test command") + wg.Done() + }() + + go func() { + receiveClientAnnounce(ctx, in, "test response announce", "test response data") + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + + assert.NoError(t, err) + assert.NotNil(t, announce) + assert.NotNil(t, command) + + protocol.AssertLogs(t, hHook, []logrus.Entry{ + {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, + {Level: logrus.InfoLevel, Message: "received response announce test response announce"}, + {Level: logrus.InfoLevel, Message: "received response data test response data"}, + }) + + protocol.AssertLogs(t, pHook, []logrus.Entry{ + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.TraceLevel, Message: "handled command"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + {Level: logrus.ErrorLevel, Message: "could not handle response"}, + {Level: logrus.DebugLevel, Message: "handling protocol state"}, + }) + }) +} + +func prepareClientCommand(ctx context.Context, commands chan *protocol.Command, command string) { + select { + case <-ctx.Done(): + return + case commands <- &protocol.Command{Announce: messages.BuildCommandAnnounce(messages.CmdUndef), Command: command}: + return + } +} + +func receiveClientAnnounce(ctx context.Context, in chan []byte, announce, data string) { + select { + case <-ctx.Done(): + return + case in <- []byte(announce): + break + } + + select { + case <-ctx.Done(): + return + case in <- []byte(data): + break + } +} + +func sendClientCommand(ctx context.Context, out chan []byte) ([]byte, []byte) { + var announce, command []byte + + select { + case <-ctx.Done(): + return nil, nil + case announce = <-out: + break + } + + select { + case <-ctx.Done(): + return announce, nil + case command = <-out: + return announce, command + } +} + +func TestNewCOBSFramer(t *testing.T) { + logger, _ := test.NewNullLogger() + + framer, err := protocol.NewCOBSFramer(logger) + + assert.NoError(t, err) + + require.NotNil(t, framer) + assert.IsType(t, (*protocol.COBSFramer)(nil), framer) + assert.Implements(t, (*protocol.Framer)(nil), framer) +} + +type slowTestReader struct{} + +func (s slowTestReader) Read(_ []byte) (n int, err error) { + time.Sleep(1 * time.Second) + + return 0, errors.New("you waited too long") +} + +func TestCOBSFramer_ReadFrames(t *testing.T) { + logger, loggerHook := test.NewNullLogger() + logger.SetLevel(logrus.TraceLevel) + + framer, err := protocol.NewCOBSFramer(logger) + + require.NoError(t, err) + require.NotNil(t, framer) + + t.Run("read error", func(t *testing.T) { + t.Cleanup(func() { + loggerHook.Reset() + }) + + readFrames := make(chan []byte) + + testError := errors.New("test error") + + reader := iotest.ErrReader(testError) + + err := framer.ReadFrames(context.Background(), reader, readFrames) + + assert.ErrorIs(t, err, testError) + + frame := <-readFrames + + assert.Nil(t, frame) + + assert.Empty(t, loggerHook.AllEntries()) + }) + + t.Run("no bytes", func(t *testing.T) { + t.Cleanup(func() { + loggerHook.Reset() + }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + readFrames := make(chan []byte) + + reader := &bytes.Buffer{} + + var err error + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = framer.ReadFrames(ctx, reader, readFrames) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + + assert.Empty(t, loggerHook.AllEntries()) + }) + + t.Run("incomplete bytes", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + readFrames := make(chan []byte) + + reader := strings.NewReader("some bytes") + + var err error + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = framer.ReadFrames(ctx, reader, readFrames) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.Nil(t, err) + + logEntries := loggerHook.AllEntries() + + require.Len(t, logEntries, 2) + assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) + assert.Equal(t, "read 10 raw bytes", logEntries[0].Message) + assert.Equal(t, logrus.TraceLevel, logEntries[1].Level) + assert.Equal(t, "read buffer is now 10 bytes long", logEntries[1].Message) + }) + + t.Run("invalid bytes", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + readFrames := make(chan []byte) + + reader := bytes.NewBuffer([]byte("some bytes")) + + reader.WriteByte(protocol.COBSDelimiter) + + var err error + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = framer.ReadFrames(ctx, reader, readFrames) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + + logEntries := loggerHook.AllEntries() + + require.Len(t, logEntries, 3) + assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) + assert.Equal(t, "read 11 raw bytes", logEntries[0].Message) + assert.Equal(t, logrus.WarnLevel, logEntries[1].Level) + assert.Equal(t, "skipping invalid frame of 11 bytes", logEntries[1].Message) + assert.Equal(t, logrus.TraceLevel, logEntries[2].Level) + assert.Equal(t, "read buffer is now 0 bytes long", logEntries[2].Message) + }) + + t.Run("closed context", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + framer, err := protocol.NewCOBSFramer(logger) + + assert.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + readFrames := make(chan []byte) + err = framer.ReadFrames(ctx, &slowTestReader{}, readFrames) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + require.NoError(t, err) + }) + + t.Run("valid frame", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + readFrames := make(chan []byte) + + encoder, _ := cobs.NewEncoder(protocol.COBSConfig) + + reader := bytes.NewBuffer(encoder.Encode([]byte("some bytes"))) + + var ( + err error + frame []byte + ) + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + err = framer.ReadFrames(ctx, reader, readFrames) + wg.Done() + }() + + go func() { + frame = <-readFrames + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.NoError(t, err) + + assert.NotNil(t, frame) + if frame != nil { + assert.Equal(t, []byte("some bytes"), frame) + } + + logEntries := loggerHook.AllEntries() + + require.Len(t, logEntries, 3) + assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) + assert.Contains(t, logEntries[0].Message, "raw bytes") + assert.Equal(t, logrus.TraceLevel, logEntries[1].Level) + assert.Equal(t, "frame decoded to length 10", logEntries[1].Message) + assert.Equal(t, logrus.TraceLevel, logEntries[2].Level) + assert.Equal(t, "read buffer is now 0 bytes long", logEntries[2].Message) + }) +} + +type brokenWriter struct{} + +var errB0rk3d = errors.New("you b0rk3d it") + +func (b brokenWriter) Write([]byte) (int, error) { + return 0, errB0rk3d +} + +type slowTestWriter struct{} + +func (s slowTestWriter) Write(_ []byte) (n int, err error) { + time.Sleep(1 * time.Second) + + return 0, errors.New("you waited too long") +} + +func TestCOBSFramer_WriteFrames(t *testing.T) { + logger, loggerHook := test.NewNullLogger() + logger.SetLevel(logrus.TraceLevel) + + framer, err := protocol.NewCOBSFramer(logger) + + require.NoError(t, err) + require.NotNil(t, framer) + + t.Run("closed channel", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + var result error + + in := make(chan []byte) + + close(in) + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + + wg.Add(1) + + go func() { + result = framer.WriteFrames(ctx, io.Discard, in) + + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.Nil(t, result) + + logEntries := loggerHook.AllEntries() + + assert.Len(t, logEntries, 1) + assert.Equal(t, logrus.DebugLevel, logEntries[0].Level) + assert.Equal(t, "channel closed", logEntries[0].Message) + }) + + t.Run("closed writer", func(t *testing.T) { + t.Cleanup(loggerHook.Reset) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + var result error + + in := make(chan []byte) + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + + wg.Add(1) + + go func() { + result = framer.WriteFrames(ctx, &brokenWriter{}, in) + + wg.Done() + }() + + in <- []byte("test message") + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.NotNil(t, result) + assert.True(t, errors.Is(result, errB0rk3d)) + + assert.Len(t, loggerHook.AllEntries(), 0) + }) + + t.Run("closed context", func(t *testing.T) { + t.Cleanup(loggerHook.Reset) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + framer, err := protocol.NewCOBSFramer(logger) + + assert.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(2) + + frames := make(chan []byte) + + go func() { + err = framer.WriteFrames(ctx, &slowTestWriter{}, frames) + + wg.Done() + }() + + go func() { + select { + case <-ctx.Done(): + case frames <- []byte("my test frame that shall never go out"): + } + + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + + cancel() + wg.Wait() + + assert.NoError(t, err) + }) + + t.Run("valid frame", func(t *testing.T) { + t.Cleanup(func() { loggerHook.Reset() }) + + if testing.Short() { + t.Skip("skipping test in short mode") + } + + var err error + + in := make(chan []byte) + + out := &bytes.Buffer{} + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + err = framer.WriteFrames(ctx, out, in) + + wg.Done() + }() + + in <- []byte("test message") + + time.Sleep(10 * time.Millisecond) + + cancel() + + wg.Wait() + + assert.Nil(t, err) + + logEntries := loggerHook.AllEntries() + assert.Len(t, logEntries, 1) + assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) + + frame := out.Bytes() + + encoder, _ := cobs.NewEncoder(protocol.COBSConfig) + + assert.NoError(t, encoder.Verify(frame)) + assert.Equal(t, []byte("test message"), encoder.Decode(out.Bytes())) + }) +} diff --git a/pkg/protocol/protocol_test.go b/pkg/protocol/protocol_test.go index a99fbdc..64075a5 100644 --- a/pkg/protocol/protocol_test.go +++ b/pkg/protocol/protocol_test.go @@ -18,40 +18,34 @@ limitations under the License. package protocol import ( - "bytes" "context" - "errors" - "fmt" - "io" "strings" "sync" "testing" - "testing/iotest" "time" "github.com/google/uuid" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "git.cacert.org/cacert-gosigner/pkg/messages" ) -type expectedLogs struct { - level logrus.Level - message string -} - -func assertLogs(t *testing.T, hook *test.Hook, expected []expectedLogs) { +func AssertLogs(t *testing.T, hook *test.Hook, expected []logrus.Entry) { t.Helper() logEntries := hook.AllEntries() assert.Len(t, logEntries, len(expected)) for i, e := range expected { - assert.Equal(t, e.level, logEntries[i].Level) - assert.Equal(t, e.message, logEntries[i].Message) + assert.Equal(t, e.Level, logEntries[i].Level) + assert.Equal(t, e.Message, logEntries[i].Message) + + for k, v := range e.Data { + assert.Contains(t, logEntries[i].Data, k) + assert.Equal(t, logEntries[i].Data[k], v) + } } } @@ -121,386 +115,90 @@ func TestProtocolState_String(t *testing.T) { }) } -type noopServerHandler struct{} +type NoopServerHandler struct{} -func (h *noopServerHandler) CommandAnnounce(context.Context, <-chan []byte) (*Command, error) { +func (h *NoopServerHandler) CommandAnnounce(context.Context, <-chan []byte) (*Command, error) { return nil, nil } -func (h *noopServerHandler) CommandData(context.Context, <-chan []byte, *Command) error { +func (h *NoopServerHandler) CommandData(context.Context, <-chan []byte, *Command) error { return nil } -func (h *noopServerHandler) HandleCommand(context.Context, *Command) (*Response, error) { +func (h *NoopServerHandler) HandleCommand(context.Context, *Command) (*Response, error) { return nil, nil } -func (h *noopServerHandler) Respond(context.Context, *Response, chan<- []byte) error { +func (h *NoopServerHandler) Respond(context.Context, *Response, chan<- []byte) error { return nil } -type testServerHandler struct { - logger *logrus.Logger -} - -func (h *testServerHandler) CommandAnnounce(ctx context.Context, in <-chan []byte) (*Command, error) { - select { - case <-ctx.Done(): - return nil, nil - case frame := <-in: - h.logger.Infof("announce frame %s", string(frame)) - - return &Command{Announce: messages.BuildCommandAnnounce(messages.CmdUndef)}, nil - } -} - -func (h *testServerHandler) CommandData(ctx context.Context, in <-chan []byte, command *Command) error { - select { - case <-ctx.Done(): - return nil - case frame := <-in: - h.logger.Infof("command frame %s", string(frame)) - - command.Command = frame - - return nil - } -} - -func (h *testServerHandler) HandleCommand(_ context.Context, command *Command) (*Response, error) { - h.logger.Info("handle command") - - return &Response{ - Announce: messages.BuildResponseAnnounce(messages.RespUndef, command.Announce.ID), - Response: fmt.Sprintf("response for command %s", command.Command), - }, nil -} - -func (h *testServerHandler) Respond(ctx context.Context, response *Response, out chan<- []byte) error { - h.logger.Info("send response") - - buf := bytes.NewBuffer([]byte("test-response-")) - buf.WriteString(response.Announce.ID) - - select { - case <-ctx.Done(): - return nil - case out <- buf.Bytes(): - return nil - } -} - -type commandAnnounceErrServerHandler struct{ testServerHandler } +func TestServerProtocol_HandleUnknownProtocolState(t *testing.T) { + protoLog, pHook := test.NewNullLogger() + protoLog.SetLevel(logrus.TraceLevel) -func (h *commandAnnounceErrServerHandler) CommandAnnounce(ctx context.Context, in <-chan []byte) (*Command, error) { - select { - case <-ctx.Done(): - return nil, nil - case announce := <-in: - return nil, fmt.Errorf("failed to handle announce %s", announce) + if testing.Short() { + t.Skip("skipping test in short mode") } -} -type commandAnnounceNilServerHandler struct{ testServerHandler } + in, out := make(chan []byte), make(chan []byte) -func (h *commandAnnounceNilServerHandler) CommandAnnounce(ctx context.Context, in <-chan []byte) (*Command, error) { - select { - case <-ctx.Done(): - return nil, nil - case <-in: - return nil, nil - } -} + var err error -type commandDataErrServerHandler struct{ testServerHandler } + p := NewServer(&NoopServerHandler{}, in, out, protoLog) + p.state = protocolState(100) -func (h *commandDataErrServerHandler) CommandData(ctx context.Context, in <-chan []byte, _ *Command) error { - select { - case <-ctx.Done(): - return nil - case data := <-in: - return fmt.Errorf("failed to handle command data %s", data) - } -} + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) -type commandDataNilAnnouncementServerHandler struct{ testServerHandler } + go func() { + err = p.Handle(ctx) -func (h *commandDataNilAnnouncementServerHandler) CommandAnnounce( - ctx context.Context, in <-chan []byte, -) (*Command, error) { - select { - case <-ctx.Done(): - return nil, nil - case <-in: - return &Command{}, nil - } -} + wg.Done() + }() -type handleCommandErrServerHandler struct{ testServerHandler } + time.Sleep(10 * time.Millisecond) -func (h *handleCommandErrServerHandler) HandleCommand(_ context.Context, command *Command) (*Response, error) { - return nil, fmt.Errorf("failed to handle command %s", command) -} + cancel() + wg.Wait() -type handleCommandNilCommandServerHandler struct{ testServerHandler } + assert.ErrorContains(t, err, "unknown protocol state") -func (h *handleCommandNilCommandServerHandler) CommandData(ctx context.Context, in <-chan []byte, _ *Command) error { - select { - case <-ctx.Done(): - return nil - case <-in: - return nil - } + AssertLogs(t, pHook, []logrus.Entry{ + { + Level: logrus.DebugLevel, + Message: "handling protocol state", + Data: map[string]interface{}{"state": protocolState(100)}, + }, + }) } -type respondErrServerHandler struct{ testServerHandler } +type NoopClientHandler struct{} -func (h *respondErrServerHandler) Respond(_ context.Context, r *Response, _ chan<- []byte) error { - return fmt.Errorf("failed to respond %s", r) +func (h *NoopClientHandler) Send(context.Context, *Command, chan<- []byte) error { + return nil } -type respondNilResponseServerHandler struct{ testServerHandler } - -func (h *respondNilResponseServerHandler) HandleCommand(context.Context, *Command) (*Response, error) { +func (h *NoopClientHandler) ResponseAnnounce(context.Context, <-chan []byte) (*Response, error) { return nil, nil } -func TestNewServer(t *testing.T) { - logger, _ := test.NewNullLogger() - - in := make(chan []byte) - out := make(chan []byte) - - server := NewServer(&noopServerHandler{}, in, out, logger) +func (h *NoopClientHandler) ResponseData(context.Context, <-chan []byte, *Response) error { + return nil +} - assert.NotNil(t, server) - assert.IsType(t, (*ServerProtocol)(nil), server) +func (h *NoopClientHandler) HandleResponse(context.Context, *Response) error { + return nil } -func TestServerProtocol_Handle(t *testing.T) { +func TestClientProtocol_HandleUnknownProtocolState(t *testing.T) { protoLog, pHook := test.NewNullLogger() protoLog.SetLevel(logrus.TraceLevel) - handlerLog, hHook := test.NewNullLogger() - - t.Run("initialization", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - - var err error - - p := NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - - wg := sync.WaitGroup{} - wg.Add(1) - - go func() { - err = p.Handle(ctx) - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - wg.Wait() - - assert.NoError(t, err) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - - assertLogs(t, hHook, []expectedLogs{}) - }) - - t.Run("happy-path", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - - var err error - - p := NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog) - - var response []byte - - ctx, cancel := context.WithCancel(context.Background()) - - wg := sync.WaitGroup{} - wg.Add(3) - - go func() { - err = p.Handle(ctx) - wg.Done() - }() - - go func() { - response = readServerResponse(ctx, out) - wg.Done() - }() - - go func() { - sendServerCommand(ctx, in, "dropped announcement", "test command") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - - wg.Wait() - - assert.NoError(t, err) - - assert.NotNil(t, response) - assert.NotEmpty(t, response) - assert.True(t, strings.HasPrefix(string(response), "test-response")) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respond)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "announce frame dropped announcement"}, - {logrus.InfoLevel, "command frame test command"}, - {logrus.InfoLevel, "handle command"}, - {logrus.InfoLevel, "send response"}, - }) - }) - - t.Run("command-announce-nil", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - - var err error - - p := NewServer(&commandAnnounceNilServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) - - var response []byte - - ctx, cancel := context.WithCancel(context.Background()) - - wg := sync.WaitGroup{} - wg.Add(3) - - go func() { - err = p.Handle(ctx) - wg.Done() - }() - - go func() { - response = readServerResponse(ctx, out) - wg.Done() - }() - - go func() { - sendServerCommand(ctx, in, "dropped announcement", "") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - - wg.Wait() - - assert.ErrorIs(t, err, errCommandAnnounceExpected) - assert.Nil(t, response) - - assertLogs(t, hHook, []expectedLogs{}) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, - }) - }) - - t.Run("command-announce-error", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - - var err error - - p := NewServer(&commandAnnounceErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) - - var response []byte - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(3) - - go func() { - err = p.Handle(ctx) - wg.Done() - }() - - go func() { - response = readServerResponse(ctx, out) - wg.Done() - }() - - go func() { - sendServerCommand(ctx, in, "dropped announcement", "") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.NoError(t, err) - assert.Nil(t, response) - - assertLogs(t, hHook, []expectedLogs{}) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.ErrorLevel, "could not handle command announce"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) - - t.Run("command-data-nil-announce", func(t *testing.T) { + t.Run("unknown-protocol-state", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() - hHook.Reset() }) if testing.Short() { @@ -508,1662 +206,49 @@ func TestServerProtocol_Handle(t *testing.T) { } in, out := make(chan []byte), make(chan []byte) + commands := make(chan *Command, 10) - var err error - - p := NewServer(&commandDataNilAnnouncementServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) + var ( + err error + sent []byte + ) - var response []byte + c := NewClient(&NoopClientHandler{}, commands, in, out, protoLog) + c.state = protocolState(100) ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} wg.Add(2) go func() { - err = p.Handle(ctx) + err = c.Handle(ctx) wg.Done() }() go func() { + defer wg.Done() select { case <-ctx.Done(): - break - case response = <-out: - break + return + case sent = <-out: + return } - wg.Done() }() - in <- []byte("dropped announcement") - time.Sleep(10 * time.Millisecond) - cancel() - - wg.Wait() - - assert.ErrorIs(t, err, errCommandAnnounceExpected) - assert.Nil(t, response) - - assertLogs(t, hHook, []expectedLogs{}) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, - }) - }) - - t.Run("command-data-error", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - - var err error - - p := NewServer(&commandDataErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) - - var response []byte - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(3) - - go func() { - err = p.Handle(ctx) - wg.Done() - }() - - go func() { - response = readServerResponse(ctx, out) - wg.Done() - }() - - go func() { - sendServerCommand(ctx, in, "dropped announcement", "command fails") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) cancel() wg.Wait() - assert.NoError(t, err) - assert.Nil(t, response) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "announce frame dropped announcement"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, - {logrus.ErrorLevel, "could not handle command data"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) + assert.ErrorContains(t, err, "unknown protocol state") + assert.Nil(t, sent) - t.Run("handle-command-nil-command", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() + AssertLogs(t, pHook, []logrus.Entry{ + { + Level: logrus.DebugLevel, + Message: "handling protocol state", + Data: map[string]interface{}{"state": protocolState(100)}, + }, }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - - var err error - - p := NewServer(&handleCommandNilCommandServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) - - var response []byte - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(3) - - go func() { - err = p.Handle(ctx) - wg.Done() - }() - - go func() { - response = readServerResponse(ctx, out) - wg.Done() - }() - - go func() { - sendServerCommand(ctx, in, "dropped announcement", "dropped command") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.ErrorIs(t, err, errCommandDataExpected) - assert.Nil(t, response) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "announce frame dropped announcement"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)}, - }) - }) - - t.Run("handle-command-error", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - - var err error - - p := NewServer(&handleCommandErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) - - var response []byte - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(3) - - go func() { - err = p.Handle(ctx) - wg.Done() - }() - - go func() { - response = readServerResponse(ctx, out) - wg.Done() - }() - - go func() { - sendServerCommand(ctx, in, "dropped announcement", "command fails") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.NoError(t, err) - assert.Nil(t, response) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "announce frame dropped announcement"}, - {logrus.InfoLevel, "command frame command fails"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)}, - {logrus.ErrorLevel, "could not handle command"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) - - t.Run("respond-nil-response", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - - var err error - - p := NewServer(&respondNilResponseServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) - - var response []byte - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(3) - - go func() { - err = p.Handle(ctx) - wg.Done() - }() - - go func() { - response = readServerResponse(ctx, out) - wg.Done() - }() - - go func() { - sendServerCommand(ctx, in, "dropped announcement", "dropped command") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - wg.Wait() - - assert.ErrorIs(t, err, errResponseExpected) - assert.Nil(t, response) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "announce frame dropped announcement"}, - {logrus.InfoLevel, "command frame dropped command"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respond)}, - }) - }) - - t.Run("respond-error", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - - var err error - - p := NewServer(&respondErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) - - var response []byte - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(3) - - go func() { - err = p.Handle(ctx) - wg.Done() - }() - - go func() { - response = readServerResponse(ctx, out) - wg.Done() - }() - - go func() { - sendServerCommand(ctx, in, "dropped announcement", "command data") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.NoError(t, err) - assert.Nil(t, response) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "announce frame dropped announcement"}, - {logrus.InfoLevel, "command frame command data"}, - {logrus.InfoLevel, "handle command"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respond)}, - {logrus.ErrorLevel, "could not respond"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) - - t.Run("unknown-protocol-state", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - - var err error - - p := NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog) - p.state = protocolState(100) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(1) - - go func() { - err = p.Handle(ctx) - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - wg.Wait() - - assert.ErrorContains(t, err, "unknown protocol state") - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, "handling protocol state unknown 100"}, - }) - - assertLogs(t, hHook, []expectedLogs{}) - }) -} - -func sendServerCommand(ctx context.Context, in chan []byte, announce string, command string) { - select { - case <-ctx.Done(): - return - case in <- []byte(announce): - break - } - - if command == "" { - return - } - - select { - case <-ctx.Done(): - return - case in <- []byte(command): - break - } -} - -func readServerResponse(ctx context.Context, out chan []byte) []byte { - select { - case <-ctx.Done(): - return nil - case response := <-out: - return response - } -} - -type noopClientHandler struct{} - -func (h *noopClientHandler) Send(context.Context, *Command, chan<- []byte) error { - return nil -} - -func (h *noopClientHandler) ResponseAnnounce(context.Context, <-chan []byte) (*Response, error) { - return nil, nil -} - -func (h *noopClientHandler) ResponseData(context.Context, <-chan []byte, *Response) error { - return nil -} - -func (h *noopClientHandler) HandleResponse(context.Context, *Response) error { - return nil -} - -type testClientHandler struct{ logger *logrus.Logger } - -func (h *testClientHandler) Send(ctx context.Context, command *Command, out chan<- []byte) error { - h.logger.Infof("send command %s", command.Announce.Code) - - select { - case <-ctx.Done(): - return nil - case out <- []byte(command.Announce.String()): - break - } - - outBytes, ok := command.Command.(string) - if !ok { - return fmt.Errorf("could not use command '%+v' as string", command.Command) - } - - select { - case <-ctx.Done(): - return nil - case out <- []byte(outBytes): - break - } - - return nil -} - -func (h *testClientHandler) ResponseAnnounce(ctx context.Context, in <-chan []byte) (*Response, error) { - select { - case <-ctx.Done(): - return nil, nil - case announce := <-in: - h.logger.Infof("received response announce %s", announce) - - response := &Response{Announce: messages.BuildResponseAnnounce(messages.RespUndef, string(announce))} - - return response, nil - } -} - -func (h *testClientHandler) ResponseData(ctx context.Context, in <-chan []byte, response *Response) error { - select { - case <-ctx.Done(): - return nil - case data := <-in: - h.logger.Infof("received response data %s", string(data)) - - response.Response = data - - return nil - } -} - -func (h *testClientHandler) HandleResponse(_ context.Context, response *Response) error { - h.logger.Infof("handle response %s", response.Announce.ID) - - return nil -} - -type commandAnnounceErrClientHandler struct{ testClientHandler } - -func (h *commandAnnounceErrClientHandler) Send(context.Context, *Command, chan<- []byte) error { - return errors.New("failed sending command") -} - -func TestNewClient(t *testing.T) { - logger, _ := test.NewNullLogger() - - in := make(chan []byte) - out := make(chan []byte) - - commands := make(chan *Command, 10) - - client := NewClient(&noopClientHandler{}, commands, in, out, logger) - - assert.NotNil(t, client) - assert.IsType(t, (*ClientProtocol)(nil), client) -} - -type responseAnnounceErrClientHandler struct{ testClientHandler } - -func (h *responseAnnounceErrClientHandler) ResponseAnnounce(ctx context.Context, in <-chan []byte) (*Response, error) { - select { - case <-ctx.Done(): - return nil, nil - case <-in: - return nil, errors.New("failed receiving response announce") - } -} - -type responseAnnounceNilClientHandler struct{ testClientHandler } - -func (h *responseAnnounceNilClientHandler) ResponseAnnounce(ctx context.Context, in <-chan []byte) (*Response, error) { - select { - case <-ctx.Done(): - return nil, nil - case <-in: - return nil, nil - } -} - -type responseDataErrClientHandler struct{ testClientHandler } - -func (h *responseDataErrClientHandler) ResponseData(ctx context.Context, in <-chan []byte, _ *Response) error { - select { - case <-ctx.Done(): - return nil - case <-in: - return errors.New("failed to handle response data") - } -} - -type responseDataTimeoutErrClientHandler struct{ testClientHandler } - -func (h *responseDataTimeoutErrClientHandler) ResponseData(ctx context.Context, in <-chan []byte, _ *Response) error { - select { - case <-ctx.Done(): - return nil - case <-in: - return ErrResponseDataTimeoutExpired - } -} - -type responseDataNilClientHandler struct{ testClientHandler } - -func (h *responseDataNilClientHandler) ResponseData(ctx context.Context, in <-chan []byte, response *Response) error { - select { - case <-ctx.Done(): - return nil - case <-in: - response.Response = nil - - return nil - } -} - -type handleResponseErrClientHandler struct{ testClientHandler } - -func (h *handleResponseErrClientHandler) HandleResponse(context.Context, *Response) error { - return errors.New("failed to handle response") -} - -func TestClientProtocol_Handle(t *testing.T) { //nolint:cyclop - protoLog, pHook := test.NewNullLogger() - protoLog.SetLevel(logrus.TraceLevel) - - handlerLog, hHook := test.NewNullLogger() - - t.Run("initialize", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - var ( - err error - sent []byte - ) - - c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(2) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - defer wg.Done() - select { - case <-ctx.Done(): - return - case sent = <-out: - return - } - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - wg.Wait() - - assert.NoError(t, err) - assert.Nil(t, sent) - - assertLogs(t, hHook, []expectedLogs{}) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) - - t.Run("happy-case", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - var ( - err error - announce, command []byte - ) - - c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(4) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - prepareClientCommand(ctx, commands, "test command") - wg.Done() - }() - - go func() { - announce, command = sendClientCommand(ctx, out) - wg.Done() - }() - - go func() { - receiveClientAnnounce(ctx, in, "test response announce", "test response data") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - wg.Wait() - - assert.NoError(t, err) - assert.NotNil(t, announce) - assert.NotNil(t, command) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "send command UNDEFINED"}, - {logrus.InfoLevel, "received response announce test response announce"}, - {logrus.InfoLevel, "received response data test response data"}, - {logrus.InfoLevel, "handle response test response announce"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.TraceLevel, "handled command"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleResponse)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) - - t.Run("command-announce-nil", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - close(commands) - - var ( - err error - announce, command []byte - ) - - c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(3) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - announce, command = sendClientCommand(ctx, out) - wg.Done() - }() - - go func() { - receiveClientAnnounce(ctx, in, "test response announce", "test response data") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - wg.Wait() - - assert.Error(t, err) - assert.Nil(t, announce) - assert.Nil(t, command) - - assertLogs(t, hHook, []expectedLogs{}) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) - - t.Run("command-announce-error", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - var ( - err error - announce, command []byte - ) - - c := NewClient(&commandAnnounceErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(4) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - announce, command = sendClientCommand(ctx, out) - wg.Done() - }() - - go func() { - prepareClientCommand(ctx, commands, "test command") - wg.Done() - }() - - go func() { - receiveClientAnnounce(ctx, in, "test response announce", "test response data") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.NoError(t, err) - assert.Nil(t, announce) - assert.Nil(t, command) - - assertLogs(t, hHook, []expectedLogs{}) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.ErrorLevel, "could not send command announce"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) - - t.Run("response-announce-error", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - var ( - err error - announce, command []byte - ) - - c := NewClient(&responseAnnounceErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(4) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - announce, command = sendClientCommand(ctx, out) - wg.Done() - }() - - go func() { - prepareClientCommand(ctx, commands, "test command") - wg.Done() - }() - - go func() { - receiveClientAnnounce(ctx, in, "test response announce", "test response data") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.NoError(t, err) - assert.NotNil(t, announce) - assert.NotNil(t, command) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "send command UNDEFINED"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.TraceLevel, "handled command"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, - {logrus.ErrorLevel, "could not handle response announce"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) - - t.Run("response-data-nil", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - var ( - err error - announce, command []byte - ) - - c := NewClient(&responseAnnounceNilClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(4) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - announce, command = sendClientCommand(ctx, out) - wg.Done() - }() - - go func() { - prepareClientCommand(ctx, commands, "test command") - wg.Done() - }() - - go func() { - receiveClientAnnounce(ctx, in, "test response announce", "test response data") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.ErrorIs(t, err, errResponseExpected) - assert.NotNil(t, announce) - assert.NotNil(t, command) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "send command UNDEFINED"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.TraceLevel, "handled command"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, - }) - }) - - t.Run("response-data-error", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - var ( - err error - announce, command []byte - ) - - c := NewClient(&responseDataErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(4) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - announce, command = sendClientCommand(ctx, out) - wg.Done() - }() - - go func() { - prepareClientCommand(ctx, commands, "test command") - wg.Done() - }() - - go func() { - receiveClientAnnounce(ctx, in, "test response announce", "test response data") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.NoError(t, err) - assert.NotNil(t, announce) - assert.NotNil(t, command) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "send command UNDEFINED"}, - {logrus.InfoLevel, "received response announce test response announce"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.TraceLevel, "handled command"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, - {logrus.ErrorLevel, "could not handle response data"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, - }) - }) - - t.Run("response-data-timeout-error", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - var ( - err error - announce, command []byte - ) - - c := NewClient(&responseDataTimeoutErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(4) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - announce, command = sendClientCommand(ctx, out) - wg.Done() - }() - - go func() { - prepareClientCommand(ctx, commands, "test command") - wg.Done() - }() - - go func() { - receiveClientAnnounce(ctx, in, "test response announce", "test response data") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.NoError(t, err) - assert.NotNil(t, announce) - assert.NotNil(t, command) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "send command UNDEFINED"}, - {logrus.InfoLevel, "received response announce test response announce"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.TraceLevel, "handled command"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, - {logrus.ErrorLevel, "could not handle response data"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) - - t.Run("handle-response-nil", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - var ( - err error - announce, command []byte - ) - - c := NewClient(&responseDataNilClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(4) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - announce, command = sendClientCommand(ctx, out) - wg.Done() - }() - - go func() { - prepareClientCommand(ctx, commands, "test command") - wg.Done() - }() - - go func() { - receiveClientAnnounce(ctx, in, "test response announce", "test response data") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.ErrorIs(t, err, errResponseExpected) - assert.NotNil(t, announce) - assert.NotNil(t, command) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "send command UNDEFINED"}, - {logrus.InfoLevel, "received response announce test response announce"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.TraceLevel, "handled command"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleResponse)}, - }) - }) - - t.Run("handle-response-error", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - var ( - err error - announce, command []byte - ) - - c := NewClient(&handleResponseErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(4) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - announce, command = sendClientCommand(ctx, out) - wg.Done() - }() - - go func() { - prepareClientCommand(ctx, commands, "test command") - wg.Done() - }() - - go func() { - receiveClientAnnounce(ctx, in, "test response announce", "test response data") - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - cancel() - wg.Wait() - - assert.NoError(t, err) - assert.NotNil(t, announce) - assert.NotNil(t, command) - - assertLogs(t, hHook, []expectedLogs{ - {logrus.InfoLevel, "send command UNDEFINED"}, - {logrus.InfoLevel, "received response announce test response announce"}, - {logrus.InfoLevel, "received response data test response data"}, - }) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - {logrus.TraceLevel, "handled command"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleResponse)}, - {logrus.ErrorLevel, "could not handle response"}, - {logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)}, - }) - }) - - t.Run("unknown-protocol-state", func(t *testing.T) { - t.Cleanup(func() { - pHook.Reset() - hHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - in, out := make(chan []byte), make(chan []byte) - commands := make(chan *Command, 10) - - var ( - err error - sent []byte - ) - - c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) - c.state = protocolState(100) - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(2) - - go func() { - err = c.Handle(ctx) - wg.Done() - }() - - go func() { - defer wg.Done() - select { - case <-ctx.Done(): - return - case sent = <-out: - return - } - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - wg.Wait() - - assert.ErrorContains(t, err, "unknown protocol state") - assert.Nil(t, sent) - - assertLogs(t, hHook, []expectedLogs{}) - - assertLogs(t, pHook, []expectedLogs{ - {logrus.DebugLevel, "handling protocol state unknown 100"}, - }) - }) -} - -func prepareClientCommand(ctx context.Context, commands chan *Command, command string) { - select { - case <-ctx.Done(): - return - case commands <- &Command{Announce: messages.BuildCommandAnnounce(messages.CmdUndef), Command: command}: - return - } -} - -func receiveClientAnnounce(ctx context.Context, in chan []byte, announce, data string) { - select { - case <-ctx.Done(): - return - case in <- []byte(announce): - break - } - - select { - case <-ctx.Done(): - return - case in <- []byte(data): - break - } -} - -func sendClientCommand(ctx context.Context, out chan []byte) ([]byte, []byte) { - var announce, command []byte - - select { - case <-ctx.Done(): - return nil, nil - case announce = <-out: - break - } - - select { - case <-ctx.Done(): - return announce, nil - case command = <-out: - return announce, command - } -} - -func TestNewCOBSFramer(t *testing.T) { - logger, _ := test.NewNullLogger() - - framer, err := NewCOBSFramer(logger) - - assert.NoError(t, err) - - require.NotNil(t, framer) - assert.IsType(t, (*COBSFramer)(nil), framer) - assert.Implements(t, (*Framer)(nil), framer) -} - -func TestCOBSFramer_ReadFrames(t *testing.T) { - logger, loggerHook := test.NewNullLogger() - logger.SetLevel(logrus.TraceLevel) - - framer, err := NewCOBSFramer(logger) - - require.NoError(t, err) - require.NotNil(t, framer) - - t.Run("read error", func(t *testing.T) { - t.Cleanup(func() { - loggerHook.Reset() - }) - - readFrames := make(chan []byte) - - testError := errors.New("test error") - - reader := iotest.ErrReader(testError) - - err := framer.ReadFrames(context.Background(), reader, readFrames) - - assert.ErrorIs(t, err, testError) - - frame := <-readFrames - - assert.Nil(t, frame) - - assert.Empty(t, loggerHook.AllEntries()) - }) - - t.Run("no bytes", func(t *testing.T) { - t.Cleanup(func() { - loggerHook.Reset() - }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - readFrames := make(chan []byte) - - reader := &bytes.Buffer{} - - var err error - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(1) - - go func() { - err = framer.ReadFrames(ctx, reader, readFrames) - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - wg.Wait() - - assert.NoError(t, err) - - assert.Empty(t, loggerHook.AllEntries()) - }) - - t.Run("incomplete bytes", func(t *testing.T) { - t.Cleanup(func() { loggerHook.Reset() }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - readFrames := make(chan []byte) - - reader := strings.NewReader("some bytes") - - var err error - - ctx, cancel := context.WithCancel(context.Background()) - - wg := sync.WaitGroup{} - wg.Add(1) - - go func() { - err = framer.ReadFrames(ctx, reader, readFrames) - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - wg.Wait() - - assert.Nil(t, err) - - logEntries := loggerHook.AllEntries() - - require.Len(t, logEntries, 2) - assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) - assert.Equal(t, "read 10 raw bytes", logEntries[0].Message) - assert.Equal(t, logrus.TraceLevel, logEntries[1].Level) - assert.Equal(t, "read buffer is now 10 bytes long", logEntries[1].Message) - }) - - t.Run("invalid bytes", func(t *testing.T) { - t.Cleanup(func() { loggerHook.Reset() }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - readFrames := make(chan []byte) - - reader := bytes.NewBuffer([]byte("some bytes")) - - reader.WriteByte(CobsDelimiter) - - var err error - - ctx, cancel := context.WithCancel(context.Background()) - wg := sync.WaitGroup{} - wg.Add(1) - - go func() { - err = framer.ReadFrames(ctx, reader, readFrames) - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - wg.Wait() - - assert.NoError(t, err) - - logEntries := loggerHook.AllEntries() - - require.Len(t, logEntries, 3) - assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) - assert.Equal(t, "read 11 raw bytes", logEntries[0].Message) - assert.Equal(t, logrus.WarnLevel, logEntries[1].Level) - assert.Equal(t, "skipping invalid frame of 11 bytes", logEntries[1].Message) - assert.Equal(t, logrus.TraceLevel, logEntries[2].Level) - assert.Equal(t, "read buffer is now 0 bytes long", logEntries[2].Message) - }) - - t.Run("valid frame", func(t *testing.T) { - t.Cleanup(func() { loggerHook.Reset() }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - readFrames := make(chan []byte) - - reader := bytes.NewBuffer(framer.encoder.Encode([]byte("some bytes"))) - - var ( - err error - frame []byte - ) - - ctx, cancel := context.WithCancel(context.Background()) - - wg := sync.WaitGroup{} - wg.Add(2) - - go func() { - err = framer.ReadFrames(ctx, reader, readFrames) - wg.Done() - }() - - go func() { - frame = <-readFrames - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - - wg.Wait() - - assert.NoError(t, err) - - assert.NotNil(t, frame) - if frame != nil { - assert.Equal(t, []byte("some bytes"), frame) - } - - logEntries := loggerHook.AllEntries() - - require.Len(t, logEntries, 3) - assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) - assert.Contains(t, logEntries[0].Message, "raw bytes") - assert.Equal(t, logrus.TraceLevel, logEntries[1].Level) - assert.Equal(t, "frame decoded to length 10", logEntries[1].Message) - assert.Equal(t, logrus.TraceLevel, logEntries[2].Level) - assert.Equal(t, "read buffer is now 0 bytes long", logEntries[2].Message) - }) -} - -type brokenWriter struct{} - -var errB0rk3d = errors.New("you b0rk3d it") - -func (b brokenWriter) Write([]byte) (int, error) { - return 0, errB0rk3d -} - -func TestCOBSFramer_WriteFrames(t *testing.T) { - logger, loggerHook := test.NewNullLogger() - logger.SetLevel(logrus.TraceLevel) - - framer, err := NewCOBSFramer(logger) - - require.NoError(t, err) - require.NotNil(t, framer) - - t.Run("closed channel", func(t *testing.T) { - t.Cleanup(func() { loggerHook.Reset() }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - var result error - - in := make(chan []byte) - - close(in) - - ctx, cancel := context.WithCancel(context.Background()) - - wg := sync.WaitGroup{} - - wg.Add(1) - - go func() { - result = framer.WriteFrames(ctx, io.Discard, in) - - wg.Done() - }() - - time.Sleep(10 * time.Millisecond) - - cancel() - - wg.Wait() - - assert.Nil(t, result) - - logEntries := loggerHook.AllEntries() - - assert.Len(t, logEntries, 1) - assert.Equal(t, logrus.DebugLevel, logEntries[0].Level) - assert.Equal(t, "channel closed", logEntries[0].Message) - }) - - t.Run("closed writer", func(t *testing.T) { - t.Cleanup(func() { loggerHook.Reset() }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - var result error - - in := make(chan []byte) - - ctx, cancel := context.WithCancel(context.Background()) - - wg := sync.WaitGroup{} - - wg.Add(1) - - go func() { - result = framer.WriteFrames(ctx, &brokenWriter{}, in) - - wg.Done() - }() - - in <- []byte("test message") - - time.Sleep(10 * time.Millisecond) - - cancel() - - wg.Wait() - - assert.NotNil(t, result) - assert.True(t, errors.Is(result, errB0rk3d)) - - assert.Len(t, loggerHook.AllEntries(), 0) - }) - - t.Run("valid frame", func(t *testing.T) { - t.Cleanup(func() { loggerHook.Reset() }) - - if testing.Short() { - t.Skip("skipping test in short mode") - } - - var err error - - in := make(chan []byte) - - out := &bytes.Buffer{} - - ctx, cancel := context.WithCancel(context.Background()) - - wg := sync.WaitGroup{} - wg.Add(1) - - go func() { - err = framer.WriteFrames(ctx, out, in) - - wg.Done() - }() - - in <- []byte("test message") - - time.Sleep(10 * time.Millisecond) - - cancel() - - wg.Wait() - - assert.Nil(t, err) - - logEntries := loggerHook.AllEntries() - assert.Len(t, logEntries, 1) - assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) - - frame := out.Bytes() - assert.NoError(t, framer.encoder.Verify(frame)) - assert.Equal(t, []byte("test message"), framer.encoder.Decode(out.Bytes())) }) }