/* Copyright 2022 CAcert Inc. SPDX-License-Identifier: Apache-2.0 Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Package protocol_test provides tests for the public API of the protocol package. package protocol_test import ( "bytes" "context" "errors" "fmt" "io" "strings" "sync" "testing" "testing/iotest" "time" "github.com/justincpresley/go-cobs" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "git.cacert.org/cacert-gosigner/pkg/messages" "git.cacert.org/cacert-gosigner/pkg/protocol" ) func TestNewServer(t *testing.T) { logger, _ := test.NewNullLogger() in := make(chan []byte) out := make(chan []byte) server := protocol.NewServer(&protocol.NoopServerHandler{}, in, out, logger) assert.NotNil(t, server) assert.IsType(t, (*protocol.ServerProtocol)(nil), server) } type testServerHandler struct { logger *logrus.Logger } func (h *testServerHandler) CommandAnnounce(ctx context.Context, in <-chan []byte) (*protocol.Command, error) { select { case <-ctx.Done(): return nil, nil case frame := <-in: h.logger.Infof("announce frame %s", string(frame)) return &protocol.Command{Announce: messages.BuildCommandAnnounce(messages.CmdUndef)}, nil } } func (h *testServerHandler) CommandData(ctx context.Context, in <-chan []byte, command *protocol.Command) error { select { case <-ctx.Done(): return nil case frame := <-in: h.logger.Infof("command frame %s", string(frame)) command.Command = frame return nil } } func (h *testServerHandler) HandleCommand(_ context.Context, command *protocol.Command) (*protocol.Response, error) { h.logger.Info("handle command") return &protocol.Response{ Announce: messages.BuildResponseAnnounce(messages.RespUndef, command.Announce.ID), Response: fmt.Sprintf("response for command %s", command.Command), }, nil } func (h *testServerHandler) Respond(ctx context.Context, response *protocol.Response, out chan<- []byte) error { h.logger.Info("send response") buf := bytes.NewBuffer([]byte("test-response-")) buf.WriteString(response.Announce.ID) select { case <-ctx.Done(): return nil case out <- buf.Bytes(): return nil } } type commandAnnounceErrServerHandler struct{ testServerHandler } func (h *commandAnnounceErrServerHandler) CommandAnnounce( ctx context.Context, in <-chan []byte, ) (*protocol.Command, error) { select { case <-ctx.Done(): return nil, nil case announce := <-in: return nil, fmt.Errorf("failed to handle announce %s", announce) } } type commandAnnounceNilServerHandler struct{ testServerHandler } func (h *commandAnnounceNilServerHandler) CommandAnnounce( ctx context.Context, in <-chan []byte, ) (*protocol.Command, error) { select { case <-ctx.Done(): return nil, nil case <-in: return nil, nil } } type commandDataErrServerHandler struct{ testServerHandler } func (h *commandDataErrServerHandler) CommandData(ctx context.Context, in <-chan []byte, _ *protocol.Command) error { select { case <-ctx.Done(): return nil case data := <-in: return fmt.Errorf("failed to handle command data %s", data) } } type commandDataNilAnnouncementServerHandler struct{ testServerHandler } func (h *commandDataNilAnnouncementServerHandler) CommandAnnounce( ctx context.Context, in <-chan []byte, ) (*protocol.Command, error) { select { case <-ctx.Done(): return nil, nil case <-in: return &protocol.Command{}, nil } } type handleCommandErrServerHandler struct{ testServerHandler } func (h *handleCommandErrServerHandler) HandleCommand( _ context.Context, command *protocol.Command, ) (*protocol.Response, error) { return nil, fmt.Errorf("failed to handle command %s", command) } type handleCommandNilCommandServerHandler struct{ testServerHandler } func (h *handleCommandNilCommandServerHandler) CommandData( ctx context.Context, in <-chan []byte, _ *protocol.Command, ) error { select { case <-ctx.Done(): return nil case <-in: return nil } } type respondErrServerHandler struct{ testServerHandler } func (h *respondErrServerHandler) Respond(_ context.Context, r *protocol.Response, _ chan<- []byte) error { return fmt.Errorf("failed to respond %s", r) } type respondNilResponseServerHandler struct{ testServerHandler } func (h *respondNilResponseServerHandler) HandleCommand( context.Context, *protocol.Command, ) (*protocol.Response, error) { return nil, nil } func TestServerProtocol_Handle(t *testing.T) { protoLog, pHook := test.NewNullLogger() protoLog.SetLevel(logrus.TraceLevel) handlerLog, hHook := test.NewNullLogger() t.Run("initialization", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) var err error p := protocol.NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) go func() { err = p.Handle(ctx) wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) protocol.AssertLogs(t, pHook, []logrus.Entry{ { Level: logrus.DebugLevel, Message: "handling protocol state", }, }) protocol.AssertLogs(t, hHook, []logrus.Entry{}) }) t.Run("happy-path", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) var err error p := protocol.NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog) var response []byte ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(3) go func() { err = p.Handle(ctx) wg.Done() }() go func() { response = readServerResponse(ctx, out) wg.Done() }() go func() { sendServerCommand(ctx, in, "dropped announcement", "test command") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.NotNil(t, response) assert.NotEmpty(t, response) assert.True(t, strings.HasPrefix(string(response), "test-response")) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, {Level: logrus.InfoLevel, Message: "command frame test command"}, {Level: logrus.InfoLevel, Message: "handle command"}, {Level: logrus.InfoLevel, Message: "send response"}, }) }) t.Run("command-announce-nil", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) var err error p := protocol.NewServer(&commandAnnounceNilServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) var response []byte ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(3) go func() { err = p.Handle(ctx) wg.Done() }() go func() { response = readServerResponse(ctx, out) wg.Done() }() go func() { sendServerCommand(ctx, in, "dropped announcement", "") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.ErrorIs(t, err, protocol.ErrCommandAnnounceExpected) assert.Nil(t, response) protocol.AssertLogs(t, hHook, []logrus.Entry{}) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("command-announce-error", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) var err error p := protocol.NewServer(&commandAnnounceErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) var response []byte ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(3) go func() { err = p.Handle(ctx) wg.Done() }() go func() { response = readServerResponse(ctx, out) wg.Done() }() go func() { sendServerCommand(ctx, in, "dropped announcement", "") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.Nil(t, response) protocol.AssertLogs(t, hHook, []logrus.Entry{}) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.ErrorLevel, Message: "could not handle command announce"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("command-data-nil-announce", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) var err error p := protocol.NewServer( &commandDataNilAnnouncementServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog, ) var response []byte ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(2) go func() { err = p.Handle(ctx) wg.Done() }() go func() { select { case <-ctx.Done(): break case response = <-out: break } wg.Done() }() in <- []byte("dropped announcement") time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.ErrorIs(t, err, protocol.ErrCommandAnnounceExpected) assert.Nil(t, response) protocol.AssertLogs(t, hHook, []logrus.Entry{}) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("command-data-error", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) var err error p := protocol.NewServer(&commandDataErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) var response []byte ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(3) go func() { err = p.Handle(ctx) wg.Done() }() go func() { response = readServerResponse(ctx, out) wg.Done() }() go func() { sendServerCommand(ctx, in, "dropped announcement", "command fails") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.Nil(t, response) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.ErrorLevel, Message: "could not handle command data"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("handle-command-nil-command", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) var err error p := protocol.NewServer( &handleCommandNilCommandServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog, ) var response []byte ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(3) go func() { err = p.Handle(ctx) wg.Done() }() go func() { response = readServerResponse(ctx, out) wg.Done() }() go func() { sendServerCommand(ctx, in, "dropped announcement", "dropped command") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.ErrorIs(t, err, protocol.ErrCommandDataExpected) assert.Nil(t, response) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("handle-command-error", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) var err error p := protocol.NewServer(&handleCommandErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) var response []byte ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(3) go func() { err = p.Handle(ctx) wg.Done() }() go func() { response = readServerResponse(ctx, out) wg.Done() }() go func() { sendServerCommand(ctx, in, "dropped announcement", "command fails") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.Nil(t, response) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, {Level: logrus.InfoLevel, Message: "command frame command fails"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.ErrorLevel, Message: "could not handle command"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("respond-nil-response", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) var err error p := protocol.NewServer( &respondNilResponseServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog, ) var response []byte ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(3) go func() { err = p.Handle(ctx) wg.Done() }() go func() { response = readServerResponse(ctx, out) wg.Done() }() go func() { sendServerCommand(ctx, in, "dropped announcement", "dropped command") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.ErrorIs(t, err, protocol.ErrResponseExpected) assert.Nil(t, response) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, {Level: logrus.InfoLevel, Message: "command frame dropped command"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("respond-error", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) var err error p := protocol.NewServer(&respondErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog) var response []byte ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(3) go func() { err = p.Handle(ctx) wg.Done() }() go func() { response = readServerResponse(ctx, out) wg.Done() }() go func() { sendServerCommand(ctx, in, "dropped announcement", "command data") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.Nil(t, response) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "announce frame dropped announcement"}, {Level: logrus.InfoLevel, Message: "command frame command data"}, {Level: logrus.InfoLevel, Message: "handle command"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.ErrorLevel, Message: "could not respond"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) } func sendServerCommand(ctx context.Context, in chan []byte, announce string, command string) { select { case <-ctx.Done(): return case in <- []byte(announce): break } if command == "" { return } select { case <-ctx.Done(): return case in <- []byte(command): break } } func readServerResponse(ctx context.Context, out chan []byte) []byte { select { case <-ctx.Done(): return nil case response := <-out: return response } } type testClientHandler struct{ logger *logrus.Logger } func (h *testClientHandler) Send(ctx context.Context, command *protocol.Command, out chan<- []byte) error { h.logger.Infof("send command %s", command.Announce.Code) select { case <-ctx.Done(): return nil case out <- []byte(command.Announce.String()): break } outBytes, ok := command.Command.(string) if !ok { return fmt.Errorf("could not use command '%+v' as string", command.Command) } select { case <-ctx.Done(): return nil case out <- []byte(outBytes): break } return nil } func (h *testClientHandler) ResponseAnnounce(ctx context.Context, in <-chan []byte) (*protocol.Response, error) { select { case <-ctx.Done(): return nil, nil case announce := <-in: h.logger.Infof("received response announce %s", announce) response := &protocol.Response{Announce: messages.BuildResponseAnnounce(messages.RespUndef, string(announce))} return response, nil } } func (h *testClientHandler) ResponseData(ctx context.Context, in <-chan []byte, response *protocol.Response) error { select { case <-ctx.Done(): return nil case data := <-in: h.logger.Infof("received response data %s", string(data)) response.Response = data return nil } } func (h *testClientHandler) HandleResponse(_ context.Context, response *protocol.Response) error { h.logger.Infof("handle response %s", response.Announce.ID) return nil } type commandAnnounceErrClientHandler struct{ testClientHandler } func (h *commandAnnounceErrClientHandler) Send(context.Context, *protocol.Command, chan<- []byte) error { return errors.New("failed sending command") } func TestNewClient(t *testing.T) { logger, _ := test.NewNullLogger() in := make(chan []byte) out := make(chan []byte) commands := make(chan *protocol.Command, 10) client := protocol.NewClient(&protocol.NoopClientHandler{}, commands, in, out, logger) assert.NotNil(t, client) assert.IsType(t, (*protocol.ClientProtocol)(nil), client) } type responseAnnounceErrClientHandler struct{ testClientHandler } func (h *responseAnnounceErrClientHandler) ResponseAnnounce( ctx context.Context, in <-chan []byte, ) (*protocol.Response, error) { select { case <-ctx.Done(): return nil, nil case <-in: return nil, errors.New("failed receiving response announce") } } type responseAnnounceNilClientHandler struct{ testClientHandler } func (h *responseAnnounceNilClientHandler) ResponseAnnounce( ctx context.Context, in <-chan []byte, ) (*protocol.Response, error) { select { case <-ctx.Done(): return nil, nil case <-in: return nil, nil } } type responseDataErrClientHandler struct{ testClientHandler } func (h *responseDataErrClientHandler) ResponseData(ctx context.Context, in <-chan []byte, _ *protocol.Response) error { select { case <-ctx.Done(): return nil case <-in: return errors.New("failed to handle response data") } } type responseDataTimeoutErrClientHandler struct{ testClientHandler } func (h *responseDataTimeoutErrClientHandler) ResponseData( ctx context.Context, in <-chan []byte, _ *protocol.Response, ) error { select { case <-ctx.Done(): return nil case <-in: return protocol.ErrResponseDataTimeoutExpired } } type responseDataNilClientHandler struct{ testClientHandler } func (h *responseDataNilClientHandler) ResponseData( ctx context.Context, in <-chan []byte, response *protocol.Response, ) error { select { case <-ctx.Done(): return nil case <-in: response.Response = nil return nil } } type handleResponseErrClientHandler struct{ testClientHandler } func (h *handleResponseErrClientHandler) HandleResponse(context.Context, *protocol.Response) error { return errors.New("failed to handle response") } func TestClientProtocol_Handle(t *testing.T) { protoLog, pHook := test.NewNullLogger() protoLog.SetLevel(logrus.TraceLevel) handlerLog, hHook := test.NewNullLogger() t.Run("initialize", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) commands := make(chan *protocol.Command, 10) var ( err error sent []byte ) c := protocol.NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(2) go func() { err = c.Handle(ctx) wg.Done() }() go func() { defer wg.Done() select { case <-ctx.Done(): return case sent = <-out: return } }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.Nil(t, sent) protocol.AssertLogs(t, hHook, []logrus.Entry{}) protocol.AssertLogs(t, pHook, []logrus.Entry{ { Level: logrus.DebugLevel, Message: "handling protocol state", }, }) }) t.Run("happy-case", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) commands := make(chan *protocol.Command, 10) var ( err error announce, command []byte ) c := protocol.NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(4) go func() { err = c.Handle(ctx) wg.Done() }() go func() { prepareClientCommand(ctx, commands, "test command") wg.Done() }() go func() { announce, command = sendClientCommand(ctx, out) wg.Done() }() go func() { receiveClientAnnounce(ctx, in, "test response announce", "test response data") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.NotNil(t, announce) assert.NotNil(t, command) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, {Level: logrus.InfoLevel, Message: "received response announce test response announce"}, {Level: logrus.InfoLevel, Message: "received response data test response data"}, {Level: logrus.InfoLevel, Message: "handle response test response announce"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.TraceLevel, Message: "handled command"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("command-announce-nil", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) commands := make(chan *protocol.Command, 10) close(commands) var ( err error announce, command []byte ) c := protocol.NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(3) go func() { err = c.Handle(ctx) wg.Done() }() go func() { announce, command = sendClientCommand(ctx, out) wg.Done() }() go func() { receiveClientAnnounce(ctx, in, "test response announce", "test response data") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.Error(t, err) assert.Nil(t, announce) assert.Nil(t, command) protocol.AssertLogs(t, hHook, []logrus.Entry{}) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("command-announce-error", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) commands := make(chan *protocol.Command, 10) var ( err error announce, command []byte ) c := protocol.NewClient(&commandAnnounceErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(4) go func() { err = c.Handle(ctx) wg.Done() }() go func() { announce, command = sendClientCommand(ctx, out) wg.Done() }() go func() { prepareClientCommand(ctx, commands, "test command") wg.Done() }() go func() { receiveClientAnnounce(ctx, in, "test response announce", "test response data") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.Nil(t, announce) assert.Nil(t, command) protocol.AssertLogs(t, hHook, []logrus.Entry{}) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.ErrorLevel, Message: "could not send command announce"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("response-announce-error", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) commands := make(chan *protocol.Command, 10) var ( err error announce, command []byte ) c := protocol.NewClient(&responseAnnounceErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(4) go func() { err = c.Handle(ctx) wg.Done() }() go func() { announce, command = sendClientCommand(ctx, out) wg.Done() }() go func() { prepareClientCommand(ctx, commands, "test command") wg.Done() }() go func() { receiveClientAnnounce(ctx, in, "test response announce", "test response data") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.NotNil(t, announce) assert.NotNil(t, command) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.TraceLevel, Message: "handled command"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.ErrorLevel, Message: "could not handle response announce"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("response-data-nil", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) commands := make(chan *protocol.Command, 10) var ( err error announce, command []byte ) c := protocol.NewClient(&responseAnnounceNilClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(4) go func() { err = c.Handle(ctx) wg.Done() }() go func() { announce, command = sendClientCommand(ctx, out) wg.Done() }() go func() { prepareClientCommand(ctx, commands, "test command") wg.Done() }() go func() { receiveClientAnnounce(ctx, in, "test response announce", "test response data") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.ErrorIs(t, err, protocol.ErrResponseExpected) assert.NotNil(t, announce) assert.NotNil(t, command) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, }) protocol.AssertLogs( t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.TraceLevel, Message: "handled command"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }, ) }) t.Run("response-data-error", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) commands := make(chan *protocol.Command, 10) var ( err error announce, command []byte ) c := protocol.NewClient(&responseDataErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(4) go func() { err = c.Handle(ctx) wg.Done() }() go func() { announce, command = sendClientCommand(ctx, out) wg.Done() }() go func() { prepareClientCommand(ctx, commands, "test command") wg.Done() }() go func() { receiveClientAnnounce(ctx, in, "test response announce", "test response data") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.NotNil(t, announce) assert.NotNil(t, command) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, {Level: logrus.InfoLevel, Message: "received response announce test response announce"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.TraceLevel, Message: "handled command"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.ErrorLevel, Message: "could not handle response data"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("response-data-timeout-error", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) commands := make(chan *protocol.Command, 10) var ( err error announce, command []byte ) c := protocol.NewClient( &responseDataTimeoutErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog, ) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(4) go func() { err = c.Handle(ctx) wg.Done() }() go func() { announce, command = sendClientCommand(ctx, out) wg.Done() }() go func() { prepareClientCommand(ctx, commands, "test command") wg.Done() }() go func() { receiveClientAnnounce(ctx, in, "test response announce", "test response data") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.NotNil(t, announce) assert.NotNil(t, command) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, {Level: logrus.InfoLevel, Message: "received response announce test response announce"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.TraceLevel, Message: "handled command"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.ErrorLevel, Message: "could not handle response data"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("handle-response-nil", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) commands := make(chan *protocol.Command, 10) var ( err error announce, command []byte ) c := protocol.NewClient(&responseDataNilClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(4) go func() { err = c.Handle(ctx) wg.Done() }() go func() { announce, command = sendClientCommand(ctx, out) wg.Done() }() go func() { prepareClientCommand(ctx, commands, "test command") wg.Done() }() go func() { receiveClientAnnounce(ctx, in, "test response announce", "test response data") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.ErrorIs(t, err, protocol.ErrResponseExpected) assert.NotNil(t, announce) assert.NotNil(t, command) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, {Level: logrus.InfoLevel, Message: "received response announce test response announce"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.TraceLevel, Message: "handled command"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) t.Run("handle-response-error", func(t *testing.T) { t.Cleanup(func() { pHook.Reset() hHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } in, out := make(chan []byte), make(chan []byte) commands := make(chan *protocol.Command, 10) var ( err error announce, command []byte ) c := protocol.NewClient(&handleResponseErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(4) go func() { err = c.Handle(ctx) wg.Done() }() go func() { announce, command = sendClientCommand(ctx, out) wg.Done() }() go func() { prepareClientCommand(ctx, commands, "test command") wg.Done() }() go func() { receiveClientAnnounce(ctx, in, "test response announce", "test response data") wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.NotNil(t, announce) assert.NotNil(t, command) protocol.AssertLogs(t, hHook, []logrus.Entry{ {Level: logrus.InfoLevel, Message: "send command UNDEFINED"}, {Level: logrus.InfoLevel, Message: "received response announce test response announce"}, {Level: logrus.InfoLevel, Message: "received response data test response data"}, }) protocol.AssertLogs(t, pHook, []logrus.Entry{ {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.TraceLevel, Message: "handled command"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, {Level: logrus.ErrorLevel, Message: "could not handle response"}, {Level: logrus.DebugLevel, Message: "handling protocol state"}, }) }) } func prepareClientCommand(ctx context.Context, commands chan *protocol.Command, command string) { select { case <-ctx.Done(): return case commands <- &protocol.Command{Announce: messages.BuildCommandAnnounce(messages.CmdUndef), Command: command}: return } } func receiveClientAnnounce(ctx context.Context, in chan []byte, announce, data string) { select { case <-ctx.Done(): return case in <- []byte(announce): break } select { case <-ctx.Done(): return case in <- []byte(data): break } } func sendClientCommand(ctx context.Context, out chan []byte) ([]byte, []byte) { var announce, command []byte select { case <-ctx.Done(): return nil, nil case announce = <-out: break } select { case <-ctx.Done(): return announce, nil case command = <-out: return announce, command } } func TestNewCOBSFramer(t *testing.T) { logger, _ := test.NewNullLogger() framer, err := protocol.NewCOBSFramer(logger) assert.NoError(t, err) require.NotNil(t, framer) assert.IsType(t, (*protocol.COBSFramer)(nil), framer) assert.Implements(t, (*protocol.Framer)(nil), framer) } type slowTestReader struct{} func (s slowTestReader) Read(_ []byte) (n int, err error) { time.Sleep(1 * time.Second) return 0, errors.New("you waited too long") } func TestCOBSFramer_ReadFrames(t *testing.T) { logger, loggerHook := test.NewNullLogger() logger.SetLevel(logrus.TraceLevel) framer, err := protocol.NewCOBSFramer(logger) require.NoError(t, err) require.NotNil(t, framer) t.Run("read error", func(t *testing.T) { t.Cleanup(func() { loggerHook.Reset() }) readFrames := make(chan []byte) testError := errors.New("test error") reader := iotest.ErrReader(testError) err := framer.ReadFrames(context.Background(), reader, readFrames) assert.ErrorIs(t, err, testError) frame := <-readFrames assert.Nil(t, frame) assert.Empty(t, loggerHook.AllEntries()) }) t.Run("no bytes", func(t *testing.T) { t.Cleanup(func() { loggerHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } readFrames := make(chan []byte) reader := &bytes.Buffer{} var err error ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) go func() { err = framer.ReadFrames(ctx, reader, readFrames) wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.Empty(t, loggerHook.AllEntries()) }) t.Run("incomplete bytes", func(t *testing.T) { t.Cleanup(func() { loggerHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } readFrames := make(chan []byte) reader := strings.NewReader("some bytes") var err error ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) go func() { err = framer.ReadFrames(ctx, reader, readFrames) wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.Nil(t, err) logEntries := loggerHook.AllEntries() require.Len(t, logEntries, 2) assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) assert.Equal(t, "read 10 raw bytes", logEntries[0].Message) assert.Equal(t, logrus.TraceLevel, logEntries[1].Level) assert.Equal(t, "read buffer is now 10 bytes long", logEntries[1].Message) }) t.Run("invalid bytes", func(t *testing.T) { t.Cleanup(func() { loggerHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } readFrames := make(chan []byte) reader := bytes.NewBuffer([]byte("some bytes")) reader.WriteByte(protocol.COBSDelimiter) var err error ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) go func() { err = framer.ReadFrames(ctx, reader, readFrames) wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) logEntries := loggerHook.AllEntries() require.Len(t, logEntries, 3) assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) assert.Equal(t, "read 11 raw bytes", logEntries[0].Message) assert.Equal(t, logrus.WarnLevel, logEntries[1].Level) assert.Equal(t, "skipping invalid frame of 11 bytes", logEntries[1].Message) assert.Equal(t, logrus.TraceLevel, logEntries[2].Level) assert.Equal(t, "read buffer is now 0 bytes long", logEntries[2].Message) }) t.Run("closed context", func(t *testing.T) { t.Cleanup(func() { loggerHook.Reset() }) framer, err := protocol.NewCOBSFramer(logger) assert.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) go func() { readFrames := make(chan []byte) err = framer.ReadFrames(ctx, &slowTestReader{}, readFrames) wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() require.NoError(t, err) }) t.Run("valid frame", func(t *testing.T) { t.Cleanup(func() { loggerHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } readFrames := make(chan []byte) encoder, _ := cobs.NewEncoder(protocol.COBSConfig) reader := bytes.NewBuffer(encoder.Encode([]byte("some bytes"))) var ( err error frame []byte ) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(2) go func() { err = framer.ReadFrames(ctx, reader, readFrames) wg.Done() }() go func() { frame = <-readFrames wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) assert.NotNil(t, frame) if frame != nil { assert.Equal(t, []byte("some bytes"), frame) } logEntries := loggerHook.AllEntries() require.Len(t, logEntries, 3) assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) assert.Contains(t, logEntries[0].Message, "raw bytes") assert.Equal(t, logrus.TraceLevel, logEntries[1].Level) assert.Equal(t, "frame decoded to length 10", logEntries[1].Message) assert.Equal(t, logrus.TraceLevel, logEntries[2].Level) assert.Equal(t, "read buffer is now 0 bytes long", logEntries[2].Message) }) } type brokenWriter struct{} var errB0rk3d = errors.New("you b0rk3d it") func (b brokenWriter) Write([]byte) (int, error) { return 0, errB0rk3d } type slowTestWriter struct{} func (s slowTestWriter) Write(_ []byte) (n int, err error) { time.Sleep(1 * time.Second) return 0, errors.New("you waited too long") } func TestCOBSFramer_WriteFrames(t *testing.T) { logger, loggerHook := test.NewNullLogger() logger.SetLevel(logrus.TraceLevel) framer, err := protocol.NewCOBSFramer(logger) require.NoError(t, err) require.NotNil(t, framer) t.Run("closed channel", func(t *testing.T) { t.Cleanup(func() { loggerHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } var result error in := make(chan []byte) close(in) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) go func() { result = framer.WriteFrames(ctx, io.Discard, in) wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.Nil(t, result) logEntries := loggerHook.AllEntries() assert.Len(t, logEntries, 1) assert.Equal(t, logrus.DebugLevel, logEntries[0].Level) assert.Equal(t, "channel closed", logEntries[0].Message) }) t.Run("closed writer", func(t *testing.T) { t.Cleanup(loggerHook.Reset) if testing.Short() { t.Skip("skipping test in short mode") } var result error in := make(chan []byte) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) go func() { result = framer.WriteFrames(ctx, &brokenWriter{}, in) wg.Done() }() in <- []byte("test message") time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NotNil(t, result) assert.True(t, errors.Is(result, errB0rk3d)) assert.Len(t, loggerHook.AllEntries(), 0) }) t.Run("closed context", func(t *testing.T) { t.Cleanup(loggerHook.Reset) if testing.Short() { t.Skip("skipping test in short mode") } framer, err := protocol.NewCOBSFramer(logger) assert.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(2) frames := make(chan []byte) go func() { err = framer.WriteFrames(ctx, &slowTestWriter{}, frames) wg.Done() }() go func() { select { case <-ctx.Done(): case frames <- []byte("my test frame that shall never go out"): } wg.Done() }() time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.NoError(t, err) }) t.Run("valid frame", func(t *testing.T) { t.Cleanup(func() { loggerHook.Reset() }) if testing.Short() { t.Skip("skipping test in short mode") } var err error in := make(chan []byte) out := &bytes.Buffer{} ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) go func() { err = framer.WriteFrames(ctx, out, in) wg.Done() }() in <- []byte("test message") time.Sleep(10 * time.Millisecond) cancel() wg.Wait() assert.Nil(t, err) logEntries := loggerHook.AllEntries() assert.Len(t, logEntries, 1) assert.Equal(t, logrus.TraceLevel, logEntries[0].Level) frame := out.Bytes() encoder, _ := cobs.NewEncoder(protocol.COBSConfig) assert.NoError(t, encoder.Verify(frame)) assert.Equal(t, []byte("test message"), encoder.Decode(out.Bytes())) }) }