Compare commits

...

3 commits

Author SHA1 Message Date
f429d3da45 Refactor server handler
- rename protocols.Handler to ServerHandler
- rename ServerHandler methods to better express their purpose
- pass command and response as parameters
- simplify state machine and handle errors in serial/seriallink.go
- implement command read timeout
- remove currentCommand and currentResponse fields from MsgPackHandler
2022-11-29 11:45:59 +01:00
9905d748d9 Improve signer robustness
- let client simulator send some garbage bytes before starting real commands
- handle EOF during reads
2022-11-29 10:29:09 +01:00
e5dcf7afa9 Refactor COBS wire protocol
Wire protocol handling has been moved to protocol.Framer and its
implementation protocol.COBSFramer
2022-11-29 09:57:23 +01:00
5 changed files with 428 additions and 499 deletions

View file

@ -20,15 +20,14 @@ limitations under the License.
package main
import (
"bytes"
"context"
"crypto/rand"
"fmt"
"io"
"os"
"sync"
"time"
"github.com/justincpresley/go-cobs"
"github.com/shamaton/msgpackgen/msgpack"
"github.com/sirupsen/logrus"
@ -37,8 +36,6 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages"
)
var cobsConfig = cobs.Config{SpecialByte: protocol.CobsDelimiter, Delimiter: true, EndingSave: true}
type protocolState int8
const (
@ -160,16 +157,27 @@ func (g *TestCommandGenerator) HandleResponse(frame []byte) error {
}
func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
const (
healthInterval = 5 * time.Second
startPause = 3 * time.Second
)
var (
announce *messages.CommandAnnounce
err error
)
// write some leading garbage to test signer robustness
_, _ = io.CopyN(os.Stdout, rand.Reader, 50) //nolint:gomnd
announce, err = messages.BuildCommandAnnounce(messages.CmdHealth)
if err != nil {
return fmt.Errorf("build command announce failed: %w", err)
}
g.commands <- &protocol.Command{Announce: announce, Command: &messages.HealthCommand{}}
const (
healthInterval = 5 * time.Second
crlInterval = 15 * time.Minute
startPause = 3 * time.Second
)
g.logger.Info("start generating commands")
time.Sleep(startPause)
@ -184,17 +192,18 @@ func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
Command: &messages.FetchCRLCommand{IssuerID: "sub-ecc_person_2022"},
}
timer := time.NewTimer(healthInterval)
healthTimer := time.NewTimer(healthInterval)
crlTimer := time.NewTimer(crlInterval)
for {
select {
case <-ctx.Done():
_ = timer.Stop()
_ = healthTimer.Stop()
g.logger.Info("stopping health check loop")
return nil
case <-timer.C:
case <-healthTimer.C:
announce, err = messages.BuildCommandAnnounce(messages.CmdHealth)
if err != nil {
return fmt.Errorf("build command announce failed: %w", err)
@ -204,9 +213,14 @@ func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
Announce: announce,
Command: &messages.HealthCommand{},
}
}
timer.Reset(healthInterval)
healthTimer.Reset(healthInterval)
case <-crlTimer.C:
g.commands <- &protocol.Command{
Announce: announce,
Command: &messages.FetchCRLCommand{IssuerID: "sub-ecc_person_2022"},
}
}
}
}
@ -215,106 +229,11 @@ type clientSimulator struct {
logger *logrus.Logger
lock sync.Mutex
framesIn chan []byte
framesOut chan []byte
framer protocol.Framer
commandGenerator *TestCommandGenerator
}
func (c *clientSimulator) readFrames() error {
const readInterval = 50 * time.Millisecond
var frame []byte
buffer := &bytes.Buffer{}
delimiter := []byte{cobsConfig.SpecialByte}
for {
readBytes, err := c.readFromStdin()
if err != nil {
c.logger.WithError(err).Error("stdin read error")
close(c.framesIn)
return err
}
if len(readBytes) == 0 {
time.Sleep(readInterval)
continue
}
c.logger.Tracef("read %d bytes", len(readBytes))
buffer.Write(readBytes)
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
rest := buffer.Bytes()
if !bytes.Contains(rest, delimiter) {
c.logger.Tracef("read data does not contain the delimiter %x", delimiter)
continue
}
for bytes.Contains(rest, delimiter) {
parts := bytes.SplitAfterN(rest, delimiter, 2)
frame, rest = parts[0], parts[1]
c.logger.Tracef("frame of length %d", len(frame))
if len(frame) == 0 {
continue
}
err = cobs.Verify(frame, cobsConfig)
if err != nil {
return fmt.Errorf("frame verification failed: %w", err)
}
decoded := cobs.Decode(frame, cobsConfig)
c.logger.Tracef("frame decoded to length %d", len(decoded))
c.framesIn <- decoded
c.logger.Tracef("%d bytes remaining", len(rest))
}
buffer.Truncate(0)
buffer.Write(rest)
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
}
func (c *clientSimulator) writeFrame(frame []byte) error {
encoded := cobs.Encode(frame, cobsConfig)
if _, err := io.Copy(os.Stdout, bytes.NewBuffer(encoded)); err != nil {
return fmt.Errorf("could not write data: %w", err)
}
return nil
}
func (c *clientSimulator) readFromStdin() ([]byte, error) {
const bufferSize = 1024
buf := make([]byte, bufferSize)
c.logger.Trace("waiting for input")
count, err := os.Stdin.Read(buf)
if err != nil {
return nil, fmt.Errorf("reading input failed: %w", err)
}
c.logger.Tracef("read %d bytes from stdin", count)
return buf[:count], nil
}
func (c *clientSimulator) writeCmdAnnouncement() error {
frame, err := c.commandGenerator.CmdAnnouncement()
if err != nil {
@ -323,9 +242,7 @@ func (c *clientSimulator) writeCmdAnnouncement() error {
c.logger.Trace("writing command announcement")
if err := c.writeFrame(frame); err != nil {
return err
}
c.framesOut <- frame
if err := c.nextState(); err != nil {
return err
@ -342,9 +259,7 @@ func (c *clientSimulator) writeCommand() error {
c.logger.Trace("writing command data")
if err := c.writeFrame(frame); err != nil {
return err
}
c.framesOut <- frame
if err := c.nextState(); err != nil {
return err
@ -353,44 +268,61 @@ func (c *clientSimulator) writeCommand() error {
return nil
}
const responseAnnounceTimeout = 30 * time.Second
const responseDataTimeout = 2 * time.Second
func (c *clientSimulator) handleResponseAnnounce() error {
c.logger.Trace("waiting for response announce")
frame := <-c.framesIn
select {
case frame := <-c.framesIn:
if frame == nil {
return nil
}
if err := c.commandGenerator.HandleResponseAnnounce(frame); err != nil {
return fmt.Errorf("response announce handling failed: %w", err)
}
if err := c.nextState(); err != nil {
return err
}
case <-time.After(responseAnnounceTimeout):
c.logger.Warn("response announce timeout expired")
c.protocolState = cmdAnnounce
if frame == nil {
return nil
}
if err := c.commandGenerator.HandleResponseAnnounce(frame); err != nil {
return fmt.Errorf("response announce handling failed: %w", err)
}
if err := c.nextState(); err != nil {
return err
}
return nil
}
func (c *clientSimulator) handleResponseData() error {
c.logger.Trace("waiting for response data")
frame := <-c.framesIn
select {
case frame := <-c.framesIn:
if frame == nil {
return nil
}
if err := c.commandGenerator.HandleResponse(frame); err != nil {
return fmt.Errorf("response handler failed: %w", err)
}
if err := c.nextState(); err != nil {
return err
}
return nil
case <-time.After(responseDataTimeout):
c.logger.Warn("response data timeout expired")
c.protocolState = cmdAnnounce
if frame == nil {
return nil
}
if err := c.commandGenerator.HandleResponse(frame); err != nil {
return fmt.Errorf("response handler failed: %w", err)
}
if err := c.nextState(); err != nil {
return err
}
return nil
}
func (c *clientSimulator) Run(ctx context.Context) error {
@ -398,7 +330,13 @@ func (c *clientSimulator) Run(ctx context.Context) error {
errors := make(chan error)
go func() {
err := c.readFrames()
err := c.framer.ReadFrames(os.Stdin, c.framesIn)
errors <- err
}()
go func() {
err := c.framer.WriteFrames(os.Stdout, c.framesOut)
errors <- err
}()
@ -480,8 +418,10 @@ func main() {
logger: logger,
commands: make(chan *protocol.Command),
},
logger: logger,
framesIn: make(chan []byte),
logger: logger,
framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: protocol.NewCOBSFramer(logger),
}
err := sim.Run(context.Background())

View file

@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"sync"
"time"
"github.com/shamaton/msgpackgen/msgpack"
"github.com/sirupsen/logrus"
@ -32,185 +33,186 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages"
)
// MsgPackHandler is a Handler implementation for the msgpack serialization format.
const readCommandTimeOut = 5 * time.Second
var errReadCommandTimeout = errors.New("read command timeout expired")
// MsgPackHandler is a ServerHandler implementation for the msgpack serialization format.
type MsgPackHandler struct {
logger *logrus.Logger
healthHandler *health.Handler
fetchCRLHandler *revoking.FetchCRLHandler
currentCommand *protocol.Command
currentResponse *protocol.Response
lock sync.Mutex
}
func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) error {
func (m *MsgPackHandler) CommandAnnounce(frames chan []byte) (*protocol.Command, error) {
m.lock.Lock()
defer m.lock.Unlock()
frame := <-frames
var ann messages.CommandAnnounce
if err := msgpack.Unmarshal(frame, &ann); err != nil {
return fmt.Errorf("could not unmarshal command announcement: %w", err)
return nil, fmt.Errorf("could not unmarshal command announcement: %w", err)
}
m.logger.WithField("announcement", &ann).Info("received command announcement")
m.currentCommand = &protocol.Command{Announce: &ann}
return nil
return &protocol.Command{Announce: &ann}, nil
}
func (m *MsgPackHandler) HandleCommand(frame []byte) error {
func (m *MsgPackHandler) CommandData(frames chan []byte, command *protocol.Command) error {
m.lock.Lock()
defer m.lock.Unlock()
err := m.parseCommand(frame)
if err != nil {
m.currentResponse = m.buildErrorResponse(err.Error())
m.logCommandResponse()
select {
case frame := <-frames:
err := m.parseCommand(frame, command)
if err != nil {
return err
}
return nil
case <-time.After(readCommandTimeOut):
return errReadCommandTimeout
}
}
err = m.handleCommand()
func (m *MsgPackHandler) HandleCommand(command *protocol.Command) (*protocol.Response, error) {
m.lock.Lock()
defer m.lock.Unlock()
var (
response *protocol.Response
err error
)
response, err = m.handleCommand(command)
if err != nil {
m.logger.WithError(err).Error("command handling failed")
return err
response = m.buildErrorResponse(command.Announce.ID, "command handling failed")
}
m.logCommandResponse()
m.logCommandResponse(command, response)
m.currentCommand = nil
return response, nil
}
func (m *MsgPackHandler) logCommandResponse(command *protocol.Command, response *protocol.Response) {
m.logger.WithField("command", command.Announce).Info("handled command")
m.logger.WithField("command", command).WithField("response", response).Debug("command and response")
}
func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) error {
m.lock.Lock()
defer m.lock.Unlock()
announce, err := msgpack.Marshal(response)
if err != nil {
return fmt.Errorf("could not marshal response announcement: %w", err)
}
m.logger.WithField("length", len(announce)).Debug("write response announcement")
out <- announce
data, err := msgpack.Marshal(response.Response)
if err != nil {
return fmt.Errorf("could not marshal response: %w", err)
}
m.logger.WithField("length", len(data)).Debug("write response")
out <- announce
return nil
}
func (m *MsgPackHandler) logCommandResponse() {
m.logger.WithField("command", m.currentCommand.Announce).Info("handled command")
m.logger.WithField(
"command",
m.currentCommand,
).WithField(
"response",
m.currentResponse,
).Debug("command and response")
}
func (m *MsgPackHandler) ResponseAnnounce() ([]byte, error) {
m.lock.Lock()
defer m.lock.Unlock()
announceData, err := msgpack.Marshal(m.currentResponse.Announce)
if err != nil {
return nil, fmt.Errorf("could not marshal response announcement: %w", err)
}
m.logger.WithField("announcement", m.currentResponse.Announce).Debug("write response announcement")
return announceData, nil
}
func (m *MsgPackHandler) ResponseData() ([]byte, error) {
m.lock.Lock()
defer m.lock.Unlock()
responseData, err := msgpack.Marshal(m.currentResponse.Response)
if err != nil {
return nil, fmt.Errorf("could not marshal response: %w", err)
}
m.logger.WithField("response", m.currentResponse.Response).Debug("write response")
return responseData, nil
}
func (m *MsgPackHandler) parseHealthCommand(frame []byte) error {
func (m *MsgPackHandler) parseHealthCommand(frame []byte) (*messages.HealthCommand, error) {
var command messages.HealthCommand
if err := msgpack.Unmarshal(frame, &command); err != nil {
m.logger.WithError(err).Error("unmarshal failed")
return errors.New("could not unmarshal health command")
return nil, errors.New("could not unmarshal health command")
}
m.currentCommand.Command = &command
return nil
return &command, nil
}
func (m *MsgPackHandler) parseFetchCRLCommand(frame []byte) error {
func (m *MsgPackHandler) parseFetchCRLCommand(frame []byte) (*messages.FetchCRLCommand, error) {
var command messages.FetchCRLCommand
if err := msgpack.Unmarshal(frame, &command); err != nil {
m.logger.WithError(err).Error("unmarshal failed")
return errors.New("could not unmarshal fetch crl command")
return nil, errors.New("could not unmarshal fetch crl command")
}
m.currentCommand.Command = &command
return nil
return &command, nil
}
func (m *MsgPackHandler) currentID() string {
return m.currentCommand.Announce.ID
}
func (m *MsgPackHandler) handleCommand() error {
func (m *MsgPackHandler) handleCommand(command *protocol.Command) (*protocol.Response, error) {
var (
err error
responseData interface{}
responseCode messages.ResponseCode
responseData interface{}
)
switch m.currentCommand.Command.(type) {
switch cmd := command.Command.(type) {
case *messages.HealthCommand:
response, err := m.handleHealthCommand()
if err != nil {
return err
return nil, err
}
responseCode, responseData = messages.RespHealth, response
case *messages.FetchCRLCommand:
response, err := m.handleFetchCRLCommand()
response, err := m.handleFetchCRLCommand(cmd)
if err != nil {
return err
return nil, err
}
responseCode, responseData = messages.RespFetchCRL, response
default:
return fmt.Errorf("unhandled command %s", m.currentCommand.Announce)
return nil, fmt.Errorf("unhandled command %s", command.Announce)
}
if err != nil {
return fmt.Errorf("error from command handler: %w", err)
}
m.currentResponse = &protocol.Response{
Announce: messages.BuildResponseAnnounce(responseCode, m.currentID()),
return &protocol.Response{
Announce: messages.BuildResponseAnnounce(responseCode, command.Announce.ID),
Response: responseData,
}
return nil
}, nil
}
func (m *MsgPackHandler) buildErrorResponse(errMsg string) *protocol.Response {
func (m *MsgPackHandler) buildErrorResponse(commandID string, errMsg string) *protocol.Response {
return &protocol.Response{
Announce: messages.BuildResponseAnnounce(messages.RespError, m.currentID()),
Announce: messages.BuildResponseAnnounce(messages.RespError, commandID),
Response: &messages.ErrorResponse{Message: errMsg},
}
}
func (m *MsgPackHandler) parseCommand(frame []byte) error {
switch m.currentCommand.Announce.Code {
func (m *MsgPackHandler) parseCommand(frame []byte, command *protocol.Command) error {
switch command.Announce.Code {
case messages.CmdHealth:
return m.parseHealthCommand(frame)
healthCommand, err := m.parseHealthCommand(frame)
if err != nil {
return err
}
command.Command = healthCommand
case messages.CmdFetchCRL:
return m.parseFetchCRLCommand(frame)
fetchCRLCommand, err := m.parseFetchCRLCommand(frame)
if err != nil {
return err
}
command.Command = fetchCRLCommand
default:
return fmt.Errorf("unhandled command code %s", m.currentCommand.Announce.Code)
return fmt.Errorf("unhandled command code %s", command.Announce.Code)
}
return nil
}
func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error) {
@ -235,27 +237,20 @@ func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error)
return response, nil
}
func (m *MsgPackHandler) handleFetchCRLCommand() (*messages.FetchCRLResponse, error) {
fetchCRLPayload, ok := m.currentCommand.Command.(*messages.FetchCRLCommand)
if !ok {
return nil, fmt.Errorf("could not use payload as FetchCRLPayload")
}
res, err := m.fetchCRLHandler.FetchCRL(fetchCRLPayload.IssuerID)
func (m *MsgPackHandler) handleFetchCRLCommand(command *messages.FetchCRLCommand) (*messages.FetchCRLResponse, error) {
res, err := m.fetchCRLHandler.FetchCRL(command.IssuerID)
if err != nil {
return nil, fmt.Errorf("could not fetch CRL: %w", err)
}
response := &messages.FetchCRLResponse{
return &messages.FetchCRLResponse{
IsDelta: false,
CRLNumber: res.Number,
CRLData: res.CRLData,
}
return response, nil
}, nil
}
func New(logger *logrus.Logger, handlers ...RegisterHandler) (protocol.Handler, error) {
func New(logger *logrus.Logger, handlers ...RegisterHandler) (protocol.ServerHandler, error) {
messages.RegisterGeneratedResolver()
h := &MsgPackHandler{

View file

@ -15,18 +15,14 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
// Package seriallink provides a handler for the serial connection of the signer machine.
// Package serial provides a handler for the serial connection of the signer machine.
package serial
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/justincpresley/go-cobs"
"github.com/sirupsen/logrus"
"github.com/tarm/serial"
@ -39,22 +35,15 @@ type protocolState int8
const (
cmdAnnounce protocolState = iota
cmdData
respAnnounce
respData
handleCommand
respond
)
var validTransitions = map[protocolState]protocolState{
cmdAnnounce: cmdData,
cmdData: respAnnounce,
respAnnounce: respData,
respData: cmdAnnounce,
}
var protocolStateNames = map[protocolState]string{
cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA",
respAnnounce: "RESP ANNOUNCE",
respData: "RESP DATA",
cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA",
handleCommand: "RESP ANNOUNCE",
respond: "RESP DATA",
}
func (p protocolState) String() string {
@ -66,13 +55,14 @@ func (p protocolState) String() string {
}
type Handler struct {
protocolHandler protocol.Handler
protocolHandler protocol.ServerHandler
protocolState protocolState
framer protocol.Framer
config *serial.Config
port *serial.Port
logger *logrus.Logger
lock sync.Mutex
framesIn chan []byte
framesOut chan []byte
}
func (h *Handler) setupConnection() error {
@ -95,256 +85,137 @@ func (h *Handler) Close() error {
return nil
}
var cobsConfig = cobs.Config{SpecialByte: protocol.CobsDelimiter, Delimiter: true, EndingSave: true}
func (h *Handler) Run(ctx context.Context) error {
h.protocolState = cmdAnnounce
errors := make(chan error)
protocolErrors, framerErrors := make(chan error), make(chan error)
go func() {
err := h.readFrames()
err := h.framer.ReadFrames(h.port, h.framesIn)
errors <- err
framerErrors <- err
}()
go func() {
err := h.framer.WriteFrames(h.port, h.framesOut)
framerErrors <- err
}()
go func() {
err := h.handleProtocolState()
protocolErrors <- err
}()
for {
select {
case <-ctx.Done():
return nil
case err := <-errors:
case err := <-framerErrors:
if err != nil {
return fmt.Errorf("error from handler loop: %w", err)
return fmt.Errorf("error from framer: %w", err)
}
return nil
default:
if err := h.handleProtocolState(); err != nil {
return err
case err := <-protocolErrors:
if err != nil {
return fmt.Errorf("error from protocol handler: %w", err)
}
return nil
}
}
}
func (h *Handler) readFrames() error {
const (
readInterval = 50 * time.Millisecond
)
var frame []byte
buffer := &bytes.Buffer{}
delimiter := []byte{cobsConfig.SpecialByte}
for {
readBytes, err := h.readFromPort()
if err != nil {
close(h.framesIn)
return err
}
if len(readBytes) == 0 {
time.Sleep(readInterval)
continue
}
h.logger.Tracef("read %d bytes", len(readBytes))
buffer.Write(readBytes)
h.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
rest := buffer.Bytes()
if !bytes.Contains(rest, delimiter) {
continue
}
for bytes.Contains(rest, delimiter) {
parts := bytes.SplitAfterN(rest, delimiter, 2)
frame, rest = parts[0], parts[1]
h.logger.Tracef("frame of length %d", len(frame))
if len(frame) == 0 {
continue
}
if err := cobs.Verify(frame, cobsConfig); err != nil {
close(h.framesIn)
return fmt.Errorf("could not verify COBS frame: %w", err)
}
decoded := cobs.Decode(frame, cobsConfig)
h.logger.Tracef("frame decoded to length %d", len(decoded))
h.framesIn <- decoded
}
buffer.Truncate(0)
buffer.Write(rest)
h.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
}
func (h *Handler) writeFrame(frame []byte) error {
encoded := cobs.Encode(frame, cobsConfig)
return h.writeToPort(encoded)
}
func (h *Handler) nextState() error {
next, ok := validTransitions[h.protocolState]
if !ok {
return fmt.Errorf("illegal protocol state %s", h.protocolState)
}
h.protocolState = next
return nil
}
var errCommandExpected = errors.New("command must not be nil")
var errResponseExpected = errors.New("response must not be nil")
func (h *Handler) handleProtocolState() error {
h.logger.Tracef("handling protocol state %s", h.protocolState)
var (
command *protocol.Command
response *protocol.Response
err error
)
h.lock.Lock()
defer h.lock.Unlock()
for {
h.logger.Debugf("handling protocol state %s", h.protocolState)
switch h.protocolState {
case cmdAnnounce:
if err := h.handleCmdAnnounce(); err != nil {
return err
switch h.protocolState {
case cmdAnnounce:
command, err = h.protocolHandler.CommandAnnounce(h.framesIn)
if err != nil {
h.logger.WithError(err).Error("could not handle command announce")
break
}
h.protocolState = cmdData
case cmdData:
if command == nil {
return errCommandExpected
}
err = h.protocolHandler.CommandData(h.framesIn, command)
if err != nil {
h.logger.WithError(err).Error("could not handle command data")
h.protocolState = cmdAnnounce
break
}
h.protocolState = handleCommand
case handleCommand:
if command == nil {
return errCommandExpected
}
response, err = h.protocolHandler.HandleCommand(command)
if err != nil {
h.logger.WithError(err).Error("could not handle command")
h.protocolState = cmdAnnounce
break
}
command = nil
h.protocolState = respond
case respond:
if response == nil {
return errResponseExpected
}
err = h.protocolHandler.Respond(response, h.framesOut)
if err != nil {
h.logger.WithError(err).Error("could not respond")
h.protocolState = cmdAnnounce
break
}
response = nil
h.protocolState = cmdAnnounce
default:
return fmt.Errorf("unknown protocol state %s", h.protocolState)
}
case cmdData:
if err := h.handleCmdData(); err != nil {
return err
}
case respAnnounce:
if err := h.handleRespAnnounce(); err != nil {
return err
}
case respData:
if err := h.handleRespData(); err != nil {
return err
}
default:
return fmt.Errorf("unknown protocol state %s", h.protocolState)
}
return nil
}
func (h *Handler) writeToPort(data []byte) error {
reader := bytes.NewReader(data)
n, err := io.Copy(h.port, reader)
if err != nil {
return fmt.Errorf("could not write data: %w", err)
}
h.logger.Tracef("wrote %d bytes", n)
return nil
}
func (h *Handler) readFromPort() ([]byte, error) {
const bufferSize = 1024
buf := make([]byte, bufferSize)
count, err := h.port.Read(buf)
if err != nil {
return nil, fmt.Errorf("could not read from serial port: %w", err)
}
return buf[:count], nil
}
func (h *Handler) handleCmdAnnounce() error {
h.logger.Trace("waiting for command announce")
frame := <-h.framesIn
if frame == nil {
return nil
}
if err := h.protocolHandler.HandleCommandAnnounce(frame); err != nil {
return fmt.Errorf("command announce handling failed: %w", err)
}
if err := h.nextState(); err != nil {
return err
}
return nil
}
func (h *Handler) handleCmdData() error {
h.logger.Trace("waiting for command data")
frame := <-h.framesIn
if frame == nil {
return nil
}
if err := h.protocolHandler.HandleCommand(frame); err != nil {
return fmt.Errorf("command handler failed: %w", err)
}
if err := h.nextState(); err != nil {
return err
}
return nil
}
func (h *Handler) handleRespAnnounce() error {
frame, err := h.protocolHandler.ResponseAnnounce()
if err != nil {
return fmt.Errorf("could not get response announcement: %w", err)
}
if err := h.writeFrame(frame); err != nil {
return err
}
if err := h.nextState(); err != nil {
return err
}
return nil
}
func (h *Handler) handleRespData() error {
frame, err := h.protocolHandler.ResponseData()
if err != nil {
return fmt.Errorf("could not get response data: %w", err)
}
if err := h.writeFrame(frame); err != nil {
return err
}
if err := h.nextState(); err != nil {
return err
}
return nil
}
func New(cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.Handler) (*Handler, error) {
func New(
cfg *config.Serial,
logger *logrus.Logger,
protocolHandler protocol.ServerHandler,
) (*Handler, error) {
h := &Handler{
protocolHandler: protocolHandler,
logger: logger,
framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: protocol.NewCOBSFramer(logger),
}
h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout}

View file

@ -43,7 +43,7 @@ const (
var commandNames = map[CommandCode]string{
CmdHealth: "HEALTH",
CmdFetchCRL: "FETCH URL",
CmdFetchCRL: "FETCH CRL",
}
func (c CommandCode) String() string {

View file

@ -19,7 +19,14 @@ limitations under the License.
package protocol
import (
"bytes"
"errors"
"fmt"
"io"
"time"
"github.com/justincpresley/go-cobs"
"github.com/sirupsen/logrus"
"git.cacert.org/cacert-gosigner/pkg/messages"
)
@ -44,14 +51,130 @@ func (r *Response) String() string {
return fmt.Sprintf("Rsp[announce={%s}, data={%s}]", r.Announce, r.Response)
}
// Handler is responsible for parsing incoming frames and calling commands
type Handler interface {
// HandleCommandAnnounce handles the initial announcement of a command.
HandleCommandAnnounce([]byte) error
// HandleCommand handles the command data.
HandleCommand([]byte) error
// ResponseAnnounce generates the announcement for a response.
ResponseAnnounce() ([]byte, error)
// ResponseData generates the response data.
ResponseData() ([]byte, error)
// ServerHandler is responsible for parsing incoming frames and calling commands
type ServerHandler interface {
// CommandAnnounce handles the initial announcement of a command.
CommandAnnounce(chan []byte) (*Command, error)
// CommandData handles the command data.
CommandData(chan []byte, *Command) error
// HandleCommand executes the command, generating a response.
HandleCommand(*Command) (*Response, error)
// Respond generates the response for a command.
Respond(*Response, chan []byte) error
}
// Framer handles bytes on the wire by adding or removing framing information.
type Framer interface {
// ReadFrames reads data frames and publishes unframed data to the channel.
ReadFrames(io.Reader, chan []byte) error
// WriteFrames takes data from the channel and writes framed data to the writer.
WriteFrames(io.Writer, chan []byte) error
}
const bufferSize = 1024
const readInterval = 50 * time.Millisecond
type COBSFramer struct {
config cobs.Config
logger *logrus.Logger
}
func NewCOBSFramer(logger *logrus.Logger) *COBSFramer {
return &COBSFramer{
config: cobs.Config{SpecialByte: CobsDelimiter, Delimiter: true, EndingSave: true},
logger: logger,
}
}
func (c *COBSFramer) ReadFrames(reader io.Reader, frameChan chan []byte) error {
var (
err error
raw, data, frame []byte
)
buffer := &bytes.Buffer{}
for {
raw, err = c.readRaw(reader)
if err != nil {
close(frameChan)
return err
}
if len(raw) == 0 {
time.Sleep(readInterval)
continue
}
c.logger.Tracef("read %d raw bytes", len(raw))
buffer.Write(raw)
for {
data, err = buffer.ReadBytes(c.config.SpecialByte)
if err != nil {
if errors.Is(err, io.EOF) {
buffer.Write(data)
break
}
return fmt.Errorf("could not read from buffer: %w", err)
}
if err = cobs.Verify(data, c.config); err != nil {
c.logger.WithError(err).Warnf("skipping invalid frame of %d bytes", len(data))
break
}
frame = cobs.Decode(data, c.config)
c.logger.Tracef("frame decoded to length %d", len(frame))
frameChan <- frame
}
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
}
func (c *COBSFramer) readRaw(reader io.Reader) ([]byte, error) {
buf := make([]byte, bufferSize)
count, err := reader.Read(buf)
if err != nil {
if errors.Is(err, io.EOF) {
return []byte{}, nil
}
return nil, fmt.Errorf("could not read data: %w", err)
}
raw := buf[:count]
return raw, nil
}
func (c *COBSFramer) WriteFrames(writer io.Writer, frameChan chan []byte) error {
for {
frame := <-frameChan
if frame == nil {
c.logger.Debug("channel closed")
return nil
}
encoded := cobs.Encode(frame, c.config)
n, err := io.Copy(writer, bytes.NewReader(encoded))
if err != nil {
return fmt.Errorf("cold not write data: %w", err)
}
c.logger.Tracef("wrote %d bytes", n)
}
}