Refactor public API tests for protocol

- move tests for public API to protocol_test package
- add tests for context handling of COBSFramer
main
Jan Dittberner 2 years ago
parent 7852c4d3df
commit 51ca9cc69d

@ -32,7 +32,7 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages" "git.cacert.org/cacert-gosigner/pkg/messages"
) )
const CobsDelimiter = 0x00 const COBSDelimiter = 0x00
type Command struct { type Command struct {
Announce *messages.CommandAnnounce Announce *messages.CommandAnnounce
@ -73,9 +73,9 @@ type ClientHandler interface {
var ( var (
errCommandExpected = errors.New("command must not be nil") errCommandExpected = errors.New("command must not be nil")
errCommandAnnounceExpected = errors.New("command must have an announcement") ErrCommandAnnounceExpected = errors.New("command must have an announcement")
errCommandDataExpected = errors.New("command must have data") ErrCommandDataExpected = errors.New("command must have data")
errResponseExpected = errors.New("response must not be nil") ErrResponseExpected = errors.New("response must not be nil")
ErrResponseAnnounceTimeoutExpired = errors.New("response announce timeout expired") ErrResponseAnnounceTimeoutExpired = errors.New("response announce timeout expired")
ErrResponseDataTimeoutExpired = errors.New("response data timeout expired") ErrResponseDataTimeoutExpired = errors.New("response data timeout expired")
@ -133,7 +133,7 @@ func (p *ServerProtocol) Handle(ctx context.Context) error {
return nil return nil
default: default:
p.logger.Debugf("handling protocol state %s", p.state) p.logger.WithField("state", p.state).Debug("handling protocol state")
switch p.state { switch p.state {
case cmdAnnounce: case cmdAnnounce:
@ -175,7 +175,7 @@ func (p *ServerProtocol) commandAnnounce(ctx context.Context) *Command {
func (p *ServerProtocol) commandData(ctx context.Context, command *Command) error { func (p *ServerProtocol) commandData(ctx context.Context, command *Command) error {
if command == nil || command.Announce == nil { if command == nil || command.Announce == nil {
return errCommandAnnounceExpected return ErrCommandAnnounceExpected
} }
err := p.handler.CommandData(ctx, p.in, command) err := p.handler.CommandData(ctx, p.in, command)
@ -194,7 +194,7 @@ func (p *ServerProtocol) commandData(ctx context.Context, command *Command) erro
func (p *ServerProtocol) handleCommand(ctx context.Context, command *Command) (*Response, error) { func (p *ServerProtocol) handleCommand(ctx context.Context, command *Command) (*Response, error) {
if command == nil || command.Announce == nil || command.Command == nil { if command == nil || command.Announce == nil || command.Command == nil {
return nil, errCommandDataExpected return nil, ErrCommandDataExpected
} }
response, err := p.handler.HandleCommand(ctx, command) response, err := p.handler.HandleCommand(ctx, command)
@ -213,7 +213,7 @@ func (p *ServerProtocol) handleCommand(ctx context.Context, command *Command) (*
func (p *ServerProtocol) respond(ctx context.Context, response *Response) error { func (p *ServerProtocol) respond(ctx context.Context, response *Response) error {
if response == nil { if response == nil {
return errResponseExpected return ErrResponseExpected
} }
err := p.handler.Respond(ctx, response, p.out) err := p.handler.Respond(ctx, response, p.out)
@ -260,7 +260,7 @@ func (p *ClientProtocol) Handle(ctx context.Context) error {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
default: default:
p.logger.Debugf("handling protocol state %s", p.state) p.logger.WithField("state", p.state).Debug("handling protocol state")
switch p.state { switch p.state {
case cmdAnnounce: case cmdAnnounce:
@ -331,7 +331,7 @@ func (p *ClientProtocol) respAnnounce(ctx context.Context) *Response {
func (p *ClientProtocol) respData(ctx context.Context, response *Response) error { func (p *ClientProtocol) respData(ctx context.Context, response *Response) error {
if response == nil || response.Announce == nil { if response == nil || response.Announce == nil {
return errResponseExpected return ErrResponseExpected
} }
err := p.handler.ResponseData(ctx, p.in, response) err := p.handler.ResponseData(ctx, p.in, response)
@ -354,7 +354,7 @@ func (p *ClientProtocol) respData(ctx context.Context, response *Response) error
func (p *ClientProtocol) handleResponse(ctx context.Context, response *Response) error { func (p *ClientProtocol) handleResponse(ctx context.Context, response *Response) error {
if response == nil || response.Announce == nil || response.Response == nil { if response == nil || response.Announce == nil || response.Response == nil {
return errResponseExpected return ErrResponseExpected
} }
err := p.handler.HandleResponse(ctx, response) err := p.handler.HandleResponse(ctx, response)
@ -404,8 +404,10 @@ type COBSFramer struct {
encoder cobs.Encoder encoder cobs.Encoder
} }
var COBSConfig = cobs.Config{SpecialByte: COBSDelimiter, Delimiter: true, EndingSave: true}
func NewCOBSFramer(logger *logrus.Logger) (*COBSFramer, error) { func NewCOBSFramer(logger *logrus.Logger) (*COBSFramer, error) {
encoder, err := cobs.NewEncoder(cobs.Config{SpecialByte: CobsDelimiter, Delimiter: true, EndingSave: true}) encoder, err := cobs.NewEncoder(COBSConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not setup encoder: %w", err) return nil, fmt.Errorf("could not setup encoder: %w", err)
} }
@ -447,7 +449,7 @@ func (c *COBSFramer) ReadFrames(ctx context.Context, reader io.Reader, frameChan
buffer.Write(raw) buffer.Write(raw)
for { for {
data, err = buffer.ReadBytes(CobsDelimiter) data, err = buffer.ReadBytes(COBSDelimiter)
if err != nil { if err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
buffer.Write(data) buffer.Write(data)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save