From af40662c7d61995c7769800112c98f74197f6cd2 Mon Sep 17 00:00:00 2001 From: Jan Dittberner Date: Tue, 29 Nov 2022 14:05:10 +0100 Subject: [PATCH] Refactor client protocol - define protocols.ClientHandler interface as base for client implementations - implement protocols.ClientHandler in clientsim's ClientHandler type - move protocol state handling into protocols.ServerProtocol and protocols.ClientProtocol - move protocolState type into protocols.go - reduce clientsim's TestCommandGenerator responsibility to test command generation --- cmd/clientsim/main.go | 379 +++++++++++----------------------- internal/handler/msgpack.go | 10 +- internal/serial/seriallink.go | 133 ++---------- pkg/messages/messages.go | 10 +- pkg/protocol/protocol.go | 235 ++++++++++++++++++++- 5 files changed, 388 insertions(+), 379 deletions(-) diff --git a/cmd/clientsim/main.go b/cmd/clientsim/main.go index 1bab83c..a446ce8 100644 --- a/cmd/clientsim/main.go +++ b/cmd/clientsim/main.go @@ -25,7 +25,6 @@ import ( "fmt" "io" "os" - "sync" "time" "github.com/shamaton/msgpackgen/msgpack" @@ -36,124 +35,9 @@ import ( "git.cacert.org/cacert-gosigner/pkg/messages" ) -type protocolState int8 - -const ( - cmdAnnounce protocolState = iota - cmdData - respAnnounce - respData -) - -var validTransitions = map[protocolState]protocolState{ - cmdAnnounce: cmdData, - cmdData: respAnnounce, - respAnnounce: respData, - respData: cmdAnnounce, -} - -var protocolStateNames = map[protocolState]string{ - cmdAnnounce: "CMD ANNOUNCE", - cmdData: "CMD DATA", - respAnnounce: "RESP ANNOUNCE", - respData: "RESP DATA", -} - -func (p protocolState) String() string { - if name, ok := protocolStateNames[p]; ok { - return name - } - - return fmt.Sprintf("unknown %d", p) -} - type TestCommandGenerator struct { - logger *logrus.Logger - currentCommand *protocol.Command - currentResponse *protocol.Response - commands chan *protocol.Command - lock sync.Mutex -} - -func (g *TestCommandGenerator) CmdAnnouncement() ([]byte, error) { - g.lock.Lock() - defer g.lock.Unlock() - - g.currentCommand = <-g.commands - - announceData, err := msgpack.Marshal(g.currentCommand.Announce) - if err != nil { - return nil, fmt.Errorf("could not marshal command annoucement: %w", err) - } - - g.logger.WithField("announcement", &g.currentCommand.Announce).Info("write command announcement") - - return announceData, nil -} - -func (g *TestCommandGenerator) CmdData() ([]byte, error) { - g.lock.Lock() - defer g.lock.Unlock() - - cmdData, err := msgpack.Marshal(g.currentCommand.Command) - if err != nil { - return nil, fmt.Errorf("could not marshal command data: %w", err) - } - - g.logger.WithField("command", &g.currentCommand.Command).Info("write command data") - - return cmdData, nil -} - -func (g *TestCommandGenerator) HandleResponseAnnounce(frame []byte) error { - g.lock.Lock() - defer g.lock.Unlock() - - var ann messages.ResponseAnnounce - - if err := msgpack.Unmarshal(frame, &ann); err != nil { - return fmt.Errorf("could not unmarshal response announcement") - } - - g.logger.WithField("announcement", &ann).Info("received response announcement") - - g.currentResponse = &protocol.Response{Announce: &ann} - - return nil -} - -func (g *TestCommandGenerator) HandleResponse(frame []byte) error { - g.lock.Lock() - defer g.lock.Unlock() - - switch g.currentResponse.Announce.Code { - case messages.RespHealth: - var response messages.HealthResponse - - if err := msgpack.Unmarshal(frame, &response); err != nil { - return fmt.Errorf("unmarshal failed: %w", err) - } - - g.currentResponse.Response = &response - case messages.RespFetchCRL: - var response messages.FetchCRLResponse - - if err := msgpack.Unmarshal(frame, &response); err != nil { - return fmt.Errorf("unmarshal failed: %w", err) - } - - g.currentResponse.Response = &response - } - - g.logger.WithField( - "command", - g.currentCommand, - ).WithField( - "response", - g.currentResponse, - ).Info("handled health response") - - return nil + logger *logrus.Logger + commands chan *protocol.Command } func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error { @@ -200,7 +84,11 @@ func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error { case <-ctx.Done(): _ = healthTimer.Stop() - g.logger.Info("stopping health check loop") + g.logger.Info("stopped health check loop") + + _ = crlTimer.Stop() + + g.logger.Info("stopped CRL fetch loop") return nil case <-healthTimer.C: @@ -225,187 +113,169 @@ func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error { } type clientSimulator struct { - protocolState protocolState - logger *logrus.Logger - lock sync.Mutex + clientHandler protocol.ClientHandler framesIn chan []byte framesOut chan []byte framer protocol.Framer commandGenerator *TestCommandGenerator + logger *logrus.Logger } -func (c *clientSimulator) writeCmdAnnouncement() error { - frame, err := c.commandGenerator.CmdAnnouncement() - if err != nil { - return fmt.Errorf("could not get command annoucement: %w", err) - } - - c.logger.Trace("writing command announcement") - - c.framesOut <- frame - - if err := c.nextState(); err != nil { - return err - } - - return nil -} - -func (c *clientSimulator) writeCommand() error { - frame, err := c.commandGenerator.CmdData() - if err != nil { - return fmt.Errorf("could not get command data: %w", err) - } - - c.logger.Trace("writing command data") - - c.framesOut <- frame - - if err := c.nextState(); err != nil { - return err - } - - return nil -} - -const responseAnnounceTimeout = 30 * time.Second -const responseDataTimeout = 2 * time.Second - -func (c *clientSimulator) handleResponseAnnounce() error { - c.logger.Trace("waiting for response announce") - - select { - case frame := <-c.framesIn: - if frame == nil { - return nil - } - - if err := c.commandGenerator.HandleResponseAnnounce(frame); err != nil { - return fmt.Errorf("response announce handling failed: %w", err) - } - - if err := c.nextState(); err != nil { - return err - } - case <-time.After(responseAnnounceTimeout): - c.logger.Warn("response announce timeout expired") - - c.protocolState = cmdAnnounce - - return nil - } - - return nil -} - -func (c *clientSimulator) handleResponseData() error { - c.logger.Trace("waiting for response data") - - select { - case frame := <-c.framesIn: - if frame == nil { - return nil - } - - if err := c.commandGenerator.HandleResponse(frame); err != nil { - return fmt.Errorf("response handler failed: %w", err) - } - - if err := c.nextState(); err != nil { - return err - } - - return nil - case <-time.After(responseDataTimeout): - c.logger.Warn("response data timeout expired") - - c.protocolState = cmdAnnounce - - return nil - } -} +const ( + responseAnnounceTimeout = 30 * time.Second + responseDataTimeout = 2 * time.Second +) func (c *clientSimulator) Run(ctx context.Context) error { - c.protocolState = cmdAnnounce - errors := make(chan error) + framerErrors := make(chan error) + protocolErrors := make(chan error) + generatorErrors := make(chan error) go func() { err := c.framer.ReadFrames(os.Stdin, c.framesIn) - errors <- err + framerErrors <- err }() go func() { err := c.framer.WriteFrames(os.Stdout, c.framesOut) - errors <- err + framerErrors <- err + }() + + go func() { + clientProtocol := protocol.NewClient(c.clientHandler, c.commandGenerator.commands, c.framesIn, c.framesOut, c.logger) + + err := clientProtocol.Handle() + + protocolErrors <- err }() go func() { err := c.commandGenerator.GenerateCommands(ctx) - errors <- err + generatorErrors <- err }() for { select { case <-ctx.Done(): return nil - case err := <-errors: + case err := <-framerErrors: if err != nil { - return fmt.Errorf("error from handler loop: %w", err) + return fmt.Errorf("error from framer: %w", err) } return nil - default: - if err := c.handleProtocolState(); err != nil { - return err + case err := <-generatorErrors: + if err != nil { + return fmt.Errorf("error from command generator: %w", err) } + + return nil + case err := <-protocolErrors: + if err != nil { + return fmt.Errorf("error from protocol handler: %w", err) + } + + return nil } } } -func (c *clientSimulator) handleProtocolState() error { - c.logger.Tracef("handling protocol state %s", c.protocolState) +type ClientHandler struct { + logger *logrus.Logger +} - c.lock.Lock() - defer c.lock.Unlock() +func (c ClientHandler) Send(command *protocol.Command, out chan []byte) error { + var ( + frame []byte + err error + ) - switch c.protocolState { - case cmdAnnounce: - if err := c.writeCmdAnnouncement(); err != nil { - return err + frame, err = msgpack.Marshal(command.Announce) + if err != nil { + return fmt.Errorf("could not marshal command annoucement: %w", err) + } + + c.logger.WithField("announcement", command.Announce).Info("write command announcement") + + c.logger.Trace("writing command announcement") + + out <- frame + + frame, err = msgpack.Marshal(command.Command) + if err != nil { + return fmt.Errorf("could not marshal command data: %w", err) + } + + c.logger.WithField("command", command.Command).Info("write command data") + + out <- frame + + return nil +} + +func (c ClientHandler) ResponseAnnounce(in chan []byte) (*protocol.Response, error) { + response := &protocol.Response{} + + var announce messages.ResponseAnnounce + + select { + case frame := <-in: + if err := msgpack.Unmarshal(frame, &announce); err != nil { + return nil, fmt.Errorf("could not unmarshal response announcement: %w", err) } - case cmdData: - if err := c.writeCommand(); err != nil { - return err + + response.Announce = &announce + + c.logger.WithField("announcement", response.Announce).Debug("received response announcement") + + return response, nil + case <-time.After(responseAnnounceTimeout): + return nil, protocol.ErrResponseAnnounceTimeoutExpired + } +} + +func (c ClientHandler) ResponseData(in chan []byte, response *protocol.Response) error { + select { + case frame := <-in: + switch response.Announce.Code { + case messages.RespHealth: + var resp messages.HealthResponse + if err := msgpack.Unmarshal(frame, &resp); err != nil { + return fmt.Errorf("could not unmarshal health response data: %w", err) + } + + response.Response = &resp + case messages.RespFetchCRL: + var resp messages.FetchCRLResponse + if err := msgpack.Unmarshal(frame, &resp); err != nil { + return fmt.Errorf("could not unmarshal fetch CRL response data: %w", err) + } + + response.Response = &resp + default: + return fmt.Errorf("unhandled response code %s", response.Announce.Code) } - case respAnnounce: - if err := c.handleResponseAnnounce(); err != nil { - return err - } - case respData: - if err := c.handleResponseData(); err != nil { - return err - } - default: - return fmt.Errorf("unknown protocol state %s", c.protocolState) + case <-time.After(responseDataTimeout): + return protocol.ErrResponseDataTimeoutExpired } return nil } -func (c *clientSimulator) nextState() error { - next, ok := validTransitions[c.protocolState] - if !ok { - return fmt.Errorf("illegal protocol state %s", c.protocolState) - } - - c.protocolState = next +func (c ClientHandler) HandleResponse(response *protocol.Response) error { + c.logger.WithField("response", response.Announce).Info("handled response") + c.logger.WithField("response", response).Debug("full response") return nil } +func newClientHandler(logger *logrus.Logger) *ClientHandler { + return &ClientHandler{logger: logger} +} + func main() { logger := logrus.New() logger.SetOutput(os.Stderr) @@ -418,10 +288,11 @@ func main() { logger: logger, commands: make(chan *protocol.Command), }, - logger: logger, - framesIn: make(chan []byte), - framesOut: make(chan []byte), - framer: protocol.NewCOBSFramer(logger), + logger: logger, + framesIn: make(chan []byte), + framesOut: make(chan []byte), + framer: protocol.NewCOBSFramer(logger), + clientHandler: newClientHandler(logger), } err := sim.Run(context.Background()) diff --git a/internal/handler/msgpack.go b/internal/handler/msgpack.go index 2ad12ec..c2b196c 100644 --- a/internal/handler/msgpack.go +++ b/internal/handler/msgpack.go @@ -57,7 +57,11 @@ func (m *MsgPackHandler) CommandAnnounce(frames chan []byte) (*protocol.Command, return nil, fmt.Errorf("could not unmarshal command announcement: %w", err) } - m.logger.WithField("announcement", &ann).Info("received command announcement") + if ann.Code == messages.CmdUndef { + return nil, fmt.Errorf("received undefined command announcement: %s", ann) + } + + m.logger.WithField("announcement", &ann).Debug("received command announcement") return &protocol.Command{Announce: &ann}, nil } @@ -109,7 +113,7 @@ func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) e m.lock.Lock() defer m.lock.Unlock() - announce, err := msgpack.Marshal(response) + announce, err := msgpack.Marshal(response.Announce) if err != nil { return fmt.Errorf("could not marshal response announcement: %w", err) } @@ -125,7 +129,7 @@ func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) e m.logger.WithField("length", len(data)).Debug("write response") - out <- announce + out <- data return nil } diff --git a/internal/serial/seriallink.go b/internal/serial/seriallink.go index 8f6557e..9fc0787 100644 --- a/internal/serial/seriallink.go +++ b/internal/serial/seriallink.go @@ -20,7 +20,6 @@ package serial import ( "context" - "errors" "fmt" "github.com/sirupsen/logrus" @@ -30,39 +29,14 @@ import ( "git.cacert.org/cacert-gosigner/pkg/protocol" ) -type protocolState int8 - -const ( - cmdAnnounce protocolState = iota - cmdData - handleCommand - respond -) - -var protocolStateNames = map[protocolState]string{ - cmdAnnounce: "CMD ANNOUNCE", - cmdData: "CMD DATA", - handleCommand: "RESP ANNOUNCE", - respond: "RESP DATA", -} - -func (p protocolState) String() string { - if name, ok := protocolStateNames[p]; ok { - return name - } - - return fmt.Sprintf("unknown %d", p) -} - type Handler struct { - protocolHandler protocol.ServerHandler - protocolState protocolState - framer protocol.Framer - config *serial.Config - port *serial.Port - logger *logrus.Logger - framesIn chan []byte - framesOut chan []byte + serverHandler protocol.ServerHandler + framer protocol.Framer + config *serial.Config + port *serial.Port + logger *logrus.Logger + framesIn chan []byte + framesOut chan []byte } func (h *Handler) setupConnection() error { @@ -86,7 +60,6 @@ func (h *Handler) Close() error { } func (h *Handler) Run(ctx context.Context) error { - h.protocolState = cmdAnnounce protocolErrors, framerErrors := make(chan error), make(chan error) go func() { @@ -102,7 +75,9 @@ func (h *Handler) Run(ctx context.Context) error { }() go func() { - err := h.handleProtocolState() + serverProtocol := protocol.NewServer(h.serverHandler, h.framesIn, h.framesOut, h.logger) + + err := serverProtocol.Handle() protocolErrors <- err }() @@ -127,95 +102,17 @@ func (h *Handler) Run(ctx context.Context) error { } } -var errCommandExpected = errors.New("command must not be nil") -var errResponseExpected = errors.New("response must not be nil") - -func (h *Handler) handleProtocolState() error { - var ( - command *protocol.Command - response *protocol.Response - err error - ) - - for { - h.logger.Debugf("handling protocol state %s", h.protocolState) - - switch h.protocolState { - case cmdAnnounce: - command, err = h.protocolHandler.CommandAnnounce(h.framesIn) - if err != nil { - h.logger.WithError(err).Error("could not handle command announce") - - break - } - - h.protocolState = cmdData - case cmdData: - if command == nil { - return errCommandExpected - } - - err = h.protocolHandler.CommandData(h.framesIn, command) - if err != nil { - h.logger.WithError(err).Error("could not handle command data") - - h.protocolState = cmdAnnounce - - break - } - - h.protocolState = handleCommand - case handleCommand: - if command == nil { - return errCommandExpected - } - - response, err = h.protocolHandler.HandleCommand(command) - if err != nil { - h.logger.WithError(err).Error("could not handle command") - - h.protocolState = cmdAnnounce - - break - } - - command = nil - - h.protocolState = respond - case respond: - if response == nil { - return errResponseExpected - } - - err = h.protocolHandler.Respond(response, h.framesOut) - if err != nil { - h.logger.WithError(err).Error("could not respond") - - h.protocolState = cmdAnnounce - - break - } - - response = nil - - h.protocolState = cmdAnnounce - default: - return fmt.Errorf("unknown protocol state %s", h.protocolState) - } - } -} - func New( cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.ServerHandler, ) (*Handler, error) { h := &Handler{ - protocolHandler: protocolHandler, - logger: logger, - framesIn: make(chan []byte), - framesOut: make(chan []byte), - framer: protocol.NewCOBSFramer(logger), + serverHandler: protocolHandler, + logger: logger, + framesIn: make(chan []byte), + framesOut: make(chan []byte), + framer: protocol.NewCOBSFramer(logger), } h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout} diff --git a/pkg/messages/messages.go b/pkg/messages/messages.go index 7b76751..744e2d1 100644 --- a/pkg/messages/messages.go +++ b/pkg/messages/messages.go @@ -37,11 +37,13 @@ import ( type CommandCode int8 const ( - CmdHealth CommandCode = iota + CmdUndef CommandCode = iota + CmdHealth CmdFetchCRL ) var commandNames = map[CommandCode]string{ + CmdUndef: "UNDEFINED", CmdHealth: "HEALTH", CmdFetchCRL: "FETCH CRL", } @@ -57,13 +59,15 @@ func (c CommandCode) String() string { type ResponseCode int8 const ( - RespError ResponseCode = -1 - RespHealth ResponseCode = iota + RespError ResponseCode = -1 + RespUndef ResponseCode = iota + RespHealth RespFetchCRL ) var responseNames = map[ResponseCode]string{ RespError: "ERROR", + RespUndef: "UNDEFINED", RespHealth: "HEALTH", RespFetchCRL: "FETCH CRL", } diff --git a/pkg/protocol/protocol.go b/pkg/protocol/protocol.go index f2b6ba8..50143ff 100644 --- a/pkg/protocol/protocol.go +++ b/pkg/protocol/protocol.go @@ -63,6 +63,139 @@ type ServerHandler interface { Respond(*Response, chan []byte) error } +type ClientHandler interface { + Send(*Command, chan []byte) error + ResponseAnnounce(chan []byte) (*Response, error) + ResponseData(chan []byte, *Response) error + HandleResponse(*Response) error +} + +var ( + errCommandExpected = errors.New("command must not be nil") + errResponseExpected = errors.New("response must not be nil") + + ErrResponseAnnounceTimeoutExpired = errors.New("response announce timeout expired") + ErrResponseDataTimeoutExpired = errors.New("response data timeout expired") +) + +type protocolState int8 + +const ( + cmdAnnounce protocolState = iota + cmdData + handleCommand + respond + respAnnounce + respData + handleResponse +) + +var protocolStateNames = map[protocolState]string{ + cmdAnnounce: "CMD ANNOUNCE", + cmdData: "CMD DATA", + handleCommand: "HANDLE CMD", + respond: "RESPOND", + respAnnounce: "RESP ANNOUNCE", + respData: "RESP DATA", + handleResponse: "HANDLE RESP", +} + +func (p protocolState) String() string { + if name, ok := protocolStateNames[p]; ok { + return name + } + + return fmt.Sprintf("unknown %d", p) +} + +type ServerProtocol struct { + handler ServerHandler + in, out chan []byte + logger *logrus.Logger + state protocolState +} + +func (p *ServerProtocol) Handle() error { + var ( + command *Command + response *Response + err error + ) + + for { + p.logger.Debugf("handling protocol state %s", p.state) + + switch p.state { + case cmdAnnounce: + command, err = p.handler.CommandAnnounce(p.in) + if err != nil { + p.logger.WithError(err).Error("could not handle command announce") + + break + } + + p.state = cmdData + case cmdData: + if command == nil { + return errCommandExpected + } + + err = p.handler.CommandData(p.in, command) + if err != nil { + p.logger.WithError(err).Error("could not handle command data") + + p.state = cmdAnnounce + + break + } + + p.state = handleCommand + case handleCommand: + if command == nil { + return errCommandExpected + } + + response, err = p.handler.HandleCommand(command) + if err != nil { + p.logger.WithError(err).Error("could not handle command") + + p.state = cmdAnnounce + + break + } + + p.state = respond + case respond: + if response == nil { + return errResponseExpected + } + + err = p.handler.Respond(response, p.out) + if err != nil { + p.logger.WithError(err).Error("could not respond") + + p.state = cmdAnnounce + + break + } + + p.state = cmdAnnounce + default: + return fmt.Errorf("unknown protocol state %s", p.state) + } + } +} + +func NewServer(handler ServerHandler, in, out chan []byte, logger *logrus.Logger) *ServerProtocol { + return &ServerProtocol{ + handler: handler, + in: in, + out: out, + logger: logger, + state: cmdAnnounce, + } +} + // Framer handles bytes on the wire by adding or removing framing information. type Framer interface { // ReadFrames reads data frames and publishes unframed data to the channel. @@ -71,6 +204,106 @@ type Framer interface { WriteFrames(io.Writer, chan []byte) error } +type ClientProtocol struct { + handler ClientHandler + commands chan *Command + in, out chan []byte + logger *logrus.Logger + state protocolState +} + +func (p *ClientProtocol) Handle() error { + var ( + command *Command + response *Response + err error + ) + + for { + p.logger.Debugf("handling protocol state %s", p.state) + + switch p.state { + case cmdAnnounce: + command = <-p.commands + if command == nil { + return errCommandExpected + } + + err = p.handler.Send(command, p.out) + if err != nil { + p.logger.WithError(err).Error("could not send command announce") + + break + } + + p.state = respAnnounce + case respAnnounce: + response, err = p.handler.ResponseAnnounce(p.in) + if err != nil { + p.logger.WithError(err).Error("could not handle response announce") + + p.state = cmdAnnounce + + break + } + + p.state = respData + case respData: + if response == nil { + return errResponseExpected + } + + err = p.handler.ResponseData(p.in, response) + if err != nil { + p.logger.WithError(err).Error("could not handle response data") + + if errors.Is(err, ErrResponseDataTimeoutExpired) { + p.state = cmdAnnounce + } else { + p.state = respAnnounce + } + + break + } + + p.state = handleResponse + case handleResponse: + if response == nil { + return errResponseExpected + } + + err = p.handler.HandleResponse(response) + if err != nil { + p.logger.WithError(err).Error("could not handle response") + + p.state = cmdAnnounce + + break + } + + p.state = cmdAnnounce + default: + return fmt.Errorf("unknown protocol state %s", p.state) + } + } +} + +func NewClient( + handler ClientHandler, + commands chan *Command, + in, out chan []byte, + logger *logrus.Logger, +) *ClientProtocol { + return &ClientProtocol{ + handler: handler, + commands: commands, + in: in, + out: out, + logger: logger, + state: cmdAnnounce, + } +} + const bufferSize = 1024 const readInterval = 50 * time.Millisecond @@ -127,7 +360,7 @@ func (c *COBSFramer) ReadFrames(reader io.Reader, frameChan chan []byte) error { if err = cobs.Verify(data, c.config); err != nil { c.logger.WithError(err).Warnf("skipping invalid frame of %d bytes", len(data)) - break + continue } frame = cobs.Decode(data, c.config)