Refactor server handler

- rename protocols.Handler to ServerHandler
- rename ServerHandler methods to better express their purpose
- pass command and response as parameters
- simplify state machine and handle errors in serial/seriallink.go
- implement command read timeout
- remove currentCommand and currentResponse fields from MsgPackHandler
main
Jan Dittberner 1 year ago
parent 9905d748d9
commit f429d3da45

@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"sync"
"time"
"github.com/shamaton/msgpackgen/msgpack"
"github.com/sirupsen/logrus"
@ -32,185 +33,186 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages"
)
// MsgPackHandler is a Handler implementation for the msgpack serialization format.
const readCommandTimeOut = 5 * time.Second
var errReadCommandTimeout = errors.New("read command timeout expired")
// MsgPackHandler is a ServerHandler implementation for the msgpack serialization format.
type MsgPackHandler struct {
logger *logrus.Logger
healthHandler *health.Handler
fetchCRLHandler *revoking.FetchCRLHandler
currentCommand *protocol.Command
currentResponse *protocol.Response
lock sync.Mutex
}
func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) error {
func (m *MsgPackHandler) CommandAnnounce(frames chan []byte) (*protocol.Command, error) {
m.lock.Lock()
defer m.lock.Unlock()
frame := <-frames
var ann messages.CommandAnnounce
if err := msgpack.Unmarshal(frame, &ann); err != nil {
return fmt.Errorf("could not unmarshal command announcement: %w", err)
return nil, fmt.Errorf("could not unmarshal command announcement: %w", err)
}
m.logger.WithField("announcement", &ann).Info("received command announcement")
m.currentCommand = &protocol.Command{Announce: &ann}
return nil
return &protocol.Command{Announce: &ann}, nil
}
func (m *MsgPackHandler) HandleCommand(frame []byte) error {
func (m *MsgPackHandler) CommandData(frames chan []byte, command *protocol.Command) error {
m.lock.Lock()
defer m.lock.Unlock()
err := m.parseCommand(frame)
if err != nil {
m.currentResponse = m.buildErrorResponse(err.Error())
m.logCommandResponse()
select {
case frame := <-frames:
err := m.parseCommand(frame, command)
if err != nil {
return err
}
return nil
case <-time.After(readCommandTimeOut):
return errReadCommandTimeout
}
}
err = m.handleCommand()
func (m *MsgPackHandler) HandleCommand(command *protocol.Command) (*protocol.Response, error) {
m.lock.Lock()
defer m.lock.Unlock()
var (
response *protocol.Response
err error
)
response, err = m.handleCommand(command)
if err != nil {
m.logger.WithError(err).Error("command handling failed")
return err
response = m.buildErrorResponse(command.Announce.ID, "command handling failed")
}
m.logCommandResponse()
m.logCommandResponse(command, response)
m.currentCommand = nil
return nil
return response, nil
}
func (m *MsgPackHandler) logCommandResponse() {
m.logger.WithField("command", m.currentCommand.Announce).Info("handled command")
m.logger.WithField(
"command",
m.currentCommand,
).WithField(
"response",
m.currentResponse,
).Debug("command and response")
func (m *MsgPackHandler) logCommandResponse(command *protocol.Command, response *protocol.Response) {
m.logger.WithField("command", command.Announce).Info("handled command")
m.logger.WithField("command", command).WithField("response", response).Debug("command and response")
}
func (m *MsgPackHandler) ResponseAnnounce() ([]byte, error) {
func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) error {
m.lock.Lock()
defer m.lock.Unlock()
announceData, err := msgpack.Marshal(m.currentResponse.Announce)
announce, err := msgpack.Marshal(response)
if err != nil {
return nil, fmt.Errorf("could not marshal response announcement: %w", err)
return fmt.Errorf("could not marshal response announcement: %w", err)
}
m.logger.WithField("announcement", m.currentResponse.Announce).Debug("write response announcement")
return announceData, nil
}
m.logger.WithField("length", len(announce)).Debug("write response announcement")
func (m *MsgPackHandler) ResponseData() ([]byte, error) {
m.lock.Lock()
defer m.lock.Unlock()
out <- announce
responseData, err := msgpack.Marshal(m.currentResponse.Response)
data, err := msgpack.Marshal(response.Response)
if err != nil {
return nil, fmt.Errorf("could not marshal response: %w", err)
return fmt.Errorf("could not marshal response: %w", err)
}
m.logger.WithField("response", m.currentResponse.Response).Debug("write response")
m.logger.WithField("length", len(data)).Debug("write response")
return responseData, nil
out <- announce
return nil
}
func (m *MsgPackHandler) parseHealthCommand(frame []byte) error {
func (m *MsgPackHandler) parseHealthCommand(frame []byte) (*messages.HealthCommand, error) {
var command messages.HealthCommand
if err := msgpack.Unmarshal(frame, &command); err != nil {
m.logger.WithError(err).Error("unmarshal failed")
return errors.New("could not unmarshal health command")
return nil, errors.New("could not unmarshal health command")
}
m.currentCommand.Command = &command
return nil
return &command, nil
}
func (m *MsgPackHandler) parseFetchCRLCommand(frame []byte) error {
func (m *MsgPackHandler) parseFetchCRLCommand(frame []byte) (*messages.FetchCRLCommand, error) {
var command messages.FetchCRLCommand
if err := msgpack.Unmarshal(frame, &command); err != nil {
m.logger.WithError(err).Error("unmarshal failed")
return errors.New("could not unmarshal fetch crl command")
return nil, errors.New("could not unmarshal fetch crl command")
}
m.currentCommand.Command = &command
return nil
}
func (m *MsgPackHandler) currentID() string {
return m.currentCommand.Announce.ID
return &command, nil
}
func (m *MsgPackHandler) handleCommand() error {
func (m *MsgPackHandler) handleCommand(command *protocol.Command) (*protocol.Response, error) {
var (
err error
responseData interface{}
responseCode messages.ResponseCode
responseData interface{}
)
switch m.currentCommand.Command.(type) {
switch cmd := command.Command.(type) {
case *messages.HealthCommand:
response, err := m.handleHealthCommand()
if err != nil {
return err
return nil, err
}
responseCode, responseData = messages.RespHealth, response
case *messages.FetchCRLCommand:
response, err := m.handleFetchCRLCommand()
response, err := m.handleFetchCRLCommand(cmd)
if err != nil {
return err
return nil, err
}
responseCode, responseData = messages.RespFetchCRL, response
default:
return fmt.Errorf("unhandled command %s", m.currentCommand.Announce)
return nil, fmt.Errorf("unhandled command %s", command.Announce)
}
if err != nil {
return fmt.Errorf("error from command handler: %w", err)
}
m.currentResponse = &protocol.Response{
Announce: messages.BuildResponseAnnounce(responseCode, m.currentID()),
return &protocol.Response{
Announce: messages.BuildResponseAnnounce(responseCode, command.Announce.ID),
Response: responseData,
}
return nil
}, nil
}
func (m *MsgPackHandler) buildErrorResponse(errMsg string) *protocol.Response {
func (m *MsgPackHandler) buildErrorResponse(commandID string, errMsg string) *protocol.Response {
return &protocol.Response{
Announce: messages.BuildResponseAnnounce(messages.RespError, m.currentID()),
Announce: messages.BuildResponseAnnounce(messages.RespError, commandID),
Response: &messages.ErrorResponse{Message: errMsg},
}
}
func (m *MsgPackHandler) parseCommand(frame []byte) error {
switch m.currentCommand.Announce.Code {
func (m *MsgPackHandler) parseCommand(frame []byte, command *protocol.Command) error {
switch command.Announce.Code {
case messages.CmdHealth:
return m.parseHealthCommand(frame)
healthCommand, err := m.parseHealthCommand(frame)
if err != nil {
return err
}
command.Command = healthCommand
case messages.CmdFetchCRL:
return m.parseFetchCRLCommand(frame)
fetchCRLCommand, err := m.parseFetchCRLCommand(frame)
if err != nil {
return err
}
command.Command = fetchCRLCommand
default:
return fmt.Errorf("unhandled command code %s", m.currentCommand.Announce.Code)
return fmt.Errorf("unhandled command code %s", command.Announce.Code)
}
return nil
}
func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error) {
@ -235,27 +237,20 @@ func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error)
return response, nil
}
func (m *MsgPackHandler) handleFetchCRLCommand() (*messages.FetchCRLResponse, error) {
fetchCRLPayload, ok := m.currentCommand.Command.(*messages.FetchCRLCommand)
if !ok {
return nil, fmt.Errorf("could not use payload as FetchCRLPayload")
}
res, err := m.fetchCRLHandler.FetchCRL(fetchCRLPayload.IssuerID)
func (m *MsgPackHandler) handleFetchCRLCommand(command *messages.FetchCRLCommand) (*messages.FetchCRLResponse, error) {
res, err := m.fetchCRLHandler.FetchCRL(command.IssuerID)
if err != nil {
return nil, fmt.Errorf("could not fetch CRL: %w", err)
}
response := &messages.FetchCRLResponse{
return &messages.FetchCRLResponse{
IsDelta: false,
CRLNumber: res.Number,
CRLData: res.CRLData,
}
return response, nil
}, nil
}
func New(logger *logrus.Logger, handlers ...RegisterHandler) (protocol.Handler, error) {
func New(logger *logrus.Logger, handlers ...RegisterHandler) (protocol.ServerHandler, error) {
messages.RegisterGeneratedResolver()
h := &MsgPackHandler{

@ -20,8 +20,8 @@ package serial
import (
"context"
"errors"
"fmt"
"sync"
"github.com/sirupsen/logrus"
"github.com/tarm/serial"
@ -35,22 +35,15 @@ type protocolState int8
const (
cmdAnnounce protocolState = iota
cmdData
respAnnounce
respData
handleCommand
respond
)
var validTransitions = map[protocolState]protocolState{
cmdAnnounce: cmdData,
cmdData: respAnnounce,
respAnnounce: respData,
respData: cmdAnnounce,
}
var protocolStateNames = map[protocolState]string{
cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA",
respAnnounce: "RESP ANNOUNCE",
respData: "RESP DATA",
cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA",
handleCommand: "RESP ANNOUNCE",
respond: "RESP DATA",
}
func (p protocolState) String() string {
@ -62,13 +55,12 @@ func (p protocolState) String() string {
}
type Handler struct {
protocolHandler protocol.Handler
protocolHandler protocol.ServerHandler
protocolState protocolState
framer protocol.Framer
config *serial.Config
port *serial.Port
logger *logrus.Logger
lock sync.Mutex
framesIn chan []byte
framesOut chan []byte
}
@ -95,152 +87,129 @@ func (h *Handler) Close() error {
func (h *Handler) Run(ctx context.Context) error {
h.protocolState = cmdAnnounce
errors := make(chan error)
protocolErrors, framerErrors := make(chan error), make(chan error)
go func() {
err := h.framer.ReadFrames(h.port, h.framesIn)
errors <- err
framerErrors <- err
}()
go func() {
err := h.framer.WriteFrames(h.port, h.framesOut)
errors <- err
framerErrors <- err
}()
go func() {
err := h.handleProtocolState()
protocolErrors <- err
}()
for {
select {
case <-ctx.Done():
return nil
case err := <-errors:
case err := <-framerErrors:
if err != nil {
return fmt.Errorf("error from handler loop: %w", err)
return fmt.Errorf("error from framer: %w", err)
}
return nil
default:
if err := h.handleProtocolState(); err != nil {
return err
case err := <-protocolErrors:
if err != nil {
return fmt.Errorf("error from protocol handler: %w", err)
}
return nil
}
}
}
func (h *Handler) nextState() error {
next, ok := validTransitions[h.protocolState]
if !ok {
return fmt.Errorf("illegal protocol state %s", h.protocolState)
}
h.protocolState = next
return nil
}
var errCommandExpected = errors.New("command must not be nil")
var errResponseExpected = errors.New("response must not be nil")
func (h *Handler) handleProtocolState() error {
h.logger.Tracef("handling protocol state %s", h.protocolState)
h.lock.Lock()
defer h.lock.Unlock()
switch h.protocolState {
case cmdAnnounce:
if err := h.handleCmdAnnounce(); err != nil {
return err
}
case cmdData:
if err := h.handleCmdData(); err != nil {
return err
}
case respAnnounce:
if err := h.handleRespAnnounce(); err != nil {
return err
}
case respData:
if err := h.handleRespData(); err != nil {
return err
}
default:
return fmt.Errorf("unknown protocol state %s", h.protocolState)
}
var (
command *protocol.Command
response *protocol.Response
err error
)
return nil
}
func (h *Handler) handleCmdAnnounce() error {
h.logger.Trace("waiting for command announce")
frame := <-h.framesIn
for {
h.logger.Debugf("handling protocol state %s", h.protocolState)
if frame == nil {
return nil
}
switch h.protocolState {
case cmdAnnounce:
command, err = h.protocolHandler.CommandAnnounce(h.framesIn)
if err != nil {
h.logger.WithError(err).Error("could not handle command announce")
if err := h.protocolHandler.HandleCommandAnnounce(frame); err != nil {
return fmt.Errorf("command announce handling failed: %w", err)
}
break
}
if err := h.nextState(); err != nil {
return err
}
h.protocolState = cmdData
case cmdData:
if command == nil {
return errCommandExpected
}
return nil
}
err = h.protocolHandler.CommandData(h.framesIn, command)
if err != nil {
h.logger.WithError(err).Error("could not handle command data")
func (h *Handler) handleCmdData() error {
h.logger.Trace("waiting for command data")
h.protocolState = cmdAnnounce
frame := <-h.framesIn
break
}
if frame == nil {
return nil
}
h.protocolState = handleCommand
case handleCommand:
if command == nil {
return errCommandExpected
}
if err := h.protocolHandler.HandleCommand(frame); err != nil {
return fmt.Errorf("command handler failed: %w", err)
}
response, err = h.protocolHandler.HandleCommand(command)
if err != nil {
h.logger.WithError(err).Error("could not handle command")
if err := h.nextState(); err != nil {
return err
}
h.protocolState = cmdAnnounce
return nil
}
break
}
func (h *Handler) handleRespAnnounce() error {
frame, err := h.protocolHandler.ResponseAnnounce()
if err != nil {
return fmt.Errorf("could not get response announcement: %w", err)
}
command = nil
h.framesOut <- frame
h.protocolState = respond
case respond:
if response == nil {
return errResponseExpected
}
if err := h.nextState(); err != nil {
return err
}
err = h.protocolHandler.Respond(response, h.framesOut)
if err != nil {
h.logger.WithError(err).Error("could not respond")
return nil
}
h.protocolState = cmdAnnounce
func (h *Handler) handleRespData() error {
frame, err := h.protocolHandler.ResponseData()
if err != nil {
return fmt.Errorf("could not get response data: %w", err)
}
break
}
h.framesOut <- frame
response = nil
if err := h.nextState(); err != nil {
return err
h.protocolState = cmdAnnounce
default:
return fmt.Errorf("unknown protocol state %s", h.protocolState)
}
}
return nil
}
func New(cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.Handler) (*Handler, error) {
func New(
cfg *config.Serial,
logger *logrus.Logger,
protocolHandler protocol.ServerHandler,
) (*Handler, error) {
h := &Handler{
protocolHandler: protocolHandler,
logger: logger,

@ -51,16 +51,16 @@ func (r *Response) String() string {
return fmt.Sprintf("Rsp[announce={%s}, data={%s}]", r.Announce, r.Response)
}
// Handler is responsible for parsing incoming frames and calling commands
type Handler interface {
// HandleCommandAnnounce handles the initial announcement of a command.
HandleCommandAnnounce([]byte) error
// HandleCommand handles the command data.
HandleCommand([]byte) error
// ResponseAnnounce generates the announcement for a response.
ResponseAnnounce() ([]byte, error)
// ResponseData generates the response data.
ResponseData() ([]byte, error)
// ServerHandler is responsible for parsing incoming frames and calling commands
type ServerHandler interface {
// CommandAnnounce handles the initial announcement of a command.
CommandAnnounce(chan []byte) (*Command, error)
// CommandData handles the command data.
CommandData(chan []byte, *Command) error
// HandleCommand executes the command, generating a response.
HandleCommand(*Command) (*Response, error)
// Respond generates the response for a command.
Respond(*Response, chan []byte) error
}
// Framer handles bytes on the wire by adding or removing framing information.

Loading…
Cancel
Save