Refactor client protocol

- define protocols.ClientHandler interface as base for client implementations
- implement protocols.ClientHandler in clientsim's ClientHandler type
- move protocol state handling into protocols.ServerProtocol and
  protocols.ClientProtocol
- move protocolState type into protocols.go
- reduce clientsim's TestCommandGenerator responsibility to test command
  generation
main
Jan Dittberner 2 years ago
parent f429d3da45
commit af40662c7d

@ -25,7 +25,6 @@ import (
"fmt"
"io"
"os"
"sync"
"time"
"github.com/shamaton/msgpackgen/msgpack"
@ -36,124 +35,9 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages"
)
type protocolState int8
const (
cmdAnnounce protocolState = iota
cmdData
respAnnounce
respData
)
var validTransitions = map[protocolState]protocolState{
cmdAnnounce: cmdData,
cmdData: respAnnounce,
respAnnounce: respData,
respData: cmdAnnounce,
}
var protocolStateNames = map[protocolState]string{
cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA",
respAnnounce: "RESP ANNOUNCE",
respData: "RESP DATA",
}
func (p protocolState) String() string {
if name, ok := protocolStateNames[p]; ok {
return name
}
return fmt.Sprintf("unknown %d", p)
}
type TestCommandGenerator struct {
logger *logrus.Logger
currentCommand *protocol.Command
currentResponse *protocol.Response
commands chan *protocol.Command
lock sync.Mutex
}
func (g *TestCommandGenerator) CmdAnnouncement() ([]byte, error) {
g.lock.Lock()
defer g.lock.Unlock()
g.currentCommand = <-g.commands
announceData, err := msgpack.Marshal(g.currentCommand.Announce)
if err != nil {
return nil, fmt.Errorf("could not marshal command annoucement: %w", err)
}
g.logger.WithField("announcement", &g.currentCommand.Announce).Info("write command announcement")
return announceData, nil
}
func (g *TestCommandGenerator) CmdData() ([]byte, error) {
g.lock.Lock()
defer g.lock.Unlock()
cmdData, err := msgpack.Marshal(g.currentCommand.Command)
if err != nil {
return nil, fmt.Errorf("could not marshal command data: %w", err)
}
g.logger.WithField("command", &g.currentCommand.Command).Info("write command data")
return cmdData, nil
}
func (g *TestCommandGenerator) HandleResponseAnnounce(frame []byte) error {
g.lock.Lock()
defer g.lock.Unlock()
var ann messages.ResponseAnnounce
if err := msgpack.Unmarshal(frame, &ann); err != nil {
return fmt.Errorf("could not unmarshal response announcement")
}
g.logger.WithField("announcement", &ann).Info("received response announcement")
g.currentResponse = &protocol.Response{Announce: &ann}
return nil
}
func (g *TestCommandGenerator) HandleResponse(frame []byte) error {
g.lock.Lock()
defer g.lock.Unlock()
switch g.currentResponse.Announce.Code {
case messages.RespHealth:
var response messages.HealthResponse
if err := msgpack.Unmarshal(frame, &response); err != nil {
return fmt.Errorf("unmarshal failed: %w", err)
}
g.currentResponse.Response = &response
case messages.RespFetchCRL:
var response messages.FetchCRLResponse
if err := msgpack.Unmarshal(frame, &response); err != nil {
return fmt.Errorf("unmarshal failed: %w", err)
}
g.currentResponse.Response = &response
}
g.logger.WithField(
"command",
g.currentCommand,
).WithField(
"response",
g.currentResponse,
).Info("handled health response")
return nil
logger *logrus.Logger
commands chan *protocol.Command
}
func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
@ -200,7 +84,11 @@ func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
case <-ctx.Done():
_ = healthTimer.Stop()
g.logger.Info("stopping health check loop")
g.logger.Info("stopped health check loop")
_ = crlTimer.Stop()
g.logger.Info("stopped CRL fetch loop")
return nil
case <-healthTimer.C:
@ -225,187 +113,169 @@ func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
}
type clientSimulator struct {
protocolState protocolState
logger *logrus.Logger
lock sync.Mutex
clientHandler protocol.ClientHandler
framesIn chan []byte
framesOut chan []byte
framer protocol.Framer
commandGenerator *TestCommandGenerator
logger *logrus.Logger
}
func (c *clientSimulator) writeCmdAnnouncement() error {
frame, err := c.commandGenerator.CmdAnnouncement()
if err != nil {
return fmt.Errorf("could not get command annoucement: %w", err)
}
const (
responseAnnounceTimeout = 30 * time.Second
responseDataTimeout = 2 * time.Second
)
c.logger.Trace("writing command announcement")
func (c *clientSimulator) Run(ctx context.Context) error {
framerErrors := make(chan error)
protocolErrors := make(chan error)
generatorErrors := make(chan error)
c.framesOut <- frame
go func() {
err := c.framer.ReadFrames(os.Stdin, c.framesIn)
if err := c.nextState(); err != nil {
return err
}
framerErrors <- err
}()
return nil
}
go func() {
err := c.framer.WriteFrames(os.Stdout, c.framesOut)
func (c *clientSimulator) writeCommand() error {
frame, err := c.commandGenerator.CmdData()
if err != nil {
return fmt.Errorf("could not get command data: %w", err)
}
framerErrors <- err
}()
c.logger.Trace("writing command data")
go func() {
clientProtocol := protocol.NewClient(c.clientHandler, c.commandGenerator.commands, c.framesIn, c.framesOut, c.logger)
c.framesOut <- frame
err := clientProtocol.Handle()
if err := c.nextState(); err != nil {
return err
}
protocolErrors <- err
}()
return nil
}
go func() {
err := c.commandGenerator.GenerateCommands(ctx)
const responseAnnounceTimeout = 30 * time.Second
const responseDataTimeout = 2 * time.Second
generatorErrors <- err
}()
func (c *clientSimulator) handleResponseAnnounce() error {
c.logger.Trace("waiting for response announce")
for {
select {
case <-ctx.Done():
return nil
case err := <-framerErrors:
if err != nil {
return fmt.Errorf("error from framer: %w", err)
}
select {
case frame := <-c.framesIn:
if frame == nil {
return nil
}
case err := <-generatorErrors:
if err != nil {
return fmt.Errorf("error from command generator: %w", err)
}
if err := c.commandGenerator.HandleResponseAnnounce(frame); err != nil {
return fmt.Errorf("response announce handling failed: %w", err)
}
return nil
case err := <-protocolErrors:
if err != nil {
return fmt.Errorf("error from protocol handler: %w", err)
}
if err := c.nextState(); err != nil {
return err
return nil
}
case <-time.After(responseAnnounceTimeout):
c.logger.Warn("response announce timeout expired")
c.protocolState = cmdAnnounce
return nil
}
return nil
}
func (c *clientSimulator) handleResponseData() error {
c.logger.Trace("waiting for response data")
type ClientHandler struct {
logger *logrus.Logger
}
select {
case frame := <-c.framesIn:
if frame == nil {
return nil
}
func (c ClientHandler) Send(command *protocol.Command, out chan []byte) error {
var (
frame []byte
err error
)
if err := c.commandGenerator.HandleResponse(frame); err != nil {
return fmt.Errorf("response handler failed: %w", err)
}
frame, err = msgpack.Marshal(command.Announce)
if err != nil {
return fmt.Errorf("could not marshal command annoucement: %w", err)
}
if err := c.nextState(); err != nil {
return err
}
c.logger.WithField("announcement", command.Announce).Info("write command announcement")
return nil
case <-time.After(responseDataTimeout):
c.logger.Warn("response data timeout expired")
c.logger.Trace("writing command announcement")
c.protocolState = cmdAnnounce
out <- frame
return nil
frame, err = msgpack.Marshal(command.Command)
if err != nil {
return fmt.Errorf("could not marshal command data: %w", err)
}
}
func (c *clientSimulator) Run(ctx context.Context) error {
c.protocolState = cmdAnnounce
errors := make(chan error)
c.logger.WithField("command", command.Command).Info("write command data")
go func() {
err := c.framer.ReadFrames(os.Stdin, c.framesIn)
out <- frame
errors <- err
}()
return nil
}
go func() {
err := c.framer.WriteFrames(os.Stdout, c.framesOut)
func (c ClientHandler) ResponseAnnounce(in chan []byte) (*protocol.Response, error) {
response := &protocol.Response{}
errors <- err
}()
var announce messages.ResponseAnnounce
go func() {
err := c.commandGenerator.GenerateCommands(ctx)
select {
case frame := <-in:
if err := msgpack.Unmarshal(frame, &announce); err != nil {
return nil, fmt.Errorf("could not unmarshal response announcement: %w", err)
}
errors <- err
}()
response.Announce = &announce
for {
select {
case <-ctx.Done():
return nil
case err := <-errors:
if err != nil {
return fmt.Errorf("error from handler loop: %w", err)
}
c.logger.WithField("announcement", response.Announce).Debug("received response announcement")
return nil
default:
if err := c.handleProtocolState(); err != nil {
return err
}
}
return response, nil
case <-time.After(responseAnnounceTimeout):
return nil, protocol.ErrResponseAnnounceTimeoutExpired
}
}
func (c *clientSimulator) handleProtocolState() error {
c.logger.Tracef("handling protocol state %s", c.protocolState)
func (c ClientHandler) ResponseData(in chan []byte, response *protocol.Response) error {
select {
case frame := <-in:
switch response.Announce.Code {
case messages.RespHealth:
var resp messages.HealthResponse
if err := msgpack.Unmarshal(frame, &resp); err != nil {
return fmt.Errorf("could not unmarshal health response data: %w", err)
}
c.lock.Lock()
defer c.lock.Unlock()
response.Response = &resp
case messages.RespFetchCRL:
var resp messages.FetchCRLResponse
if err := msgpack.Unmarshal(frame, &resp); err != nil {
return fmt.Errorf("could not unmarshal fetch CRL response data: %w", err)
}
switch c.protocolState {
case cmdAnnounce:
if err := c.writeCmdAnnouncement(); err != nil {
return err
}
case cmdData:
if err := c.writeCommand(); err != nil {
return err
}
case respAnnounce:
if err := c.handleResponseAnnounce(); err != nil {
return err
}
case respData:
if err := c.handleResponseData(); err != nil {
return err
response.Response = &resp
default:
return fmt.Errorf("unhandled response code %s", response.Announce.Code)
}
default:
return fmt.Errorf("unknown protocol state %s", c.protocolState)
case <-time.After(responseDataTimeout):
return protocol.ErrResponseDataTimeoutExpired
}
return nil
}
func (c *clientSimulator) nextState() error {
next, ok := validTransitions[c.protocolState]
if !ok {
return fmt.Errorf("illegal protocol state %s", c.protocolState)
}
c.protocolState = next
func (c ClientHandler) HandleResponse(response *protocol.Response) error {
c.logger.WithField("response", response.Announce).Info("handled response")
c.logger.WithField("response", response).Debug("full response")
return nil
}
func newClientHandler(logger *logrus.Logger) *ClientHandler {
return &ClientHandler{logger: logger}
}
func main() {
logger := logrus.New()
logger.SetOutput(os.Stderr)
@ -418,10 +288,11 @@ func main() {
logger: logger,
commands: make(chan *protocol.Command),
},
logger: logger,
framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: protocol.NewCOBSFramer(logger),
logger: logger,
framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: protocol.NewCOBSFramer(logger),
clientHandler: newClientHandler(logger),
}
err := sim.Run(context.Background())

@ -57,7 +57,11 @@ func (m *MsgPackHandler) CommandAnnounce(frames chan []byte) (*protocol.Command,
return nil, fmt.Errorf("could not unmarshal command announcement: %w", err)
}
m.logger.WithField("announcement", &ann).Info("received command announcement")
if ann.Code == messages.CmdUndef {
return nil, fmt.Errorf("received undefined command announcement: %s", ann)
}
m.logger.WithField("announcement", &ann).Debug("received command announcement")
return &protocol.Command{Announce: &ann}, nil
}
@ -109,7 +113,7 @@ func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) e
m.lock.Lock()
defer m.lock.Unlock()
announce, err := msgpack.Marshal(response)
announce, err := msgpack.Marshal(response.Announce)
if err != nil {
return fmt.Errorf("could not marshal response announcement: %w", err)
}
@ -125,7 +129,7 @@ func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) e
m.logger.WithField("length", len(data)).Debug("write response")
out <- announce
out <- data
return nil
}

@ -20,7 +20,6 @@ package serial
import (
"context"
"errors"
"fmt"
"github.com/sirupsen/logrus"
@ -30,39 +29,14 @@ import (
"git.cacert.org/cacert-gosigner/pkg/protocol"
)
type protocolState int8
const (
cmdAnnounce protocolState = iota
cmdData
handleCommand
respond
)
var protocolStateNames = map[protocolState]string{
cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA",
handleCommand: "RESP ANNOUNCE",
respond: "RESP DATA",
}
func (p protocolState) String() string {
if name, ok := protocolStateNames[p]; ok {
return name
}
return fmt.Sprintf("unknown %d", p)
}
type Handler struct {
protocolHandler protocol.ServerHandler
protocolState protocolState
framer protocol.Framer
config *serial.Config
port *serial.Port
logger *logrus.Logger
framesIn chan []byte
framesOut chan []byte
serverHandler protocol.ServerHandler
framer protocol.Framer
config *serial.Config
port *serial.Port
logger *logrus.Logger
framesIn chan []byte
framesOut chan []byte
}
func (h *Handler) setupConnection() error {
@ -86,7 +60,6 @@ func (h *Handler) Close() error {
}
func (h *Handler) Run(ctx context.Context) error {
h.protocolState = cmdAnnounce
protocolErrors, framerErrors := make(chan error), make(chan error)
go func() {
@ -102,7 +75,9 @@ func (h *Handler) Run(ctx context.Context) error {
}()
go func() {
err := h.handleProtocolState()
serverProtocol := protocol.NewServer(h.serverHandler, h.framesIn, h.framesOut, h.logger)
err := serverProtocol.Handle()
protocolErrors <- err
}()
@ -127,95 +102,17 @@ func (h *Handler) Run(ctx context.Context) error {
}
}
var errCommandExpected = errors.New("command must not be nil")
var errResponseExpected = errors.New("response must not be nil")
func (h *Handler) handleProtocolState() error {
var (
command *protocol.Command
response *protocol.Response
err error
)
for {
h.logger.Debugf("handling protocol state %s", h.protocolState)
switch h.protocolState {
case cmdAnnounce:
command, err = h.protocolHandler.CommandAnnounce(h.framesIn)
if err != nil {
h.logger.WithError(err).Error("could not handle command announce")
break
}
h.protocolState = cmdData
case cmdData:
if command == nil {
return errCommandExpected
}
err = h.protocolHandler.CommandData(h.framesIn, command)
if err != nil {
h.logger.WithError(err).Error("could not handle command data")
h.protocolState = cmdAnnounce
break
}
h.protocolState = handleCommand
case handleCommand:
if command == nil {
return errCommandExpected
}
response, err = h.protocolHandler.HandleCommand(command)
if err != nil {
h.logger.WithError(err).Error("could not handle command")
h.protocolState = cmdAnnounce
break
}
command = nil
h.protocolState = respond
case respond:
if response == nil {
return errResponseExpected
}
err = h.protocolHandler.Respond(response, h.framesOut)
if err != nil {
h.logger.WithError(err).Error("could not respond")
h.protocolState = cmdAnnounce
break
}
response = nil
h.protocolState = cmdAnnounce
default:
return fmt.Errorf("unknown protocol state %s", h.protocolState)
}
}
}
func New(
cfg *config.Serial,
logger *logrus.Logger,
protocolHandler protocol.ServerHandler,
) (*Handler, error) {
h := &Handler{
protocolHandler: protocolHandler,
logger: logger,
framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: protocol.NewCOBSFramer(logger),
serverHandler: protocolHandler,
logger: logger,
framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: protocol.NewCOBSFramer(logger),
}
h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout}

@ -37,11 +37,13 @@ import (
type CommandCode int8
const (
CmdHealth CommandCode = iota
CmdUndef CommandCode = iota
CmdHealth
CmdFetchCRL
)
var commandNames = map[CommandCode]string{
CmdUndef: "UNDEFINED",
CmdHealth: "HEALTH",
CmdFetchCRL: "FETCH CRL",
}
@ -57,13 +59,15 @@ func (c CommandCode) String() string {
type ResponseCode int8
const (
RespError ResponseCode = -1
RespHealth ResponseCode = iota
RespError ResponseCode = -1
RespUndef ResponseCode = iota
RespHealth
RespFetchCRL
)
var responseNames = map[ResponseCode]string{
RespError: "ERROR",
RespUndef: "UNDEFINED",
RespHealth: "HEALTH",
RespFetchCRL: "FETCH CRL",
}

@ -63,6 +63,139 @@ type ServerHandler interface {
Respond(*Response, chan []byte) error
}
type ClientHandler interface {
Send(*Command, chan []byte) error
ResponseAnnounce(chan []byte) (*Response, error)
ResponseData(chan []byte, *Response) error
HandleResponse(*Response) error
}
var (
errCommandExpected = errors.New("command must not be nil")
errResponseExpected = errors.New("response must not be nil")
ErrResponseAnnounceTimeoutExpired = errors.New("response announce timeout expired")
ErrResponseDataTimeoutExpired = errors.New("response data timeout expired")
)
type protocolState int8
const (
cmdAnnounce protocolState = iota
cmdData
handleCommand
respond
respAnnounce
respData
handleResponse
)
var protocolStateNames = map[protocolState]string{
cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA",
handleCommand: "HANDLE CMD",
respond: "RESPOND",
respAnnounce: "RESP ANNOUNCE",
respData: "RESP DATA",
handleResponse: "HANDLE RESP",
}
func (p protocolState) String() string {
if name, ok := protocolStateNames[p]; ok {
return name
}
return fmt.Sprintf("unknown %d", p)
}
type ServerProtocol struct {
handler ServerHandler
in, out chan []byte
logger *logrus.Logger
state protocolState
}
func (p *ServerProtocol) Handle() error {
var (
command *Command
response *Response
err error
)
for {
p.logger.Debugf("handling protocol state %s", p.state)
switch p.state {
case cmdAnnounce:
command, err = p.handler.CommandAnnounce(p.in)
if err != nil {
p.logger.WithError(err).Error("could not handle command announce")
break
}
p.state = cmdData
case cmdData:
if command == nil {
return errCommandExpected
}
err = p.handler.CommandData(p.in, command)
if err != nil {
p.logger.WithError(err).Error("could not handle command data")
p.state = cmdAnnounce
break
}
p.state = handleCommand
case handleCommand:
if command == nil {
return errCommandExpected
}
response, err = p.handler.HandleCommand(command)
if err != nil {
p.logger.WithError(err).Error("could not handle command")
p.state = cmdAnnounce
break
}
p.state = respond
case respond:
if response == nil {
return errResponseExpected
}
err = p.handler.Respond(response, p.out)
if err != nil {
p.logger.WithError(err).Error("could not respond")
p.state = cmdAnnounce
break
}
p.state = cmdAnnounce
default:
return fmt.Errorf("unknown protocol state %s", p.state)
}
}
}
func NewServer(handler ServerHandler, in, out chan []byte, logger *logrus.Logger) *ServerProtocol {
return &ServerProtocol{
handler: handler,
in: in,
out: out,
logger: logger,
state: cmdAnnounce,
}
}
// Framer handles bytes on the wire by adding or removing framing information.
type Framer interface {
// ReadFrames reads data frames and publishes unframed data to the channel.
@ -71,6 +204,106 @@ type Framer interface {
WriteFrames(io.Writer, chan []byte) error
}
type ClientProtocol struct {
handler ClientHandler
commands chan *Command
in, out chan []byte
logger *logrus.Logger
state protocolState
}
func (p *ClientProtocol) Handle() error {
var (
command *Command
response *Response
err error
)
for {
p.logger.Debugf("handling protocol state %s", p.state)
switch p.state {
case cmdAnnounce:
command = <-p.commands
if command == nil {
return errCommandExpected
}
err = p.handler.Send(command, p.out)
if err != nil {
p.logger.WithError(err).Error("could not send command announce")
break
}
p.state = respAnnounce
case respAnnounce:
response, err = p.handler.ResponseAnnounce(p.in)
if err != nil {
p.logger.WithError(err).Error("could not handle response announce")
p.state = cmdAnnounce
break
}
p.state = respData
case respData:
if response == nil {
return errResponseExpected
}
err = p.handler.ResponseData(p.in, response)
if err != nil {
p.logger.WithError(err).Error("could not handle response data")
if errors.Is(err, ErrResponseDataTimeoutExpired) {
p.state = cmdAnnounce
} else {
p.state = respAnnounce
}
break
}
p.state = handleResponse
case handleResponse:
if response == nil {
return errResponseExpected
}
err = p.handler.HandleResponse(response)
if err != nil {
p.logger.WithError(err).Error("could not handle response")
p.state = cmdAnnounce
break
}
p.state = cmdAnnounce
default:
return fmt.Errorf("unknown protocol state %s", p.state)
}
}
}
func NewClient(
handler ClientHandler,
commands chan *Command,
in, out chan []byte,
logger *logrus.Logger,
) *ClientProtocol {
return &ClientProtocol{
handler: handler,
commands: commands,
in: in,
out: out,
logger: logger,
state: cmdAnnounce,
}
}
const bufferSize = 1024
const readInterval = 50 * time.Millisecond
@ -127,7 +360,7 @@ func (c *COBSFramer) ReadFrames(reader io.Reader, frameChan chan []byte) error {
if err = cobs.Verify(data, c.config); err != nil {
c.logger.WithError(err).Warnf("skipping invalid frame of %d bytes", len(data))
break
continue
}
frame = cobs.Decode(data, c.config)

Loading…
Cancel
Save