Refactor server handler

- rename protocols.Handler to ServerHandler
- rename ServerHandler methods to better express their purpose
- pass command and response as parameters
- simplify state machine and handle errors in serial/seriallink.go
- implement command read timeout
- remove currentCommand and currentResponse fields from MsgPackHandler
main
Jan Dittberner 2 years ago
parent 9905d748d9
commit f429d3da45

@ -21,6 +21,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
"time"
"github.com/shamaton/msgpackgen/msgpack" "github.com/shamaton/msgpackgen/msgpack"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -32,185 +33,186 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages" "git.cacert.org/cacert-gosigner/pkg/messages"
) )
// MsgPackHandler is a Handler implementation for the msgpack serialization format. const readCommandTimeOut = 5 * time.Second
var errReadCommandTimeout = errors.New("read command timeout expired")
// MsgPackHandler is a ServerHandler implementation for the msgpack serialization format.
type MsgPackHandler struct { type MsgPackHandler struct {
logger *logrus.Logger logger *logrus.Logger
healthHandler *health.Handler healthHandler *health.Handler
fetchCRLHandler *revoking.FetchCRLHandler fetchCRLHandler *revoking.FetchCRLHandler
currentCommand *protocol.Command
currentResponse *protocol.Response
lock sync.Mutex lock sync.Mutex
} }
func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) error { func (m *MsgPackHandler) CommandAnnounce(frames chan []byte) (*protocol.Command, error) {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock() defer m.lock.Unlock()
frame := <-frames
var ann messages.CommandAnnounce var ann messages.CommandAnnounce
if err := msgpack.Unmarshal(frame, &ann); err != nil { if err := msgpack.Unmarshal(frame, &ann); err != nil {
return fmt.Errorf("could not unmarshal command announcement: %w", err) return nil, fmt.Errorf("could not unmarshal command announcement: %w", err)
} }
m.logger.WithField("announcement", &ann).Info("received command announcement") m.logger.WithField("announcement", &ann).Info("received command announcement")
m.currentCommand = &protocol.Command{Announce: &ann} return &protocol.Command{Announce: &ann}, nil
return nil
} }
func (m *MsgPackHandler) HandleCommand(frame []byte) error { func (m *MsgPackHandler) CommandData(frames chan []byte, command *protocol.Command) error {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock() defer m.lock.Unlock()
err := m.parseCommand(frame) select {
if err != nil { case frame := <-frames:
m.currentResponse = m.buildErrorResponse(err.Error()) err := m.parseCommand(frame, command)
if err != nil {
m.logCommandResponse() return err
}
return nil return nil
case <-time.After(readCommandTimeOut):
return errReadCommandTimeout
} }
}
err = m.handleCommand() func (m *MsgPackHandler) HandleCommand(command *protocol.Command) (*protocol.Response, error) {
m.lock.Lock()
defer m.lock.Unlock()
var (
response *protocol.Response
err error
)
response, err = m.handleCommand(command)
if err != nil { if err != nil {
m.logger.WithError(err).Error("command handling failed") m.logger.WithError(err).Error("command handling failed")
return err response = m.buildErrorResponse(command.Announce.ID, "command handling failed")
} }
m.logCommandResponse() m.logCommandResponse(command, response)
m.currentCommand = nil return response, nil
return nil
} }
func (m *MsgPackHandler) logCommandResponse() { func (m *MsgPackHandler) logCommandResponse(command *protocol.Command, response *protocol.Response) {
m.logger.WithField("command", m.currentCommand.Announce).Info("handled command") m.logger.WithField("command", command.Announce).Info("handled command")
m.logger.WithField( m.logger.WithField("command", command).WithField("response", response).Debug("command and response")
"command",
m.currentCommand,
).WithField(
"response",
m.currentResponse,
).Debug("command and response")
} }
func (m *MsgPackHandler) ResponseAnnounce() ([]byte, error) { func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) error {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock() defer m.lock.Unlock()
announceData, err := msgpack.Marshal(m.currentResponse.Announce) announce, err := msgpack.Marshal(response)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not marshal response announcement: %w", err) return fmt.Errorf("could not marshal response announcement: %w", err)
} }
m.logger.WithField("announcement", m.currentResponse.Announce).Debug("write response announcement") m.logger.WithField("length", len(announce)).Debug("write response announcement")
return announceData, nil
}
func (m *MsgPackHandler) ResponseData() ([]byte, error) { out <- announce
m.lock.Lock()
defer m.lock.Unlock()
responseData, err := msgpack.Marshal(m.currentResponse.Response) data, err := msgpack.Marshal(response.Response)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not marshal response: %w", err) return fmt.Errorf("could not marshal response: %w", err)
} }
m.logger.WithField("response", m.currentResponse.Response).Debug("write response") m.logger.WithField("length", len(data)).Debug("write response")
return responseData, nil out <- announce
return nil
} }
func (m *MsgPackHandler) parseHealthCommand(frame []byte) error { func (m *MsgPackHandler) parseHealthCommand(frame []byte) (*messages.HealthCommand, error) {
var command messages.HealthCommand var command messages.HealthCommand
if err := msgpack.Unmarshal(frame, &command); err != nil { if err := msgpack.Unmarshal(frame, &command); err != nil {
m.logger.WithError(err).Error("unmarshal failed") m.logger.WithError(err).Error("unmarshal failed")
return errors.New("could not unmarshal health command") return nil, errors.New("could not unmarshal health command")
} }
m.currentCommand.Command = &command return &command, nil
return nil
} }
func (m *MsgPackHandler) parseFetchCRLCommand(frame []byte) error { func (m *MsgPackHandler) parseFetchCRLCommand(frame []byte) (*messages.FetchCRLCommand, error) {
var command messages.FetchCRLCommand var command messages.FetchCRLCommand
if err := msgpack.Unmarshal(frame, &command); err != nil { if err := msgpack.Unmarshal(frame, &command); err != nil {
m.logger.WithError(err).Error("unmarshal failed") m.logger.WithError(err).Error("unmarshal failed")
return errors.New("could not unmarshal fetch crl command") return nil, errors.New("could not unmarshal fetch crl command")
} }
m.currentCommand.Command = &command return &command, nil
return nil
}
func (m *MsgPackHandler) currentID() string {
return m.currentCommand.Announce.ID
} }
func (m *MsgPackHandler) handleCommand() error { func (m *MsgPackHandler) handleCommand(command *protocol.Command) (*protocol.Response, error) {
var ( var (
err error
responseData interface{}
responseCode messages.ResponseCode responseCode messages.ResponseCode
responseData interface{}
) )
switch m.currentCommand.Command.(type) { switch cmd := command.Command.(type) {
case *messages.HealthCommand: case *messages.HealthCommand:
response, err := m.handleHealthCommand() response, err := m.handleHealthCommand()
if err != nil { if err != nil {
return err return nil, err
} }
responseCode, responseData = messages.RespHealth, response responseCode, responseData = messages.RespHealth, response
case *messages.FetchCRLCommand: case *messages.FetchCRLCommand:
response, err := m.handleFetchCRLCommand() response, err := m.handleFetchCRLCommand(cmd)
if err != nil { if err != nil {
return err return nil, err
} }
responseCode, responseData = messages.RespFetchCRL, response responseCode, responseData = messages.RespFetchCRL, response
default: default:
return fmt.Errorf("unhandled command %s", m.currentCommand.Announce) return nil, fmt.Errorf("unhandled command %s", command.Announce)
} }
if err != nil { return &protocol.Response{
return fmt.Errorf("error from command handler: %w", err) Announce: messages.BuildResponseAnnounce(responseCode, command.Announce.ID),
}
m.currentResponse = &protocol.Response{
Announce: messages.BuildResponseAnnounce(responseCode, m.currentID()),
Response: responseData, Response: responseData,
} }, nil
return nil
} }
func (m *MsgPackHandler) buildErrorResponse(errMsg string) *protocol.Response { func (m *MsgPackHandler) buildErrorResponse(commandID string, errMsg string) *protocol.Response {
return &protocol.Response{ return &protocol.Response{
Announce: messages.BuildResponseAnnounce(messages.RespError, m.currentID()), Announce: messages.BuildResponseAnnounce(messages.RespError, commandID),
Response: &messages.ErrorResponse{Message: errMsg}, Response: &messages.ErrorResponse{Message: errMsg},
} }
} }
func (m *MsgPackHandler) parseCommand(frame []byte) error { func (m *MsgPackHandler) parseCommand(frame []byte, command *protocol.Command) error {
switch m.currentCommand.Announce.Code { switch command.Announce.Code {
case messages.CmdHealth: case messages.CmdHealth:
return m.parseHealthCommand(frame) healthCommand, err := m.parseHealthCommand(frame)
if err != nil {
return err
}
command.Command = healthCommand
case messages.CmdFetchCRL: case messages.CmdFetchCRL:
return m.parseFetchCRLCommand(frame) fetchCRLCommand, err := m.parseFetchCRLCommand(frame)
if err != nil {
return err
}
command.Command = fetchCRLCommand
default: default:
return fmt.Errorf("unhandled command code %s", m.currentCommand.Announce.Code) return fmt.Errorf("unhandled command code %s", command.Announce.Code)
} }
return nil
} }
func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error) { func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error) {
@ -235,27 +237,20 @@ func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error)
return response, nil return response, nil
} }
func (m *MsgPackHandler) handleFetchCRLCommand() (*messages.FetchCRLResponse, error) { func (m *MsgPackHandler) handleFetchCRLCommand(command *messages.FetchCRLCommand) (*messages.FetchCRLResponse, error) {
fetchCRLPayload, ok := m.currentCommand.Command.(*messages.FetchCRLCommand) res, err := m.fetchCRLHandler.FetchCRL(command.IssuerID)
if !ok {
return nil, fmt.Errorf("could not use payload as FetchCRLPayload")
}
res, err := m.fetchCRLHandler.FetchCRL(fetchCRLPayload.IssuerID)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not fetch CRL: %w", err) return nil, fmt.Errorf("could not fetch CRL: %w", err)
} }
response := &messages.FetchCRLResponse{ return &messages.FetchCRLResponse{
IsDelta: false, IsDelta: false,
CRLNumber: res.Number, CRLNumber: res.Number,
CRLData: res.CRLData, CRLData: res.CRLData,
} }, nil
return response, nil
} }
func New(logger *logrus.Logger, handlers ...RegisterHandler) (protocol.Handler, error) { func New(logger *logrus.Logger, handlers ...RegisterHandler) (protocol.ServerHandler, error) {
messages.RegisterGeneratedResolver() messages.RegisterGeneratedResolver()
h := &MsgPackHandler{ h := &MsgPackHandler{

@ -20,8 +20,8 @@ package serial
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sync"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tarm/serial" "github.com/tarm/serial"
@ -35,22 +35,15 @@ type protocolState int8
const ( const (
cmdAnnounce protocolState = iota cmdAnnounce protocolState = iota
cmdData cmdData
respAnnounce handleCommand
respData respond
) )
var validTransitions = map[protocolState]protocolState{
cmdAnnounce: cmdData,
cmdData: respAnnounce,
respAnnounce: respData,
respData: cmdAnnounce,
}
var protocolStateNames = map[protocolState]string{ var protocolStateNames = map[protocolState]string{
cmdAnnounce: "CMD ANNOUNCE", cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA", cmdData: "CMD DATA",
respAnnounce: "RESP ANNOUNCE", handleCommand: "RESP ANNOUNCE",
respData: "RESP DATA", respond: "RESP DATA",
} }
func (p protocolState) String() string { func (p protocolState) String() string {
@ -62,13 +55,12 @@ func (p protocolState) String() string {
} }
type Handler struct { type Handler struct {
protocolHandler protocol.Handler protocolHandler protocol.ServerHandler
protocolState protocolState protocolState protocolState
framer protocol.Framer framer protocol.Framer
config *serial.Config config *serial.Config
port *serial.Port port *serial.Port
logger *logrus.Logger logger *logrus.Logger
lock sync.Mutex
framesIn chan []byte framesIn chan []byte
framesOut chan []byte framesOut chan []byte
} }
@ -95,152 +87,129 @@ func (h *Handler) Close() error {
func (h *Handler) Run(ctx context.Context) error { func (h *Handler) Run(ctx context.Context) error {
h.protocolState = cmdAnnounce h.protocolState = cmdAnnounce
errors := make(chan error) protocolErrors, framerErrors := make(chan error), make(chan error)
go func() { go func() {
err := h.framer.ReadFrames(h.port, h.framesIn) err := h.framer.ReadFrames(h.port, h.framesIn)
errors <- err framerErrors <- err
}() }()
go func() { go func() {
err := h.framer.WriteFrames(h.port, h.framesOut) err := h.framer.WriteFrames(h.port, h.framesOut)
errors <- err framerErrors <- err
}()
go func() {
err := h.handleProtocolState()
protocolErrors <- err
}() }()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case err := <-framerErrors:
case err := <-errors:
if err != nil { if err != nil {
return fmt.Errorf("error from handler loop: %w", err) return fmt.Errorf("error from framer: %w", err)
} }
return nil return nil
case err := <-protocolErrors:
default: if err != nil {
if err := h.handleProtocolState(); err != nil { return fmt.Errorf("error from protocol handler: %w", err)
return err
} }
return nil
} }
} }
} }
func (h *Handler) nextState() error { var errCommandExpected = errors.New("command must not be nil")
next, ok := validTransitions[h.protocolState] var errResponseExpected = errors.New("response must not be nil")
if !ok {
return fmt.Errorf("illegal protocol state %s", h.protocolState)
}
h.protocolState = next
return nil
}
func (h *Handler) handleProtocolState() error { func (h *Handler) handleProtocolState() error {
h.logger.Tracef("handling protocol state %s", h.protocolState) var (
command *protocol.Command
h.lock.Lock() response *protocol.Response
defer h.lock.Unlock() err error
)
switch h.protocolState {
case cmdAnnounce:
if err := h.handleCmdAnnounce(); err != nil {
return err
}
case cmdData:
if err := h.handleCmdData(); err != nil {
return err
}
case respAnnounce:
if err := h.handleRespAnnounce(); err != nil {
return err
}
case respData:
if err := h.handleRespData(); err != nil {
return err
}
default:
return fmt.Errorf("unknown protocol state %s", h.protocolState)
}
return nil for {
} h.logger.Debugf("handling protocol state %s", h.protocolState)
func (h *Handler) handleCmdAnnounce() error {
h.logger.Trace("waiting for command announce")
frame := <-h.framesIn
if frame == nil { switch h.protocolState {
return nil case cmdAnnounce:
} command, err = h.protocolHandler.CommandAnnounce(h.framesIn)
if err != nil {
h.logger.WithError(err).Error("could not handle command announce")
if err := h.protocolHandler.HandleCommandAnnounce(frame); err != nil { break
return fmt.Errorf("command announce handling failed: %w", err) }
}
if err := h.nextState(); err != nil { h.protocolState = cmdData
return err case cmdData:
} if command == nil {
return errCommandExpected
}
return nil err = h.protocolHandler.CommandData(h.framesIn, command)
} if err != nil {
h.logger.WithError(err).Error("could not handle command data")
func (h *Handler) handleCmdData() error { h.protocolState = cmdAnnounce
h.logger.Trace("waiting for command data")
frame := <-h.framesIn break
}
if frame == nil { h.protocolState = handleCommand
return nil case handleCommand:
} if command == nil {
return errCommandExpected
}
if err := h.protocolHandler.HandleCommand(frame); err != nil { response, err = h.protocolHandler.HandleCommand(command)
return fmt.Errorf("command handler failed: %w", err) if err != nil {
} h.logger.WithError(err).Error("could not handle command")
if err := h.nextState(); err != nil { h.protocolState = cmdAnnounce
return err
}
return nil break
} }
func (h *Handler) handleRespAnnounce() error { command = nil
frame, err := h.protocolHandler.ResponseAnnounce()
if err != nil {
return fmt.Errorf("could not get response announcement: %w", err)
}
h.framesOut <- frame h.protocolState = respond
case respond:
if response == nil {
return errResponseExpected
}
if err := h.nextState(); err != nil { err = h.protocolHandler.Respond(response, h.framesOut)
return err if err != nil {
} h.logger.WithError(err).Error("could not respond")
return nil h.protocolState = cmdAnnounce
}
func (h *Handler) handleRespData() error { break
frame, err := h.protocolHandler.ResponseData() }
if err != nil {
return fmt.Errorf("could not get response data: %w", err)
}
h.framesOut <- frame response = nil
if err := h.nextState(); err != nil { h.protocolState = cmdAnnounce
return err default:
return fmt.Errorf("unknown protocol state %s", h.protocolState)
}
} }
return nil
} }
func New(cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.Handler) (*Handler, error) { func New(
cfg *config.Serial,
logger *logrus.Logger,
protocolHandler protocol.ServerHandler,
) (*Handler, error) {
h := &Handler{ h := &Handler{
protocolHandler: protocolHandler, protocolHandler: protocolHandler,
logger: logger, logger: logger,

@ -51,16 +51,16 @@ func (r *Response) String() string {
return fmt.Sprintf("Rsp[announce={%s}, data={%s}]", r.Announce, r.Response) return fmt.Sprintf("Rsp[announce={%s}, data={%s}]", r.Announce, r.Response)
} }
// Handler is responsible for parsing incoming frames and calling commands // ServerHandler is responsible for parsing incoming frames and calling commands
type Handler interface { type ServerHandler interface {
// HandleCommandAnnounce handles the initial announcement of a command. // CommandAnnounce handles the initial announcement of a command.
HandleCommandAnnounce([]byte) error CommandAnnounce(chan []byte) (*Command, error)
// HandleCommand handles the command data. // CommandData handles the command data.
HandleCommand([]byte) error CommandData(chan []byte, *Command) error
// ResponseAnnounce generates the announcement for a response. // HandleCommand executes the command, generating a response.
ResponseAnnounce() ([]byte, error) HandleCommand(*Command) (*Response, error)
// ResponseData generates the response data. // Respond generates the response for a command.
ResponseData() ([]byte, error) Respond(*Response, chan []byte) error
} }
// Framer handles bytes on the wire by adding or removing framing information. // Framer handles bytes on the wire by adding or removing framing information.

Loading…
Cancel
Save