diff --git a/internal/handler/msgpack.go b/internal/handler/msgpack.go index 80d4a8a..2ad12ec 100644 --- a/internal/handler/msgpack.go +++ b/internal/handler/msgpack.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "sync" + "time" "github.com/shamaton/msgpackgen/msgpack" "github.com/sirupsen/logrus" @@ -32,185 +33,186 @@ import ( "git.cacert.org/cacert-gosigner/pkg/messages" ) -// MsgPackHandler is a Handler implementation for the msgpack serialization format. +const readCommandTimeOut = 5 * time.Second + +var errReadCommandTimeout = errors.New("read command timeout expired") + +// MsgPackHandler is a ServerHandler implementation for the msgpack serialization format. type MsgPackHandler struct { logger *logrus.Logger healthHandler *health.Handler fetchCRLHandler *revoking.FetchCRLHandler - currentCommand *protocol.Command - currentResponse *protocol.Response lock sync.Mutex } -func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) error { +func (m *MsgPackHandler) CommandAnnounce(frames chan []byte) (*protocol.Command, error) { m.lock.Lock() defer m.lock.Unlock() + frame := <-frames + var ann messages.CommandAnnounce if err := msgpack.Unmarshal(frame, &ann); err != nil { - return fmt.Errorf("could not unmarshal command announcement: %w", err) + return nil, fmt.Errorf("could not unmarshal command announcement: %w", err) } m.logger.WithField("announcement", &ann).Info("received command announcement") - m.currentCommand = &protocol.Command{Announce: &ann} - - return nil + return &protocol.Command{Announce: &ann}, nil } -func (m *MsgPackHandler) HandleCommand(frame []byte) error { +func (m *MsgPackHandler) CommandData(frames chan []byte, command *protocol.Command) error { m.lock.Lock() defer m.lock.Unlock() - err := m.parseCommand(frame) - if err != nil { - m.currentResponse = m.buildErrorResponse(err.Error()) - - m.logCommandResponse() + select { + case frame := <-frames: + err := m.parseCommand(frame, command) + if err != nil { + return err + } return nil + case <-time.After(readCommandTimeOut): + return errReadCommandTimeout } +} - err = m.handleCommand() +func (m *MsgPackHandler) HandleCommand(command *protocol.Command) (*protocol.Response, error) { + m.lock.Lock() + defer m.lock.Unlock() + + var ( + response *protocol.Response + err error + ) + + response, err = m.handleCommand(command) if err != nil { m.logger.WithError(err).Error("command handling failed") - return err + response = m.buildErrorResponse(command.Announce.ID, "command handling failed") } - m.logCommandResponse() + m.logCommandResponse(command, response) - m.currentCommand = nil - - return nil + return response, nil } -func (m *MsgPackHandler) logCommandResponse() { - m.logger.WithField("command", m.currentCommand.Announce).Info("handled command") - m.logger.WithField( - "command", - m.currentCommand, - ).WithField( - "response", - m.currentResponse, - ).Debug("command and response") +func (m *MsgPackHandler) logCommandResponse(command *protocol.Command, response *protocol.Response) { + m.logger.WithField("command", command.Announce).Info("handled command") + m.logger.WithField("command", command).WithField("response", response).Debug("command and response") } -func (m *MsgPackHandler) ResponseAnnounce() ([]byte, error) { +func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) error { m.lock.Lock() defer m.lock.Unlock() - announceData, err := msgpack.Marshal(m.currentResponse.Announce) + announce, err := msgpack.Marshal(response) if err != nil { - return nil, fmt.Errorf("could not marshal response announcement: %w", err) + return fmt.Errorf("could not marshal response announcement: %w", err) } - m.logger.WithField("announcement", m.currentResponse.Announce).Debug("write response announcement") - - return announceData, nil -} + m.logger.WithField("length", len(announce)).Debug("write response announcement") -func (m *MsgPackHandler) ResponseData() ([]byte, error) { - m.lock.Lock() - defer m.lock.Unlock() + out <- announce - responseData, err := msgpack.Marshal(m.currentResponse.Response) + data, err := msgpack.Marshal(response.Response) if err != nil { - return nil, fmt.Errorf("could not marshal response: %w", err) + return fmt.Errorf("could not marshal response: %w", err) } - m.logger.WithField("response", m.currentResponse.Response).Debug("write response") + m.logger.WithField("length", len(data)).Debug("write response") - return responseData, nil + out <- announce + + return nil } -func (m *MsgPackHandler) parseHealthCommand(frame []byte) error { +func (m *MsgPackHandler) parseHealthCommand(frame []byte) (*messages.HealthCommand, error) { var command messages.HealthCommand if err := msgpack.Unmarshal(frame, &command); err != nil { m.logger.WithError(err).Error("unmarshal failed") - return errors.New("could not unmarshal health command") + return nil, errors.New("could not unmarshal health command") } - m.currentCommand.Command = &command - - return nil + return &command, nil } -func (m *MsgPackHandler) parseFetchCRLCommand(frame []byte) error { +func (m *MsgPackHandler) parseFetchCRLCommand(frame []byte) (*messages.FetchCRLCommand, error) { var command messages.FetchCRLCommand if err := msgpack.Unmarshal(frame, &command); err != nil { m.logger.WithError(err).Error("unmarshal failed") - return errors.New("could not unmarshal fetch crl command") + return nil, errors.New("could not unmarshal fetch crl command") } - m.currentCommand.Command = &command - - return nil -} - -func (m *MsgPackHandler) currentID() string { - return m.currentCommand.Announce.ID + return &command, nil } -func (m *MsgPackHandler) handleCommand() error { +func (m *MsgPackHandler) handleCommand(command *protocol.Command) (*protocol.Response, error) { var ( - err error - responseData interface{} responseCode messages.ResponseCode + responseData interface{} ) - switch m.currentCommand.Command.(type) { + switch cmd := command.Command.(type) { case *messages.HealthCommand: response, err := m.handleHealthCommand() if err != nil { - return err + return nil, err } responseCode, responseData = messages.RespHealth, response case *messages.FetchCRLCommand: - response, err := m.handleFetchCRLCommand() + response, err := m.handleFetchCRLCommand(cmd) if err != nil { - return err + return nil, err } responseCode, responseData = messages.RespFetchCRL, response default: - return fmt.Errorf("unhandled command %s", m.currentCommand.Announce) + return nil, fmt.Errorf("unhandled command %s", command.Announce) } - if err != nil { - return fmt.Errorf("error from command handler: %w", err) - } - - m.currentResponse = &protocol.Response{ - Announce: messages.BuildResponseAnnounce(responseCode, m.currentID()), + return &protocol.Response{ + Announce: messages.BuildResponseAnnounce(responseCode, command.Announce.ID), Response: responseData, - } - - return nil + }, nil } -func (m *MsgPackHandler) buildErrorResponse(errMsg string) *protocol.Response { +func (m *MsgPackHandler) buildErrorResponse(commandID string, errMsg string) *protocol.Response { return &protocol.Response{ - Announce: messages.BuildResponseAnnounce(messages.RespError, m.currentID()), + Announce: messages.BuildResponseAnnounce(messages.RespError, commandID), Response: &messages.ErrorResponse{Message: errMsg}, } } -func (m *MsgPackHandler) parseCommand(frame []byte) error { - switch m.currentCommand.Announce.Code { +func (m *MsgPackHandler) parseCommand(frame []byte, command *protocol.Command) error { + switch command.Announce.Code { case messages.CmdHealth: - return m.parseHealthCommand(frame) + healthCommand, err := m.parseHealthCommand(frame) + if err != nil { + return err + } + + command.Command = healthCommand case messages.CmdFetchCRL: - return m.parseFetchCRLCommand(frame) + fetchCRLCommand, err := m.parseFetchCRLCommand(frame) + if err != nil { + return err + } + + command.Command = fetchCRLCommand default: - return fmt.Errorf("unhandled command code %s", m.currentCommand.Announce.Code) + return fmt.Errorf("unhandled command code %s", command.Announce.Code) } + + return nil } func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error) { @@ -235,27 +237,20 @@ func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error) return response, nil } -func (m *MsgPackHandler) handleFetchCRLCommand() (*messages.FetchCRLResponse, error) { - fetchCRLPayload, ok := m.currentCommand.Command.(*messages.FetchCRLCommand) - if !ok { - return nil, fmt.Errorf("could not use payload as FetchCRLPayload") - } - - res, err := m.fetchCRLHandler.FetchCRL(fetchCRLPayload.IssuerID) +func (m *MsgPackHandler) handleFetchCRLCommand(command *messages.FetchCRLCommand) (*messages.FetchCRLResponse, error) { + res, err := m.fetchCRLHandler.FetchCRL(command.IssuerID) if err != nil { return nil, fmt.Errorf("could not fetch CRL: %w", err) } - response := &messages.FetchCRLResponse{ + return &messages.FetchCRLResponse{ IsDelta: false, CRLNumber: res.Number, CRLData: res.CRLData, - } - - return response, nil + }, nil } -func New(logger *logrus.Logger, handlers ...RegisterHandler) (protocol.Handler, error) { +func New(logger *logrus.Logger, handlers ...RegisterHandler) (protocol.ServerHandler, error) { messages.RegisterGeneratedResolver() h := &MsgPackHandler{ diff --git a/internal/serial/seriallink.go b/internal/serial/seriallink.go index 41cc9a7..8f6557e 100644 --- a/internal/serial/seriallink.go +++ b/internal/serial/seriallink.go @@ -20,8 +20,8 @@ package serial import ( "context" + "errors" "fmt" - "sync" "github.com/sirupsen/logrus" "github.com/tarm/serial" @@ -35,22 +35,15 @@ type protocolState int8 const ( cmdAnnounce protocolState = iota cmdData - respAnnounce - respData + handleCommand + respond ) -var validTransitions = map[protocolState]protocolState{ - cmdAnnounce: cmdData, - cmdData: respAnnounce, - respAnnounce: respData, - respData: cmdAnnounce, -} - var protocolStateNames = map[protocolState]string{ - cmdAnnounce: "CMD ANNOUNCE", - cmdData: "CMD DATA", - respAnnounce: "RESP ANNOUNCE", - respData: "RESP DATA", + cmdAnnounce: "CMD ANNOUNCE", + cmdData: "CMD DATA", + handleCommand: "RESP ANNOUNCE", + respond: "RESP DATA", } func (p protocolState) String() string { @@ -62,13 +55,12 @@ func (p protocolState) String() string { } type Handler struct { - protocolHandler protocol.Handler + protocolHandler protocol.ServerHandler protocolState protocolState framer protocol.Framer config *serial.Config port *serial.Port logger *logrus.Logger - lock sync.Mutex framesIn chan []byte framesOut chan []byte } @@ -95,152 +87,129 @@ func (h *Handler) Close() error { func (h *Handler) Run(ctx context.Context) error { h.protocolState = cmdAnnounce - errors := make(chan error) + protocolErrors, framerErrors := make(chan error), make(chan error) go func() { err := h.framer.ReadFrames(h.port, h.framesIn) - errors <- err + framerErrors <- err }() go func() { err := h.framer.WriteFrames(h.port, h.framesOut) - errors <- err + framerErrors <- err + }() + + go func() { + err := h.handleProtocolState() + + protocolErrors <- err }() for { select { case <-ctx.Done(): return nil - - case err := <-errors: + case err := <-framerErrors: if err != nil { - return fmt.Errorf("error from handler loop: %w", err) + return fmt.Errorf("error from framer: %w", err) } return nil - - default: - if err := h.handleProtocolState(); err != nil { - return err + case err := <-protocolErrors: + if err != nil { + return fmt.Errorf("error from protocol handler: %w", err) } + + return nil } } } -func (h *Handler) nextState() error { - next, ok := validTransitions[h.protocolState] - if !ok { - return fmt.Errorf("illegal protocol state %s", h.protocolState) - } - - h.protocolState = next - - return nil -} +var errCommandExpected = errors.New("command must not be nil") +var errResponseExpected = errors.New("response must not be nil") func (h *Handler) handleProtocolState() error { - h.logger.Tracef("handling protocol state %s", h.protocolState) - - h.lock.Lock() - defer h.lock.Unlock() - - switch h.protocolState { - case cmdAnnounce: - if err := h.handleCmdAnnounce(); err != nil { - return err - } - case cmdData: - if err := h.handleCmdData(); err != nil { - return err - } - case respAnnounce: - if err := h.handleRespAnnounce(); err != nil { - return err - } - case respData: - if err := h.handleRespData(); err != nil { - return err - } - default: - return fmt.Errorf("unknown protocol state %s", h.protocolState) - } + var ( + command *protocol.Command + response *protocol.Response + err error + ) - return nil -} - -func (h *Handler) handleCmdAnnounce() error { - h.logger.Trace("waiting for command announce") - - frame := <-h.framesIn + for { + h.logger.Debugf("handling protocol state %s", h.protocolState) - if frame == nil { - return nil - } + switch h.protocolState { + case cmdAnnounce: + command, err = h.protocolHandler.CommandAnnounce(h.framesIn) + if err != nil { + h.logger.WithError(err).Error("could not handle command announce") - if err := h.protocolHandler.HandleCommandAnnounce(frame); err != nil { - return fmt.Errorf("command announce handling failed: %w", err) - } + break + } - if err := h.nextState(); err != nil { - return err - } + h.protocolState = cmdData + case cmdData: + if command == nil { + return errCommandExpected + } - return nil -} + err = h.protocolHandler.CommandData(h.framesIn, command) + if err != nil { + h.logger.WithError(err).Error("could not handle command data") -func (h *Handler) handleCmdData() error { - h.logger.Trace("waiting for command data") + h.protocolState = cmdAnnounce - frame := <-h.framesIn + break + } - if frame == nil { - return nil - } + h.protocolState = handleCommand + case handleCommand: + if command == nil { + return errCommandExpected + } - if err := h.protocolHandler.HandleCommand(frame); err != nil { - return fmt.Errorf("command handler failed: %w", err) - } + response, err = h.protocolHandler.HandleCommand(command) + if err != nil { + h.logger.WithError(err).Error("could not handle command") - if err := h.nextState(); err != nil { - return err - } + h.protocolState = cmdAnnounce - return nil -} + break + } -func (h *Handler) handleRespAnnounce() error { - frame, err := h.protocolHandler.ResponseAnnounce() - if err != nil { - return fmt.Errorf("could not get response announcement: %w", err) - } + command = nil - h.framesOut <- frame + h.protocolState = respond + case respond: + if response == nil { + return errResponseExpected + } - if err := h.nextState(); err != nil { - return err - } + err = h.protocolHandler.Respond(response, h.framesOut) + if err != nil { + h.logger.WithError(err).Error("could not respond") - return nil -} + h.protocolState = cmdAnnounce -func (h *Handler) handleRespData() error { - frame, err := h.protocolHandler.ResponseData() - if err != nil { - return fmt.Errorf("could not get response data: %w", err) - } + break + } - h.framesOut <- frame + response = nil - if err := h.nextState(); err != nil { - return err + h.protocolState = cmdAnnounce + default: + return fmt.Errorf("unknown protocol state %s", h.protocolState) + } } - - return nil } -func New(cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.Handler) (*Handler, error) { +func New( + cfg *config.Serial, + logger *logrus.Logger, + protocolHandler protocol.ServerHandler, +) (*Handler, error) { h := &Handler{ protocolHandler: protocolHandler, logger: logger, diff --git a/pkg/protocol/protocol.go b/pkg/protocol/protocol.go index 58dbda6..f2b6ba8 100644 --- a/pkg/protocol/protocol.go +++ b/pkg/protocol/protocol.go @@ -51,16 +51,16 @@ func (r *Response) String() string { return fmt.Sprintf("Rsp[announce={%s}, data={%s}]", r.Announce, r.Response) } -// Handler is responsible for parsing incoming frames and calling commands -type Handler interface { - // HandleCommandAnnounce handles the initial announcement of a command. - HandleCommandAnnounce([]byte) error - // HandleCommand handles the command data. - HandleCommand([]byte) error - // ResponseAnnounce generates the announcement for a response. - ResponseAnnounce() ([]byte, error) - // ResponseData generates the response data. - ResponseData() ([]byte, error) +// ServerHandler is responsible for parsing incoming frames and calling commands +type ServerHandler interface { + // CommandAnnounce handles the initial announcement of a command. + CommandAnnounce(chan []byte) (*Command, error) + // CommandData handles the command data. + CommandData(chan []byte, *Command) error + // HandleCommand executes the command, generating a response. + HandleCommand(*Command) (*Response, error) + // Respond generates the response for a command. + Respond(*Response, chan []byte) error } // Framer handles bytes on the wire by adding or removing framing information.