You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2152 lines
47 KiB
Go

/*
Copyright 2022 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package protocol
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"strings"
"sync"
"testing"
"testing/iotest"
"time"
"github.com/google/uuid"
"github.com/justincpresley/go-cobs"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.cacert.org/cacert-gosigner/pkg/messages"
)
type expectedLogs struct {
level logrus.Level
message string
}
func assertLogs(t *testing.T, hook *test.Hook, expected []expectedLogs) {
t.Helper()
logEntries := hook.AllEntries()
assert.Len(t, logEntries, len(expected))
for i, e := range expected {
assert.Equal(t, e.level, logEntries[i].Level)
assert.Equal(t, e.message, logEntries[i].Message)
}
}
func TestCommand_String(t *testing.T) {
c := &Command{
Announce: messages.BuildCommandAnnounce(messages.CmdUndef),
Command: "my undefined command",
}
str := c.String()
assert.NotEmpty(t, str)
assert.Contains(t, str, c.Announce.String())
assert.Contains(t, str, c.Command)
assert.Contains(t, str, "announce")
assert.Contains(t, str, "data")
assert.Contains(t, str, "Cmd[")
assert.True(t, strings.HasSuffix(str, "}]"))
}
func TestResponse_String(t *testing.T) {
r := &Response{
Announce: messages.BuildResponseAnnounce(messages.RespUndef, uuid.NewString()),
Response: "my undefined response",
}
str := r.String()
assert.NotEmpty(t, str)
assert.Contains(t, str, r.Announce.String())
assert.Contains(t, str, r.Response)
assert.Contains(t, str, "announce")
assert.Contains(t, str, "data")
assert.Contains(t, str, "Rsp[")
assert.True(t, strings.HasSuffix(str, "}]"))
}
func TestProtocolState_String(t *testing.T) {
goodStates := []struct {
name string
state protocolState
}{
{"command announce", cmdAnnounce},
{"command data", cmdData},
{"handle command", handleCommand},
{"respond", respond},
{"response announce", respAnnounce},
{"response data", respData},
{"handle response", handleResponse},
}
for _, s := range goodStates {
t.Run(s.name, func(t *testing.T) {
str := s.state.String()
assert.NotEmpty(t, str)
assert.NotContains(t, str, "unknown")
})
}
t.Run("unsupported state", func(t *testing.T) {
str := protocolState(-1).String()
assert.NotEmpty(t, str)
assert.Contains(t, str, "unknown")
assert.Contains(t, str, "-1")
})
}
type noopServerHandler struct{}
func (h *noopServerHandler) CommandAnnounce(context.Context, chan []byte) (*Command, error) {
return nil, nil
}
func (h *noopServerHandler) CommandData(context.Context, chan []byte, *Command) error {
return nil
}
func (h *noopServerHandler) HandleCommand(context.Context, *Command) (*Response, error) {
return nil, nil
}
func (h *noopServerHandler) Respond(context.Context, *Response, chan []byte) error {
return nil
}
type testServerHandler struct {
logger *logrus.Logger
}
func (h *testServerHandler) CommandAnnounce(ctx context.Context, in chan []byte) (*Command, error) {
select {
case <-ctx.Done():
return nil, nil
case frame := <-in:
h.logger.Infof("announce frame %s", string(frame))
return &Command{Announce: messages.BuildCommandAnnounce(messages.CmdUndef)}, nil
}
}
func (h *testServerHandler) CommandData(ctx context.Context, in chan []byte, command *Command) error {
select {
case <-ctx.Done():
return nil
case frame := <-in:
h.logger.Infof("command frame %s", string(frame))
command.Command = frame
return nil
}
}
func (h *testServerHandler) HandleCommand(_ context.Context, command *Command) (*Response, error) {
h.logger.Info("handle command")
return &Response{
Announce: messages.BuildResponseAnnounce(messages.RespUndef, command.Announce.ID),
Response: fmt.Sprintf("response for command %s", command.Command),
}, nil
}
func (h *testServerHandler) Respond(ctx context.Context, response *Response, out chan []byte) error {
h.logger.Info("send response")
buf := bytes.NewBuffer([]byte("test-response-"))
buf.WriteString(response.Announce.ID)
select {
case <-ctx.Done():
return nil
case out <- buf.Bytes():
return nil
}
}
type commandAnnounceErrServerHandler struct{ testServerHandler }
func (h *commandAnnounceErrServerHandler) CommandAnnounce(ctx context.Context, in chan []byte) (*Command, error) {
select {
case <-ctx.Done():
return nil, nil
case announce := <-in:
return nil, fmt.Errorf("failed to handle announce %s", announce)
}
}
type commandAnnounceNilServerHandler struct{ testServerHandler }
func (h *commandAnnounceNilServerHandler) CommandAnnounce(ctx context.Context, in chan []byte) (*Command, error) {
select {
case <-ctx.Done():
return nil, nil
case <-in:
return nil, nil
}
}
type commandDataErrServerHandler struct{ testServerHandler }
func (h *commandDataErrServerHandler) CommandData(ctx context.Context, in chan []byte, _ *Command) error {
select {
case <-ctx.Done():
return nil
case data := <-in:
return fmt.Errorf("failed to handle command data %s", data)
}
}
type commandDataNilAnnouncementServerHandler struct{ testServerHandler }
func (h *commandDataNilAnnouncementServerHandler) CommandAnnounce(
ctx context.Context,
in chan []byte,
) (*Command, error) {
select {
case <-ctx.Done():
return nil, nil
case <-in:
return &Command{}, nil
}
}
type handleCommandErrServerHandler struct{ testServerHandler }
func (h *handleCommandErrServerHandler) HandleCommand(_ context.Context, command *Command) (*Response, error) {
return nil, fmt.Errorf("failed to handle command %s", command)
}
type handleCommandNilCommandServerHandler struct{ testServerHandler }
func (h *handleCommandNilCommandServerHandler) CommandData(ctx context.Context, in chan []byte, _ *Command) error {
select {
case <-ctx.Done():
return nil
case <-in:
return nil
}
}
type respondErrServerHandler struct{ testServerHandler }
func (h *respondErrServerHandler) Respond(_ context.Context, r *Response, _ chan []byte) error {
return fmt.Errorf("failed to respond %s", r)
}
type respondNilResponseServerHandler struct{ testServerHandler }
func (h *respondNilResponseServerHandler) HandleCommand(context.Context, *Command) (*Response, error) {
return nil, nil
}
func TestNewServer(t *testing.T) {
logger, _ := test.NewNullLogger()
in := make(chan []byte)
out := make(chan []byte)
server := NewServer(&noopServerHandler{}, in, out, logger)
assert.NotNil(t, server)
assert.IsType(t, (*ServerProtocol)(nil), server)
}
func TestServerProtocol_Handle(t *testing.T) {
protoLog, pHook := test.NewNullLogger()
protoLog.SetLevel(logrus.TraceLevel)
handlerLog, hHook := test.NewNullLogger()
t.Run("initialization", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
assertLogs(t, hHook, []expectedLogs{})
})
t.Run("happy-path", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog)
var response []byte
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
go func() {
response = readServerResponse(ctx, out)
wg.Done()
}()
go func() {
sendServerCommand(ctx, in, "dropped announcement", "test command")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.NotNil(t, response)
assert.NotEmpty(t, response)
assert.True(t, strings.HasPrefix(string(response), "test-response"))
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respond)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "announce frame dropped announcement"},
{logrus.InfoLevel, "command frame test command"},
{logrus.InfoLevel, "handle command"},
{logrus.InfoLevel, "send response"},
})
})
t.Run("command-announce-nil", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&commandAnnounceNilServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog)
var response []byte
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
go func() {
response = readServerResponse(ctx, out)
wg.Done()
}()
go func() {
sendServerCommand(ctx, in, "dropped announcement", "")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.ErrorIs(t, err, errCommandAnnounceExpected)
assert.Nil(t, response)
assertLogs(t, hHook, []expectedLogs{})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)},
})
})
t.Run("command-announce-error", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&commandAnnounceErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog)
var response []byte
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
go func() {
response = readServerResponse(ctx, out)
wg.Done()
}()
go func() {
sendServerCommand(ctx, in, "dropped announcement", "")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.Nil(t, response)
assertLogs(t, hHook, []expectedLogs{})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.ErrorLevel, "could not handle command announce"},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("command-data-nil-announce", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&commandDataNilAnnouncementServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog)
var response []byte
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
go func() {
select {
case <-ctx.Done():
break
case response = <-out:
break
}
wg.Done()
}()
in <- []byte("dropped announcement")
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.ErrorIs(t, err, errCommandAnnounceExpected)
assert.Nil(t, response)
assertLogs(t, hHook, []expectedLogs{})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)},
})
})
t.Run("command-data-error", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&commandDataErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog)
var response []byte
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
go func() {
response = readServerResponse(ctx, out)
wg.Done()
}()
go func() {
sendServerCommand(ctx, in, "dropped announcement", "command fails")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.Nil(t, response)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "announce frame dropped announcement"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)},
{logrus.ErrorLevel, "could not handle command data"},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("handle-command-nil-command", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&handleCommandNilCommandServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog)
var response []byte
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
go func() {
response = readServerResponse(ctx, out)
wg.Done()
}()
go func() {
sendServerCommand(ctx, in, "dropped announcement", "dropped command")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.ErrorIs(t, err, errCommandDataExpected)
assert.Nil(t, response)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "announce frame dropped announcement"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)},
})
})
t.Run("handle-command-error", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&handleCommandErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog)
var response []byte
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
go func() {
response = readServerResponse(ctx, out)
wg.Done()
}()
go func() {
sendServerCommand(ctx, in, "dropped announcement", "command fails")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.Nil(t, response)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "announce frame dropped announcement"},
{logrus.InfoLevel, "command frame command fails"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)},
{logrus.ErrorLevel, "could not handle command"},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("respond-nil-response", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&respondNilResponseServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog)
var response []byte
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
go func() {
response = readServerResponse(ctx, out)
wg.Done()
}()
go func() {
sendServerCommand(ctx, in, "dropped announcement", "dropped command")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.ErrorIs(t, err, errResponseExpected)
assert.Nil(t, response)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "announce frame dropped announcement"},
{logrus.InfoLevel, "command frame dropped command"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respond)},
})
})
t.Run("respond-error", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&respondErrServerHandler{testServerHandler{logger: handlerLog}}, in, out, protoLog)
var response []byte
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
go func() {
response = readServerResponse(ctx, out)
wg.Done()
}()
go func() {
sendServerCommand(ctx, in, "dropped announcement", "command data")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.Nil(t, response)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "announce frame dropped announcement"},
{logrus.InfoLevel, "command frame command data"},
{logrus.InfoLevel, "handle command"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdData)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleCommand)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respond)},
{logrus.ErrorLevel, "could not respond"},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("unknown-protocol-state", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
var err error
p := NewServer(&testServerHandler{logger: handlerLog}, in, out, protoLog)
p.state = protocolState(100)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
err = p.Handle(ctx)
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.ErrorContains(t, err, "unknown protocol state")
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, "handling protocol state unknown 100"},
})
assertLogs(t, hHook, []expectedLogs{})
})
}
func sendServerCommand(ctx context.Context, in chan []byte, announce string, command string) {
select {
case <-ctx.Done():
return
case in <- []byte(announce):
break
}
if command == "" {
return
}
select {
case <-ctx.Done():
return
case in <- []byte(command):
break
}
}
func readServerResponse(ctx context.Context, out chan []byte) []byte {
select {
case <-ctx.Done():
return nil
case response := <-out:
return response
}
}
type noopClientHandler struct{}
func (h *noopClientHandler) Send(context.Context, *Command, chan []byte) error {
return nil
}
func (h *noopClientHandler) ResponseAnnounce(context.Context, chan []byte) (*Response, error) {
return nil, nil
}
func (h *noopClientHandler) ResponseData(context.Context, chan []byte, *Response) error {
return nil
}
func (h *noopClientHandler) HandleResponse(context.Context, *Response) error {
return nil
}
type testClientHandler struct{ logger *logrus.Logger }
func (h *testClientHandler) Send(ctx context.Context, command *Command, out chan []byte) error {
h.logger.Infof("send command %s", command.Announce.Code)
select {
case <-ctx.Done():
return nil
case out <- []byte(command.Announce.String()):
break
}
select {
case <-ctx.Done():
return nil
case out <- []byte(command.Command.(string)): //nolint:forcetypeassert
break
}
return nil
}
func (h *testClientHandler) ResponseAnnounce(ctx context.Context, in chan []byte) (*Response, error) {
select {
case <-ctx.Done():
return nil, nil
case announce := <-in:
h.logger.Infof("received response announce %s", announce)
response := &Response{Announce: messages.BuildResponseAnnounce(messages.RespUndef, string(announce))}
return response, nil
}
}
func (h *testClientHandler) ResponseData(ctx context.Context, in chan []byte, response *Response) error {
select {
case <-ctx.Done():
return nil
case data := <-in:
h.logger.Infof("received response data %s", string(data))
response.Response = data
return nil
}
}
func (h *testClientHandler) HandleResponse(_ context.Context, response *Response) error {
h.logger.Infof("handle response %s", response.Announce.ID)
return nil
}
type commandAnnounceErrClientHandler struct{ testClientHandler }
func (h *commandAnnounceErrClientHandler) Send(context.Context, *Command, chan []byte) error {
return errors.New("failed sending command")
}
func TestNewClient(t *testing.T) {
logger, _ := test.NewNullLogger()
in := make(chan []byte)
out := make(chan []byte)
commands := make(chan *Command, 10)
client := NewClient(&noopClientHandler{}, commands, in, out, logger)
assert.NotNil(t, client)
assert.IsType(t, (*ClientProtocol)(nil), client)
}
type responseAnnounceErrClientHandler struct{ testClientHandler }
func (h *responseAnnounceErrClientHandler) ResponseAnnounce(ctx context.Context, in chan []byte) (*Response, error) {
select {
case <-ctx.Done():
return nil, nil
case <-in:
return nil, errors.New("failed receiving response announce")
}
}
type responseAnnounceNilClientHandler struct{ testClientHandler }
func (h *responseAnnounceNilClientHandler) ResponseAnnounce(ctx context.Context, in chan []byte) (*Response, error) {
select {
case <-ctx.Done():
return nil, nil
case <-in:
return nil, nil
}
}
type responseDataErrClientHandler struct{ testClientHandler }
func (h *responseDataErrClientHandler) ResponseData(ctx context.Context, in chan []byte, _ *Response) error {
select {
case <-ctx.Done():
return nil
case <-in:
return errors.New("failed to handle response data")
}
}
type responseDataTimeoutErrClientHandler struct{ testClientHandler }
func (h *responseDataTimeoutErrClientHandler) ResponseData(ctx context.Context, in chan []byte, _ *Response) error {
select {
case <-ctx.Done():
return nil
case <-in:
return ErrResponseDataTimeoutExpired
}
}
type responseDataNilClientHandler struct{ testClientHandler }
func (h *responseDataNilClientHandler) ResponseData(ctx context.Context, in chan []byte, response *Response) error {
select {
case <-ctx.Done():
return nil
case <-in:
response.Response = nil
return nil
}
}
type handleResponseErrClientHandler struct{ testClientHandler }
func (h *handleResponseErrClientHandler) HandleResponse(context.Context, *Response) error {
return errors.New("failed to handle response")
}
func TestClientProtocol_Handle(t *testing.T) { //nolint:cyclop
protoLog, pHook := test.NewNullLogger()
protoLog.SetLevel(logrus.TraceLevel)
handlerLog, hHook := test.NewNullLogger()
t.Run("initialize", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
var (
err error
sent []byte
)
c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
defer wg.Done()
select {
case <-ctx.Done():
return
case sent = <-out:
return
}
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.Nil(t, sent)
assertLogs(t, hHook, []expectedLogs{})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("happy-case", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
var (
err error
announce, command []byte
)
c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(4)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
prepareClientCommand(ctx, commands, "test command")
wg.Done()
}()
go func() {
announce, command = sendClientCommand(ctx, out)
wg.Done()
}()
go func() {
receiveClientAnnounce(ctx, in, "test response announce", "test response data")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.NotNil(t, announce)
assert.NotNil(t, command)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "send command UNDEFINED"},
{logrus.InfoLevel, "received response announce test response announce"},
{logrus.InfoLevel, "received response data test response data"},
{logrus.InfoLevel, "handle response test response announce"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleResponse)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("command-announce-nil", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
close(commands)
var (
err error
announce, command []byte
)
c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
announce, command = sendClientCommand(ctx, out)
wg.Done()
}()
go func() {
receiveClientAnnounce(ctx, in, "test response announce", "test response data")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.Error(t, err)
assert.Nil(t, announce)
assert.Nil(t, command)
assertLogs(t, hHook, []expectedLogs{})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("command-announce-error", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
var (
err error
announce, command []byte
)
c := NewClient(&commandAnnounceErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(4)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
announce, command = sendClientCommand(ctx, out)
wg.Done()
}()
go func() {
prepareClientCommand(ctx, commands, "test command")
wg.Done()
}()
go func() {
receiveClientAnnounce(ctx, in, "test response announce", "test response data")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.Nil(t, announce)
assert.Nil(t, command)
assertLogs(t, hHook, []expectedLogs{})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.ErrorLevel, "could not send command announce"},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("response-announce-error", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
var (
err error
announce, command []byte
)
c := NewClient(&responseAnnounceErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(4)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
announce, command = sendClientCommand(ctx, out)
wg.Done()
}()
go func() {
prepareClientCommand(ctx, commands, "test command")
wg.Done()
}()
go func() {
receiveClientAnnounce(ctx, in, "test response announce", "test response data")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.NotNil(t, announce)
assert.NotNil(t, command)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "send command UNDEFINED"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)},
{logrus.ErrorLevel, "could not handle response announce"},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("response-data-nil", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
var (
err error
announce, command []byte
)
c := NewClient(&responseAnnounceNilClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(4)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
announce, command = sendClientCommand(ctx, out)
wg.Done()
}()
go func() {
prepareClientCommand(ctx, commands, "test command")
wg.Done()
}()
go func() {
receiveClientAnnounce(ctx, in, "test response announce", "test response data")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.ErrorIs(t, err, errResponseExpected)
assert.NotNil(t, announce)
assert.NotNil(t, command)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "send command UNDEFINED"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)},
})
})
t.Run("response-data-error", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
var (
err error
announce, command []byte
)
c := NewClient(&responseDataErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(4)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
announce, command = sendClientCommand(ctx, out)
wg.Done()
}()
go func() {
prepareClientCommand(ctx, commands, "test command")
wg.Done()
}()
go func() {
receiveClientAnnounce(ctx, in, "test response announce", "test response data")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.NotNil(t, announce)
assert.NotNil(t, command)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "send command UNDEFINED"},
{logrus.InfoLevel, "received response announce test response announce"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)},
{logrus.ErrorLevel, "could not handle response data"},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)},
})
})
t.Run("response-data-timeout-error", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
var (
err error
announce, command []byte
)
c := NewClient(&responseDataTimeoutErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(4)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
announce, command = sendClientCommand(ctx, out)
wg.Done()
}()
go func() {
prepareClientCommand(ctx, commands, "test command")
wg.Done()
}()
go func() {
receiveClientAnnounce(ctx, in, "test response announce", "test response data")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.NotNil(t, announce)
assert.NotNil(t, command)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "send command UNDEFINED"},
{logrus.InfoLevel, "received response announce test response announce"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)},
{logrus.ErrorLevel, "could not handle response data"},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("handle-response-nil", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
var (
err error
announce, command []byte
)
c := NewClient(&responseDataNilClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(4)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
announce, command = sendClientCommand(ctx, out)
wg.Done()
}()
go func() {
prepareClientCommand(ctx, commands, "test command")
wg.Done()
}()
go func() {
receiveClientAnnounce(ctx, in, "test response announce", "test response data")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.ErrorIs(t, err, errResponseExpected)
assert.NotNil(t, announce)
assert.NotNil(t, command)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "send command UNDEFINED"},
{logrus.InfoLevel, "received response announce test response announce"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleResponse)},
})
})
t.Run("handle-response-error", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
var (
err error
announce, command []byte
)
c := NewClient(&handleResponseErrClientHandler{testClientHandler{handlerLog}}, commands, in, out, protoLog)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(4)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
announce, command = sendClientCommand(ctx, out)
wg.Done()
}()
go func() {
prepareClientCommand(ctx, commands, "test command")
wg.Done()
}()
go func() {
receiveClientAnnounce(ctx, in, "test response announce", "test response data")
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.NotNil(t, announce)
assert.NotNil(t, command)
assertLogs(t, hHook, []expectedLogs{
{logrus.InfoLevel, "send command UNDEFINED"},
{logrus.InfoLevel, "received response announce test response announce"},
{logrus.InfoLevel, "received response data test response data"},
})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respAnnounce)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", respData)},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", handleResponse)},
{logrus.ErrorLevel, "could not handle response"},
{logrus.DebugLevel, fmt.Sprintf("handling protocol state %s", cmdAnnounce)},
})
})
t.Run("unknown-protocol-state", func(t *testing.T) {
t.Cleanup(func() {
pHook.Reset()
hHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
in, out := make(chan []byte), make(chan []byte)
commands := make(chan *Command, 10)
var (
err error
sent []byte
)
c := NewClient(&testClientHandler{handlerLog}, commands, in, out, protoLog)
c.state = protocolState(100)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
err = c.Handle(ctx)
wg.Done()
}()
go func() {
defer wg.Done()
select {
case <-ctx.Done():
return
case sent = <-out:
return
}
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.ErrorContains(t, err, "unknown protocol state")
assert.Nil(t, sent)
assertLogs(t, hHook, []expectedLogs{})
assertLogs(t, pHook, []expectedLogs{
{logrus.DebugLevel, "handling protocol state unknown 100"},
})
})
}
func prepareClientCommand(ctx context.Context, commands chan *Command, command string) {
select {
case <-ctx.Done():
return
case commands <- &Command{Announce: messages.BuildCommandAnnounce(messages.CmdUndef), Command: command}:
return
}
}
func receiveClientAnnounce(ctx context.Context, in chan []byte, announce, data string) {
select {
case <-ctx.Done():
return
case in <- []byte(announce):
break
}
select {
case <-ctx.Done():
return
case in <- []byte(data):
break
}
}
func sendClientCommand(ctx context.Context, out chan []byte) ([]byte, []byte) {
var announce, command []byte
select {
case <-ctx.Done():
return nil, nil
case announce = <-out:
break
}
select {
case <-ctx.Done():
return announce, nil
case command = <-out:
return announce, command
}
}
func TestNewCOBSFramer(t *testing.T) {
logger, _ := test.NewNullLogger()
framer := NewCOBSFramer(logger)
require.NotNil(t, framer)
assert.IsType(t, (*COBSFramer)(nil), framer)
assert.Implements(t, (*Framer)(nil), framer)
}
func TestCOBSFramer_ReadFrames(t *testing.T) {
logger, loggerHook := test.NewNullLogger()
logger.SetLevel(logrus.TraceLevel)
framer := NewCOBSFramer(logger)
t.Run("read error", func(t *testing.T) {
t.Cleanup(func() {
loggerHook.Reset()
})
readFrames := make(chan []byte)
testError := errors.New("test error")
reader := iotest.ErrReader(testError)
err := framer.ReadFrames(context.Background(), reader, readFrames)
assert.ErrorIs(t, err, testError)
frame := <-readFrames
assert.Nil(t, frame)
assert.Empty(t, loggerHook.AllEntries())
})
t.Run("no bytes", func(t *testing.T) {
t.Cleanup(func() {
loggerHook.Reset()
})
if testing.Short() {
t.Skip("skipping test in short mode")
}
readFrames := make(chan []byte)
reader := &bytes.Buffer{}
var err error
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
err = framer.ReadFrames(ctx, reader, readFrames)
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.Empty(t, loggerHook.AllEntries())
})
t.Run("incomplete bytes", func(t *testing.T) {
t.Cleanup(func() { loggerHook.Reset() })
if testing.Short() {
t.Skip("skipping test in short mode")
}
readFrames := make(chan []byte)
reader := strings.NewReader("some bytes")
var err error
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
err = framer.ReadFrames(ctx, reader, readFrames)
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.Nil(t, err)
logEntries := loggerHook.AllEntries()
require.Len(t, logEntries, 2)
assert.Equal(t, logrus.TraceLevel, logEntries[0].Level)
assert.Equal(t, "read 10 raw bytes", logEntries[0].Message)
assert.Equal(t, logrus.TraceLevel, logEntries[1].Level)
assert.Equal(t, "read buffer is now 10 bytes long", logEntries[1].Message)
})
t.Run("invalid bytes", func(t *testing.T) {
t.Cleanup(func() { loggerHook.Reset() })
if testing.Short() {
t.Skip("skipping test in short mode")
}
readFrames := make(chan []byte)
reader := bytes.NewBuffer([]byte("some bytes"))
reader.WriteByte(CobsDelimiter)
var err error
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
err = framer.ReadFrames(ctx, reader, readFrames)
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
logEntries := loggerHook.AllEntries()
require.Len(t, logEntries, 3)
assert.Equal(t, logrus.TraceLevel, logEntries[0].Level)
assert.Equal(t, "read 11 raw bytes", logEntries[0].Message)
assert.Equal(t, logrus.WarnLevel, logEntries[1].Level)
assert.Equal(t, "skipping invalid frame of 11 bytes", logEntries[1].Message)
assert.Equal(t, logrus.TraceLevel, logEntries[2].Level)
assert.Equal(t, "read buffer is now 0 bytes long", logEntries[2].Message)
})
t.Run("valid frame", func(t *testing.T) {
t.Cleanup(func() { loggerHook.Reset() })
if testing.Short() {
t.Skip("skipping test in short mode")
}
readFrames := make(chan []byte)
reader := bytes.NewBuffer(cobs.Encode([]byte("some bytes"), framer.config))
var (
err error
frame []byte
)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
err = framer.ReadFrames(ctx, reader, readFrames)
wg.Done()
}()
go func() {
frame = <-readFrames
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NoError(t, err)
assert.NotNil(t, frame)
if frame != nil {
assert.Equal(t, []byte("some bytes"), frame)
}
logEntries := loggerHook.AllEntries()
require.Len(t, logEntries, 3)
assert.Equal(t, logrus.TraceLevel, logEntries[0].Level)
assert.Contains(t, logEntries[0].Message, "raw bytes")
assert.Equal(t, logrus.TraceLevel, logEntries[1].Level)
assert.Equal(t, "frame decoded to length 10", logEntries[1].Message)
assert.Equal(t, logrus.TraceLevel, logEntries[2].Level)
assert.Equal(t, "read buffer is now 0 bytes long", logEntries[2].Message)
})
}
type brokenWriter struct{}
var errB0rk3d = errors.New("you b0rk3d it")
func (b brokenWriter) Write([]byte) (int, error) {
return 0, errB0rk3d
}
func TestCOBSFramer_WriteFrames(t *testing.T) {
logger, loggerHook := test.NewNullLogger()
logger.SetLevel(logrus.TraceLevel)
framer := NewCOBSFramer(logger)
t.Run("closed channel", func(t *testing.T) {
t.Cleanup(func() { loggerHook.Reset() })
if testing.Short() {
t.Skip("skipping test in short mode")
}
var result error
in := make(chan []byte)
close(in)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
result = framer.WriteFrames(ctx, io.Discard, in)
wg.Done()
}()
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.Nil(t, result)
logEntries := loggerHook.AllEntries()
assert.Len(t, logEntries, 1)
assert.Equal(t, logrus.DebugLevel, logEntries[0].Level)
assert.Equal(t, "channel closed", logEntries[0].Message)
})
t.Run("closed writer", func(t *testing.T) {
t.Cleanup(func() { loggerHook.Reset() })
if testing.Short() {
t.Skip("skipping test in short mode")
}
var result error
in := make(chan []byte)
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
result = framer.WriteFrames(ctx, &brokenWriter{}, in)
wg.Done()
}()
in <- []byte("test message")
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.NotNil(t, result)
assert.True(t, errors.Is(result, errB0rk3d))
assert.Len(t, loggerHook.AllEntries(), 0)
})
t.Run("valid frame", func(t *testing.T) {
t.Cleanup(func() { loggerHook.Reset() })
if testing.Short() {
t.Skip("skipping test in short mode")
}
var err error
in := make(chan []byte)
out := &bytes.Buffer{}
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
err = framer.WriteFrames(ctx, out, in)
wg.Done()
}()
in <- []byte("test message")
time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
assert.Nil(t, err)
logEntries := loggerHook.AllEntries()
assert.Len(t, logEntries, 1)
assert.Equal(t, logrus.TraceLevel, logEntries[0].Level)
frame := out.Bytes()
assert.NoError(t, cobs.Verify(frame, framer.config))
assert.Equal(t, []byte("test message"), cobs.Decode(out.Bytes(), framer.config))
})
}