Refactor client protocol
- define protocols.ClientHandler interface as base for client implementations - implement protocols.ClientHandler in clientsim's ClientHandler type - move protocol state handling into protocols.ServerProtocol and protocols.ClientProtocol - move protocolState type into protocols.go - reduce clientsim's TestCommandGenerator responsibility to test command generation
This commit is contained in:
parent
f429d3da45
commit
af40662c7d
5 changed files with 388 additions and 379 deletions
|
@ -25,7 +25,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shamaton/msgpackgen/msgpack"
|
||||
|
@ -36,124 +35,9 @@ import (
|
|||
"git.cacert.org/cacert-gosigner/pkg/messages"
|
||||
)
|
||||
|
||||
type protocolState int8
|
||||
|
||||
const (
|
||||
cmdAnnounce protocolState = iota
|
||||
cmdData
|
||||
respAnnounce
|
||||
respData
|
||||
)
|
||||
|
||||
var validTransitions = map[protocolState]protocolState{
|
||||
cmdAnnounce: cmdData,
|
||||
cmdData: respAnnounce,
|
||||
respAnnounce: respData,
|
||||
respData: cmdAnnounce,
|
||||
}
|
||||
|
||||
var protocolStateNames = map[protocolState]string{
|
||||
cmdAnnounce: "CMD ANNOUNCE",
|
||||
cmdData: "CMD DATA",
|
||||
respAnnounce: "RESP ANNOUNCE",
|
||||
respData: "RESP DATA",
|
||||
}
|
||||
|
||||
func (p protocolState) String() string {
|
||||
if name, ok := protocolStateNames[p]; ok {
|
||||
return name
|
||||
}
|
||||
|
||||
return fmt.Sprintf("unknown %d", p)
|
||||
}
|
||||
|
||||
type TestCommandGenerator struct {
|
||||
logger *logrus.Logger
|
||||
currentCommand *protocol.Command
|
||||
currentResponse *protocol.Response
|
||||
commands chan *protocol.Command
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (g *TestCommandGenerator) CmdAnnouncement() ([]byte, error) {
|
||||
g.lock.Lock()
|
||||
defer g.lock.Unlock()
|
||||
|
||||
g.currentCommand = <-g.commands
|
||||
|
||||
announceData, err := msgpack.Marshal(g.currentCommand.Announce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not marshal command annoucement: %w", err)
|
||||
}
|
||||
|
||||
g.logger.WithField("announcement", &g.currentCommand.Announce).Info("write command announcement")
|
||||
|
||||
return announceData, nil
|
||||
}
|
||||
|
||||
func (g *TestCommandGenerator) CmdData() ([]byte, error) {
|
||||
g.lock.Lock()
|
||||
defer g.lock.Unlock()
|
||||
|
||||
cmdData, err := msgpack.Marshal(g.currentCommand.Command)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not marshal command data: %w", err)
|
||||
}
|
||||
|
||||
g.logger.WithField("command", &g.currentCommand.Command).Info("write command data")
|
||||
|
||||
return cmdData, nil
|
||||
}
|
||||
|
||||
func (g *TestCommandGenerator) HandleResponseAnnounce(frame []byte) error {
|
||||
g.lock.Lock()
|
||||
defer g.lock.Unlock()
|
||||
|
||||
var ann messages.ResponseAnnounce
|
||||
|
||||
if err := msgpack.Unmarshal(frame, &ann); err != nil {
|
||||
return fmt.Errorf("could not unmarshal response announcement")
|
||||
}
|
||||
|
||||
g.logger.WithField("announcement", &ann).Info("received response announcement")
|
||||
|
||||
g.currentResponse = &protocol.Response{Announce: &ann}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *TestCommandGenerator) HandleResponse(frame []byte) error {
|
||||
g.lock.Lock()
|
||||
defer g.lock.Unlock()
|
||||
|
||||
switch g.currentResponse.Announce.Code {
|
||||
case messages.RespHealth:
|
||||
var response messages.HealthResponse
|
||||
|
||||
if err := msgpack.Unmarshal(frame, &response); err != nil {
|
||||
return fmt.Errorf("unmarshal failed: %w", err)
|
||||
}
|
||||
|
||||
g.currentResponse.Response = &response
|
||||
case messages.RespFetchCRL:
|
||||
var response messages.FetchCRLResponse
|
||||
|
||||
if err := msgpack.Unmarshal(frame, &response); err != nil {
|
||||
return fmt.Errorf("unmarshal failed: %w", err)
|
||||
}
|
||||
|
||||
g.currentResponse.Response = &response
|
||||
}
|
||||
|
||||
g.logger.WithField(
|
||||
"command",
|
||||
g.currentCommand,
|
||||
).WithField(
|
||||
"response",
|
||||
g.currentResponse,
|
||||
).Info("handled health response")
|
||||
|
||||
return nil
|
||||
logger *logrus.Logger
|
||||
commands chan *protocol.Command
|
||||
}
|
||||
|
||||
func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
|
||||
|
@ -200,7 +84,11 @@ func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
|
|||
case <-ctx.Done():
|
||||
_ = healthTimer.Stop()
|
||||
|
||||
g.logger.Info("stopping health check loop")
|
||||
g.logger.Info("stopped health check loop")
|
||||
|
||||
_ = crlTimer.Stop()
|
||||
|
||||
g.logger.Info("stopped CRL fetch loop")
|
||||
|
||||
return nil
|
||||
case <-healthTimer.C:
|
||||
|
@ -225,187 +113,169 @@ func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
|
|||
}
|
||||
|
||||
type clientSimulator struct {
|
||||
protocolState protocolState
|
||||
logger *logrus.Logger
|
||||
lock sync.Mutex
|
||||
clientHandler protocol.ClientHandler
|
||||
framesIn chan []byte
|
||||
framesOut chan []byte
|
||||
framer protocol.Framer
|
||||
commandGenerator *TestCommandGenerator
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
func (c *clientSimulator) writeCmdAnnouncement() error {
|
||||
frame, err := c.commandGenerator.CmdAnnouncement()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get command annoucement: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Trace("writing command announcement")
|
||||
|
||||
c.framesOut <- frame
|
||||
|
||||
if err := c.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) writeCommand() error {
|
||||
frame, err := c.commandGenerator.CmdData()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get command data: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Trace("writing command data")
|
||||
|
||||
c.framesOut <- frame
|
||||
|
||||
if err := c.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const responseAnnounceTimeout = 30 * time.Second
|
||||
const responseDataTimeout = 2 * time.Second
|
||||
|
||||
func (c *clientSimulator) handleResponseAnnounce() error {
|
||||
c.logger.Trace("waiting for response announce")
|
||||
|
||||
select {
|
||||
case frame := <-c.framesIn:
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.commandGenerator.HandleResponseAnnounce(frame); err != nil {
|
||||
return fmt.Errorf("response announce handling failed: %w", err)
|
||||
}
|
||||
|
||||
if err := c.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-time.After(responseAnnounceTimeout):
|
||||
c.logger.Warn("response announce timeout expired")
|
||||
|
||||
c.protocolState = cmdAnnounce
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) handleResponseData() error {
|
||||
c.logger.Trace("waiting for response data")
|
||||
|
||||
select {
|
||||
case frame := <-c.framesIn:
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.commandGenerator.HandleResponse(frame); err != nil {
|
||||
return fmt.Errorf("response handler failed: %w", err)
|
||||
}
|
||||
|
||||
if err := c.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
case <-time.After(responseDataTimeout):
|
||||
c.logger.Warn("response data timeout expired")
|
||||
|
||||
c.protocolState = cmdAnnounce
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
const (
|
||||
responseAnnounceTimeout = 30 * time.Second
|
||||
responseDataTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
func (c *clientSimulator) Run(ctx context.Context) error {
|
||||
c.protocolState = cmdAnnounce
|
||||
errors := make(chan error)
|
||||
framerErrors := make(chan error)
|
||||
protocolErrors := make(chan error)
|
||||
generatorErrors := make(chan error)
|
||||
|
||||
go func() {
|
||||
err := c.framer.ReadFrames(os.Stdin, c.framesIn)
|
||||
|
||||
errors <- err
|
||||
framerErrors <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := c.framer.WriteFrames(os.Stdout, c.framesOut)
|
||||
|
||||
errors <- err
|
||||
framerErrors <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
clientProtocol := protocol.NewClient(c.clientHandler, c.commandGenerator.commands, c.framesIn, c.framesOut, c.logger)
|
||||
|
||||
err := clientProtocol.Handle()
|
||||
|
||||
protocolErrors <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := c.commandGenerator.GenerateCommands(ctx)
|
||||
|
||||
errors <- err
|
||||
generatorErrors <- err
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case err := <-errors:
|
||||
case err := <-framerErrors:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error from handler loop: %w", err)
|
||||
return fmt.Errorf("error from framer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
default:
|
||||
if err := c.handleProtocolState(); err != nil {
|
||||
return err
|
||||
case err := <-generatorErrors:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error from command generator: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
case err := <-protocolErrors:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error from protocol handler: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientSimulator) handleProtocolState() error {
|
||||
c.logger.Tracef("handling protocol state %s", c.protocolState)
|
||||
type ClientHandler struct {
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
func (c ClientHandler) Send(command *protocol.Command, out chan []byte) error {
|
||||
var (
|
||||
frame []byte
|
||||
err error
|
||||
)
|
||||
|
||||
switch c.protocolState {
|
||||
case cmdAnnounce:
|
||||
if err := c.writeCmdAnnouncement(); err != nil {
|
||||
return err
|
||||
frame, err = msgpack.Marshal(command.Announce)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal command annoucement: %w", err)
|
||||
}
|
||||
|
||||
c.logger.WithField("announcement", command.Announce).Info("write command announcement")
|
||||
|
||||
c.logger.Trace("writing command announcement")
|
||||
|
||||
out <- frame
|
||||
|
||||
frame, err = msgpack.Marshal(command.Command)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal command data: %w", err)
|
||||
}
|
||||
|
||||
c.logger.WithField("command", command.Command).Info("write command data")
|
||||
|
||||
out <- frame
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c ClientHandler) ResponseAnnounce(in chan []byte) (*protocol.Response, error) {
|
||||
response := &protocol.Response{}
|
||||
|
||||
var announce messages.ResponseAnnounce
|
||||
|
||||
select {
|
||||
case frame := <-in:
|
||||
if err := msgpack.Unmarshal(frame, &announce); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response announcement: %w", err)
|
||||
}
|
||||
case cmdData:
|
||||
if err := c.writeCommand(); err != nil {
|
||||
return err
|
||||
|
||||
response.Announce = &announce
|
||||
|
||||
c.logger.WithField("announcement", response.Announce).Debug("received response announcement")
|
||||
|
||||
return response, nil
|
||||
case <-time.After(responseAnnounceTimeout):
|
||||
return nil, protocol.ErrResponseAnnounceTimeoutExpired
|
||||
}
|
||||
}
|
||||
|
||||
func (c ClientHandler) ResponseData(in chan []byte, response *protocol.Response) error {
|
||||
select {
|
||||
case frame := <-in:
|
||||
switch response.Announce.Code {
|
||||
case messages.RespHealth:
|
||||
var resp messages.HealthResponse
|
||||
if err := msgpack.Unmarshal(frame, &resp); err != nil {
|
||||
return fmt.Errorf("could not unmarshal health response data: %w", err)
|
||||
}
|
||||
|
||||
response.Response = &resp
|
||||
case messages.RespFetchCRL:
|
||||
var resp messages.FetchCRLResponse
|
||||
if err := msgpack.Unmarshal(frame, &resp); err != nil {
|
||||
return fmt.Errorf("could not unmarshal fetch CRL response data: %w", err)
|
||||
}
|
||||
|
||||
response.Response = &resp
|
||||
default:
|
||||
return fmt.Errorf("unhandled response code %s", response.Announce.Code)
|
||||
}
|
||||
case respAnnounce:
|
||||
if err := c.handleResponseAnnounce(); err != nil {
|
||||
return err
|
||||
}
|
||||
case respData:
|
||||
if err := c.handleResponseData(); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown protocol state %s", c.protocolState)
|
||||
case <-time.After(responseDataTimeout):
|
||||
return protocol.ErrResponseDataTimeoutExpired
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) nextState() error {
|
||||
next, ok := validTransitions[c.protocolState]
|
||||
if !ok {
|
||||
return fmt.Errorf("illegal protocol state %s", c.protocolState)
|
||||
}
|
||||
|
||||
c.protocolState = next
|
||||
func (c ClientHandler) HandleResponse(response *protocol.Response) error {
|
||||
c.logger.WithField("response", response.Announce).Info("handled response")
|
||||
c.logger.WithField("response", response).Debug("full response")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newClientHandler(logger *logrus.Logger) *ClientHandler {
|
||||
return &ClientHandler{logger: logger}
|
||||
}
|
||||
|
||||
func main() {
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(os.Stderr)
|
||||
|
@ -418,10 +288,11 @@ func main() {
|
|||
logger: logger,
|
||||
commands: make(chan *protocol.Command),
|
||||
},
|
||||
logger: logger,
|
||||
framesIn: make(chan []byte),
|
||||
framesOut: make(chan []byte),
|
||||
framer: protocol.NewCOBSFramer(logger),
|
||||
logger: logger,
|
||||
framesIn: make(chan []byte),
|
||||
framesOut: make(chan []byte),
|
||||
framer: protocol.NewCOBSFramer(logger),
|
||||
clientHandler: newClientHandler(logger),
|
||||
}
|
||||
|
||||
err := sim.Run(context.Background())
|
||||
|
|
|
@ -57,7 +57,11 @@ func (m *MsgPackHandler) CommandAnnounce(frames chan []byte) (*protocol.Command,
|
|||
return nil, fmt.Errorf("could not unmarshal command announcement: %w", err)
|
||||
}
|
||||
|
||||
m.logger.WithField("announcement", &ann).Info("received command announcement")
|
||||
if ann.Code == messages.CmdUndef {
|
||||
return nil, fmt.Errorf("received undefined command announcement: %s", ann)
|
||||
}
|
||||
|
||||
m.logger.WithField("announcement", &ann).Debug("received command announcement")
|
||||
|
||||
return &protocol.Command{Announce: &ann}, nil
|
||||
}
|
||||
|
@ -109,7 +113,7 @@ func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) e
|
|||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
announce, err := msgpack.Marshal(response)
|
||||
announce, err := msgpack.Marshal(response.Announce)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal response announcement: %w", err)
|
||||
}
|
||||
|
@ -125,7 +129,7 @@ func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) e
|
|||
|
||||
m.logger.WithField("length", len(data)).Debug("write response")
|
||||
|
||||
out <- announce
|
||||
out <- data
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ package serial
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -30,39 +29,14 @@ import (
|
|||
"git.cacert.org/cacert-gosigner/pkg/protocol"
|
||||
)
|
||||
|
||||
type protocolState int8
|
||||
|
||||
const (
|
||||
cmdAnnounce protocolState = iota
|
||||
cmdData
|
||||
handleCommand
|
||||
respond
|
||||
)
|
||||
|
||||
var protocolStateNames = map[protocolState]string{
|
||||
cmdAnnounce: "CMD ANNOUNCE",
|
||||
cmdData: "CMD DATA",
|
||||
handleCommand: "RESP ANNOUNCE",
|
||||
respond: "RESP DATA",
|
||||
}
|
||||
|
||||
func (p protocolState) String() string {
|
||||
if name, ok := protocolStateNames[p]; ok {
|
||||
return name
|
||||
}
|
||||
|
||||
return fmt.Sprintf("unknown %d", p)
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
protocolHandler protocol.ServerHandler
|
||||
protocolState protocolState
|
||||
framer protocol.Framer
|
||||
config *serial.Config
|
||||
port *serial.Port
|
||||
logger *logrus.Logger
|
||||
framesIn chan []byte
|
||||
framesOut chan []byte
|
||||
serverHandler protocol.ServerHandler
|
||||
framer protocol.Framer
|
||||
config *serial.Config
|
||||
port *serial.Port
|
||||
logger *logrus.Logger
|
||||
framesIn chan []byte
|
||||
framesOut chan []byte
|
||||
}
|
||||
|
||||
func (h *Handler) setupConnection() error {
|
||||
|
@ -86,7 +60,6 @@ func (h *Handler) Close() error {
|
|||
}
|
||||
|
||||
func (h *Handler) Run(ctx context.Context) error {
|
||||
h.protocolState = cmdAnnounce
|
||||
protocolErrors, framerErrors := make(chan error), make(chan error)
|
||||
|
||||
go func() {
|
||||
|
@ -102,7 +75,9 @@ func (h *Handler) Run(ctx context.Context) error {
|
|||
}()
|
||||
|
||||
go func() {
|
||||
err := h.handleProtocolState()
|
||||
serverProtocol := protocol.NewServer(h.serverHandler, h.framesIn, h.framesOut, h.logger)
|
||||
|
||||
err := serverProtocol.Handle()
|
||||
|
||||
protocolErrors <- err
|
||||
}()
|
||||
|
@ -127,95 +102,17 @@ func (h *Handler) Run(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
var errCommandExpected = errors.New("command must not be nil")
|
||||
var errResponseExpected = errors.New("response must not be nil")
|
||||
|
||||
func (h *Handler) handleProtocolState() error {
|
||||
var (
|
||||
command *protocol.Command
|
||||
response *protocol.Response
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
h.logger.Debugf("handling protocol state %s", h.protocolState)
|
||||
|
||||
switch h.protocolState {
|
||||
case cmdAnnounce:
|
||||
command, err = h.protocolHandler.CommandAnnounce(h.framesIn)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("could not handle command announce")
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
h.protocolState = cmdData
|
||||
case cmdData:
|
||||
if command == nil {
|
||||
return errCommandExpected
|
||||
}
|
||||
|
||||
err = h.protocolHandler.CommandData(h.framesIn, command)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("could not handle command data")
|
||||
|
||||
h.protocolState = cmdAnnounce
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
h.protocolState = handleCommand
|
||||
case handleCommand:
|
||||
if command == nil {
|
||||
return errCommandExpected
|
||||
}
|
||||
|
||||
response, err = h.protocolHandler.HandleCommand(command)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("could not handle command")
|
||||
|
||||
h.protocolState = cmdAnnounce
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
command = nil
|
||||
|
||||
h.protocolState = respond
|
||||
case respond:
|
||||
if response == nil {
|
||||
return errResponseExpected
|
||||
}
|
||||
|
||||
err = h.protocolHandler.Respond(response, h.framesOut)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("could not respond")
|
||||
|
||||
h.protocolState = cmdAnnounce
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
response = nil
|
||||
|
||||
h.protocolState = cmdAnnounce
|
||||
default:
|
||||
return fmt.Errorf("unknown protocol state %s", h.protocolState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func New(
|
||||
cfg *config.Serial,
|
||||
logger *logrus.Logger,
|
||||
protocolHandler protocol.ServerHandler,
|
||||
) (*Handler, error) {
|
||||
h := &Handler{
|
||||
protocolHandler: protocolHandler,
|
||||
logger: logger,
|
||||
framesIn: make(chan []byte),
|
||||
framesOut: make(chan []byte),
|
||||
framer: protocol.NewCOBSFramer(logger),
|
||||
serverHandler: protocolHandler,
|
||||
logger: logger,
|
||||
framesIn: make(chan []byte),
|
||||
framesOut: make(chan []byte),
|
||||
framer: protocol.NewCOBSFramer(logger),
|
||||
}
|
||||
h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout}
|
||||
|
||||
|
|
|
@ -37,11 +37,13 @@ import (
|
|||
type CommandCode int8
|
||||
|
||||
const (
|
||||
CmdHealth CommandCode = iota
|
||||
CmdUndef CommandCode = iota
|
||||
CmdHealth
|
||||
CmdFetchCRL
|
||||
)
|
||||
|
||||
var commandNames = map[CommandCode]string{
|
||||
CmdUndef: "UNDEFINED",
|
||||
CmdHealth: "HEALTH",
|
||||
CmdFetchCRL: "FETCH CRL",
|
||||
}
|
||||
|
@ -57,13 +59,15 @@ func (c CommandCode) String() string {
|
|||
type ResponseCode int8
|
||||
|
||||
const (
|
||||
RespError ResponseCode = -1
|
||||
RespHealth ResponseCode = iota
|
||||
RespError ResponseCode = -1
|
||||
RespUndef ResponseCode = iota
|
||||
RespHealth
|
||||
RespFetchCRL
|
||||
)
|
||||
|
||||
var responseNames = map[ResponseCode]string{
|
||||
RespError: "ERROR",
|
||||
RespUndef: "UNDEFINED",
|
||||
RespHealth: "HEALTH",
|
||||
RespFetchCRL: "FETCH CRL",
|
||||
}
|
||||
|
|
|
@ -63,6 +63,139 @@ type ServerHandler interface {
|
|||
Respond(*Response, chan []byte) error
|
||||
}
|
||||
|
||||
type ClientHandler interface {
|
||||
Send(*Command, chan []byte) error
|
||||
ResponseAnnounce(chan []byte) (*Response, error)
|
||||
ResponseData(chan []byte, *Response) error
|
||||
HandleResponse(*Response) error
|
||||
}
|
||||
|
||||
var (
|
||||
errCommandExpected = errors.New("command must not be nil")
|
||||
errResponseExpected = errors.New("response must not be nil")
|
||||
|
||||
ErrResponseAnnounceTimeoutExpired = errors.New("response announce timeout expired")
|
||||
ErrResponseDataTimeoutExpired = errors.New("response data timeout expired")
|
||||
)
|
||||
|
||||
type protocolState int8
|
||||
|
||||
const (
|
||||
cmdAnnounce protocolState = iota
|
||||
cmdData
|
||||
handleCommand
|
||||
respond
|
||||
respAnnounce
|
||||
respData
|
||||
handleResponse
|
||||
)
|
||||
|
||||
var protocolStateNames = map[protocolState]string{
|
||||
cmdAnnounce: "CMD ANNOUNCE",
|
||||
cmdData: "CMD DATA",
|
||||
handleCommand: "HANDLE CMD",
|
||||
respond: "RESPOND",
|
||||
respAnnounce: "RESP ANNOUNCE",
|
||||
respData: "RESP DATA",
|
||||
handleResponse: "HANDLE RESP",
|
||||
}
|
||||
|
||||
func (p protocolState) String() string {
|
||||
if name, ok := protocolStateNames[p]; ok {
|
||||
return name
|
||||
}
|
||||
|
||||
return fmt.Sprintf("unknown %d", p)
|
||||
}
|
||||
|
||||
type ServerProtocol struct {
|
||||
handler ServerHandler
|
||||
in, out chan []byte
|
||||
logger *logrus.Logger
|
||||
state protocolState
|
||||
}
|
||||
|
||||
func (p *ServerProtocol) Handle() error {
|
||||
var (
|
||||
command *Command
|
||||
response *Response
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
p.logger.Debugf("handling protocol state %s", p.state)
|
||||
|
||||
switch p.state {
|
||||
case cmdAnnounce:
|
||||
command, err = p.handler.CommandAnnounce(p.in)
|
||||
if err != nil {
|
||||
p.logger.WithError(err).Error("could not handle command announce")
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
p.state = cmdData
|
||||
case cmdData:
|
||||
if command == nil {
|
||||
return errCommandExpected
|
||||
}
|
||||
|
||||
err = p.handler.CommandData(p.in, command)
|
||||
if err != nil {
|
||||
p.logger.WithError(err).Error("could not handle command data")
|
||||
|
||||
p.state = cmdAnnounce
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
p.state = handleCommand
|
||||
case handleCommand:
|
||||
if command == nil {
|
||||
return errCommandExpected
|
||||
}
|
||||
|
||||
response, err = p.handler.HandleCommand(command)
|
||||
if err != nil {
|
||||
p.logger.WithError(err).Error("could not handle command")
|
||||
|
||||
p.state = cmdAnnounce
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
p.state = respond
|
||||
case respond:
|
||||
if response == nil {
|
||||
return errResponseExpected
|
||||
}
|
||||
|
||||
err = p.handler.Respond(response, p.out)
|
||||
if err != nil {
|
||||
p.logger.WithError(err).Error("could not respond")
|
||||
|
||||
p.state = cmdAnnounce
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
p.state = cmdAnnounce
|
||||
default:
|
||||
return fmt.Errorf("unknown protocol state %s", p.state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewServer(handler ServerHandler, in, out chan []byte, logger *logrus.Logger) *ServerProtocol {
|
||||
return &ServerProtocol{
|
||||
handler: handler,
|
||||
in: in,
|
||||
out: out,
|
||||
logger: logger,
|
||||
state: cmdAnnounce,
|
||||
}
|
||||
}
|
||||
|
||||
// Framer handles bytes on the wire by adding or removing framing information.
|
||||
type Framer interface {
|
||||
// ReadFrames reads data frames and publishes unframed data to the channel.
|
||||
|
@ -71,6 +204,106 @@ type Framer interface {
|
|||
WriteFrames(io.Writer, chan []byte) error
|
||||
}
|
||||
|
||||
type ClientProtocol struct {
|
||||
handler ClientHandler
|
||||
commands chan *Command
|
||||
in, out chan []byte
|
||||
logger *logrus.Logger
|
||||
state protocolState
|
||||
}
|
||||
|
||||
func (p *ClientProtocol) Handle() error {
|
||||
var (
|
||||
command *Command
|
||||
response *Response
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
p.logger.Debugf("handling protocol state %s", p.state)
|
||||
|
||||
switch p.state {
|
||||
case cmdAnnounce:
|
||||
command = <-p.commands
|
||||
if command == nil {
|
||||
return errCommandExpected
|
||||
}
|
||||
|
||||
err = p.handler.Send(command, p.out)
|
||||
if err != nil {
|
||||
p.logger.WithError(err).Error("could not send command announce")
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
p.state = respAnnounce
|
||||
case respAnnounce:
|
||||
response, err = p.handler.ResponseAnnounce(p.in)
|
||||
if err != nil {
|
||||
p.logger.WithError(err).Error("could not handle response announce")
|
||||
|
||||
p.state = cmdAnnounce
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
p.state = respData
|
||||
case respData:
|
||||
if response == nil {
|
||||
return errResponseExpected
|
||||
}
|
||||
|
||||
err = p.handler.ResponseData(p.in, response)
|
||||
if err != nil {
|
||||
p.logger.WithError(err).Error("could not handle response data")
|
||||
|
||||
if errors.Is(err, ErrResponseDataTimeoutExpired) {
|
||||
p.state = cmdAnnounce
|
||||
} else {
|
||||
p.state = respAnnounce
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
p.state = handleResponse
|
||||
case handleResponse:
|
||||
if response == nil {
|
||||
return errResponseExpected
|
||||
}
|
||||
|
||||
err = p.handler.HandleResponse(response)
|
||||
if err != nil {
|
||||
p.logger.WithError(err).Error("could not handle response")
|
||||
|
||||
p.state = cmdAnnounce
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
p.state = cmdAnnounce
|
||||
default:
|
||||
return fmt.Errorf("unknown protocol state %s", p.state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewClient(
|
||||
handler ClientHandler,
|
||||
commands chan *Command,
|
||||
in, out chan []byte,
|
||||
logger *logrus.Logger,
|
||||
) *ClientProtocol {
|
||||
return &ClientProtocol{
|
||||
handler: handler,
|
||||
commands: commands,
|
||||
in: in,
|
||||
out: out,
|
||||
logger: logger,
|
||||
state: cmdAnnounce,
|
||||
}
|
||||
}
|
||||
|
||||
const bufferSize = 1024
|
||||
const readInterval = 50 * time.Millisecond
|
||||
|
||||
|
@ -127,7 +360,7 @@ func (c *COBSFramer) ReadFrames(reader io.Reader, frameChan chan []byte) error {
|
|||
if err = cobs.Verify(data, c.config); err != nil {
|
||||
c.logger.WithError(err).Warnf("skipping invalid frame of %d bytes", len(data))
|
||||
|
||||
break
|
||||
continue
|
||||
}
|
||||
|
||||
frame = cobs.Decode(data, c.config)
|
||||
|
|
Loading…
Reference in a new issue