Compare commits

..

No commits in common. 'f429d3da452d3f517b90be0c049dc1c800afec77' and 'faaadbe5aa6ffce8c2561535a4f476688c708c87' have entirely different histories.

@ -20,14 +20,15 @@ limitations under the License.
package main package main
import ( import (
"bytes"
"context" "context"
"crypto/rand"
"fmt" "fmt"
"io" "io"
"os" "os"
"sync" "sync"
"time" "time"
"github.com/justincpresley/go-cobs"
"github.com/shamaton/msgpackgen/msgpack" "github.com/shamaton/msgpackgen/msgpack"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -36,6 +37,8 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages" "git.cacert.org/cacert-gosigner/pkg/messages"
) )
var cobsConfig = cobs.Config{SpecialByte: protocol.CobsDelimiter, Delimiter: true, EndingSave: true}
type protocolState int8 type protocolState int8
const ( const (
@ -157,27 +160,16 @@ func (g *TestCommandGenerator) HandleResponse(frame []byte) error {
} }
func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error { func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
var (
announce *messages.CommandAnnounce
err error
)
// write some leading garbage to test signer robustness
_, _ = io.CopyN(os.Stdout, rand.Reader, 50) //nolint:gomnd
announce, err = messages.BuildCommandAnnounce(messages.CmdHealth)
if err != nil {
return fmt.Errorf("build command announce failed: %w", err)
}
g.commands <- &protocol.Command{Announce: announce, Command: &messages.HealthCommand{}}
const ( const (
healthInterval = 5 * time.Second healthInterval = 5 * time.Second
crlInterval = 15 * time.Minute
startPause = 3 * time.Second startPause = 3 * time.Second
) )
var (
announce *messages.CommandAnnounce
err error
)
g.logger.Info("start generating commands") g.logger.Info("start generating commands")
time.Sleep(startPause) time.Sleep(startPause)
@ -192,18 +184,17 @@ func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
Command: &messages.FetchCRLCommand{IssuerID: "sub-ecc_person_2022"}, Command: &messages.FetchCRLCommand{IssuerID: "sub-ecc_person_2022"},
} }
healthTimer := time.NewTimer(healthInterval) timer := time.NewTimer(healthInterval)
crlTimer := time.NewTimer(crlInterval)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
_ = healthTimer.Stop() _ = timer.Stop()
g.logger.Info("stopping health check loop") g.logger.Info("stopping health check loop")
return nil return nil
case <-healthTimer.C: case <-timer.C:
announce, err = messages.BuildCommandAnnounce(messages.CmdHealth) announce, err = messages.BuildCommandAnnounce(messages.CmdHealth)
if err != nil { if err != nil {
return fmt.Errorf("build command announce failed: %w", err) return fmt.Errorf("build command announce failed: %w", err)
@ -213,14 +204,9 @@ func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
Announce: announce, Announce: announce,
Command: &messages.HealthCommand{}, Command: &messages.HealthCommand{},
} }
healthTimer.Reset(healthInterval)
case <-crlTimer.C:
g.commands <- &protocol.Command{
Announce: announce,
Command: &messages.FetchCRLCommand{IssuerID: "sub-ecc_person_2022"},
}
} }
timer.Reset(healthInterval)
} }
} }
@ -229,11 +215,106 @@ type clientSimulator struct {
logger *logrus.Logger logger *logrus.Logger
lock sync.Mutex lock sync.Mutex
framesIn chan []byte framesIn chan []byte
framesOut chan []byte
framer protocol.Framer
commandGenerator *TestCommandGenerator commandGenerator *TestCommandGenerator
} }
func (c *clientSimulator) readFrames() error {
const readInterval = 50 * time.Millisecond
var frame []byte
buffer := &bytes.Buffer{}
delimiter := []byte{cobsConfig.SpecialByte}
for {
readBytes, err := c.readFromStdin()
if err != nil {
c.logger.WithError(err).Error("stdin read error")
close(c.framesIn)
return err
}
if len(readBytes) == 0 {
time.Sleep(readInterval)
continue
}
c.logger.Tracef("read %d bytes", len(readBytes))
buffer.Write(readBytes)
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
rest := buffer.Bytes()
if !bytes.Contains(rest, delimiter) {
c.logger.Tracef("read data does not contain the delimiter %x", delimiter)
continue
}
for bytes.Contains(rest, delimiter) {
parts := bytes.SplitAfterN(rest, delimiter, 2)
frame, rest = parts[0], parts[1]
c.logger.Tracef("frame of length %d", len(frame))
if len(frame) == 0 {
continue
}
err = cobs.Verify(frame, cobsConfig)
if err != nil {
return fmt.Errorf("frame verification failed: %w", err)
}
decoded := cobs.Decode(frame, cobsConfig)
c.logger.Tracef("frame decoded to length %d", len(decoded))
c.framesIn <- decoded
c.logger.Tracef("%d bytes remaining", len(rest))
}
buffer.Truncate(0)
buffer.Write(rest)
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
}
func (c *clientSimulator) writeFrame(frame []byte) error {
encoded := cobs.Encode(frame, cobsConfig)
if _, err := io.Copy(os.Stdout, bytes.NewBuffer(encoded)); err != nil {
return fmt.Errorf("could not write data: %w", err)
}
return nil
}
func (c *clientSimulator) readFromStdin() ([]byte, error) {
const bufferSize = 1024
buf := make([]byte, bufferSize)
c.logger.Trace("waiting for input")
count, err := os.Stdin.Read(buf)
if err != nil {
return nil, fmt.Errorf("reading input failed: %w", err)
}
c.logger.Tracef("read %d bytes from stdin", count)
return buf[:count], nil
}
func (c *clientSimulator) writeCmdAnnouncement() error { func (c *clientSimulator) writeCmdAnnouncement() error {
frame, err := c.commandGenerator.CmdAnnouncement() frame, err := c.commandGenerator.CmdAnnouncement()
if err != nil { if err != nil {
@ -242,7 +323,9 @@ func (c *clientSimulator) writeCmdAnnouncement() error {
c.logger.Trace("writing command announcement") c.logger.Trace("writing command announcement")
c.framesOut <- frame if err := c.writeFrame(frame); err != nil {
return err
}
if err := c.nextState(); err != nil { if err := c.nextState(); err != nil {
return err return err
@ -259,7 +342,9 @@ func (c *clientSimulator) writeCommand() error {
c.logger.Trace("writing command data") c.logger.Trace("writing command data")
c.framesOut <- frame if err := c.writeFrame(frame); err != nil {
return err
}
if err := c.nextState(); err != nil { if err := c.nextState(); err != nil {
return err return err
@ -268,31 +353,21 @@ func (c *clientSimulator) writeCommand() error {
return nil return nil
} }
const responseAnnounceTimeout = 30 * time.Second
const responseDataTimeout = 2 * time.Second
func (c *clientSimulator) handleResponseAnnounce() error { func (c *clientSimulator) handleResponseAnnounce() error {
c.logger.Trace("waiting for response announce") c.logger.Trace("waiting for response announce")
select { frame := <-c.framesIn
case frame := <-c.framesIn:
if frame == nil {
return nil
}
if err := c.commandGenerator.HandleResponseAnnounce(frame); err != nil {
return fmt.Errorf("response announce handling failed: %w", err)
}
if err := c.nextState(); err != nil { if frame == nil {
return err return nil
} }
case <-time.After(responseAnnounceTimeout):
c.logger.Warn("response announce timeout expired")
c.protocolState = cmdAnnounce if err := c.commandGenerator.HandleResponseAnnounce(frame); err != nil {
return fmt.Errorf("response announce handling failed: %w", err)
}
return nil if err := c.nextState(); err != nil {
return err
} }
return nil return nil
@ -301,28 +376,21 @@ func (c *clientSimulator) handleResponseAnnounce() error {
func (c *clientSimulator) handleResponseData() error { func (c *clientSimulator) handleResponseData() error {
c.logger.Trace("waiting for response data") c.logger.Trace("waiting for response data")
select { frame := <-c.framesIn
case frame := <-c.framesIn:
if frame == nil {
return nil
}
if err := c.commandGenerator.HandleResponse(frame); err != nil {
return fmt.Errorf("response handler failed: %w", err)
}
if err := c.nextState(); err != nil {
return err
}
if frame == nil {
return nil return nil
case <-time.After(responseDataTimeout): }
c.logger.Warn("response data timeout expired")
c.protocolState = cmdAnnounce if err := c.commandGenerator.HandleResponse(frame); err != nil {
return fmt.Errorf("response handler failed: %w", err)
}
return nil if err := c.nextState(); err != nil {
return err
} }
return nil
} }
func (c *clientSimulator) Run(ctx context.Context) error { func (c *clientSimulator) Run(ctx context.Context) error {
@ -330,13 +398,7 @@ func (c *clientSimulator) Run(ctx context.Context) error {
errors := make(chan error) errors := make(chan error)
go func() { go func() {
err := c.framer.ReadFrames(os.Stdin, c.framesIn) err := c.readFrames()
errors <- err
}()
go func() {
err := c.framer.WriteFrames(os.Stdout, c.framesOut)
errors <- err errors <- err
}() }()
@ -418,10 +480,8 @@ func main() {
logger: logger, logger: logger,
commands: make(chan *protocol.Command), commands: make(chan *protocol.Command),
}, },
logger: logger, logger: logger,
framesIn: make(chan []byte), framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: protocol.NewCOBSFramer(logger),
} }
err := sim.Run(context.Background()) err := sim.Run(context.Background())

@ -21,7 +21,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
"time"
"github.com/shamaton/msgpackgen/msgpack" "github.com/shamaton/msgpackgen/msgpack"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -33,186 +32,185 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages" "git.cacert.org/cacert-gosigner/pkg/messages"
) )
const readCommandTimeOut = 5 * time.Second // MsgPackHandler is a Handler implementation for the msgpack serialization format.
var errReadCommandTimeout = errors.New("read command timeout expired")
// MsgPackHandler is a ServerHandler implementation for the msgpack serialization format.
type MsgPackHandler struct { type MsgPackHandler struct {
logger *logrus.Logger logger *logrus.Logger
healthHandler *health.Handler healthHandler *health.Handler
fetchCRLHandler *revoking.FetchCRLHandler fetchCRLHandler *revoking.FetchCRLHandler
currentCommand *protocol.Command
currentResponse *protocol.Response
lock sync.Mutex lock sync.Mutex
} }
func (m *MsgPackHandler) CommandAnnounce(frames chan []byte) (*protocol.Command, error) { func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) error {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock() defer m.lock.Unlock()
frame := <-frames
var ann messages.CommandAnnounce var ann messages.CommandAnnounce
if err := msgpack.Unmarshal(frame, &ann); err != nil { if err := msgpack.Unmarshal(frame, &ann); err != nil {
return nil, fmt.Errorf("could not unmarshal command announcement: %w", err) return fmt.Errorf("could not unmarshal command announcement: %w", err)
} }
m.logger.WithField("announcement", &ann).Info("received command announcement") m.logger.WithField("announcement", &ann).Info("received command announcement")
return &protocol.Command{Announce: &ann}, nil m.currentCommand = &protocol.Command{Announce: &ann}
return nil
} }
func (m *MsgPackHandler) CommandData(frames chan []byte, command *protocol.Command) error { func (m *MsgPackHandler) HandleCommand(frame []byte) error {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock() defer m.lock.Unlock()
select { err := m.parseCommand(frame)
case frame := <-frames: if err != nil {
err := m.parseCommand(frame, command) m.currentResponse = m.buildErrorResponse(err.Error())
if err != nil {
return err m.logCommandResponse()
}
return nil return nil
case <-time.After(readCommandTimeOut):
return errReadCommandTimeout
} }
}
func (m *MsgPackHandler) HandleCommand(command *protocol.Command) (*protocol.Response, error) { err = m.handleCommand()
m.lock.Lock()
defer m.lock.Unlock()
var (
response *protocol.Response
err error
)
response, err = m.handleCommand(command)
if err != nil { if err != nil {
m.logger.WithError(err).Error("command handling failed") m.logger.WithError(err).Error("command handling failed")
response = m.buildErrorResponse(command.Announce.ID, "command handling failed") return err
} }
m.logCommandResponse(command, response) m.logCommandResponse()
return response, nil m.currentCommand = nil
return nil
} }
func (m *MsgPackHandler) logCommandResponse(command *protocol.Command, response *protocol.Response) { func (m *MsgPackHandler) logCommandResponse() {
m.logger.WithField("command", command.Announce).Info("handled command") m.logger.WithField("command", m.currentCommand.Announce).Info("handled command")
m.logger.WithField("command", command).WithField("response", response).Debug("command and response") m.logger.WithField(
"command",
m.currentCommand,
).WithField(
"response",
m.currentResponse,
).Debug("command and response")
} }
func (m *MsgPackHandler) Respond(response *protocol.Response, out chan []byte) error { func (m *MsgPackHandler) ResponseAnnounce() ([]byte, error) {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock() defer m.lock.Unlock()
announce, err := msgpack.Marshal(response) announceData, err := msgpack.Marshal(m.currentResponse.Announce)
if err != nil { if err != nil {
return fmt.Errorf("could not marshal response announcement: %w", err) return nil, fmt.Errorf("could not marshal response announcement: %w", err)
} }
m.logger.WithField("length", len(announce)).Debug("write response announcement") m.logger.WithField("announcement", m.currentResponse.Announce).Debug("write response announcement")
return announceData, nil
}
out <- announce func (m *MsgPackHandler) ResponseData() ([]byte, error) {
m.lock.Lock()
defer m.lock.Unlock()
data, err := msgpack.Marshal(response.Response) responseData, err := msgpack.Marshal(m.currentResponse.Response)
if err != nil { if err != nil {
return fmt.Errorf("could not marshal response: %w", err) return nil, fmt.Errorf("could not marshal response: %w", err)
} }
m.logger.WithField("length", len(data)).Debug("write response") m.logger.WithField("response", m.currentResponse.Response).Debug("write response")
out <- announce return responseData, nil
return nil
} }
func (m *MsgPackHandler) parseHealthCommand(frame []byte) (*messages.HealthCommand, error) { func (m *MsgPackHandler) parseHealthCommand(frame []byte) error {
var command messages.HealthCommand var command messages.HealthCommand
if err := msgpack.Unmarshal(frame, &command); err != nil { if err := msgpack.Unmarshal(frame, &command); err != nil {
m.logger.WithError(err).Error("unmarshal failed") m.logger.WithError(err).Error("unmarshal failed")
return nil, errors.New("could not unmarshal health command") return errors.New("could not unmarshal health command")
} }
return &command, nil m.currentCommand.Command = &command
return nil
} }
func (m *MsgPackHandler) parseFetchCRLCommand(frame []byte) (*messages.FetchCRLCommand, error) { func (m *MsgPackHandler) parseFetchCRLCommand(frame []byte) error {
var command messages.FetchCRLCommand var command messages.FetchCRLCommand
if err := msgpack.Unmarshal(frame, &command); err != nil { if err := msgpack.Unmarshal(frame, &command); err != nil {
m.logger.WithError(err).Error("unmarshal failed") m.logger.WithError(err).Error("unmarshal failed")
return nil, errors.New("could not unmarshal fetch crl command") return errors.New("could not unmarshal fetch crl command")
} }
return &command, nil m.currentCommand.Command = &command
return nil
}
func (m *MsgPackHandler) currentID() string {
return m.currentCommand.Announce.ID
} }
func (m *MsgPackHandler) handleCommand(command *protocol.Command) (*protocol.Response, error) { func (m *MsgPackHandler) handleCommand() error {
var ( var (
responseCode messages.ResponseCode err error
responseData interface{} responseData interface{}
responseCode messages.ResponseCode
) )
switch cmd := command.Command.(type) { switch m.currentCommand.Command.(type) {
case *messages.HealthCommand: case *messages.HealthCommand:
response, err := m.handleHealthCommand() response, err := m.handleHealthCommand()
if err != nil { if err != nil {
return nil, err return err
} }
responseCode, responseData = messages.RespHealth, response responseCode, responseData = messages.RespHealth, response
case *messages.FetchCRLCommand: case *messages.FetchCRLCommand:
response, err := m.handleFetchCRLCommand(cmd) response, err := m.handleFetchCRLCommand()
if err != nil { if err != nil {
return nil, err return err
} }
responseCode, responseData = messages.RespFetchCRL, response responseCode, responseData = messages.RespFetchCRL, response
default: default:
return nil, fmt.Errorf("unhandled command %s", command.Announce) return fmt.Errorf("unhandled command %s", m.currentCommand.Announce)
} }
return &protocol.Response{ if err != nil {
Announce: messages.BuildResponseAnnounce(responseCode, command.Announce.ID), return fmt.Errorf("error from command handler: %w", err)
}
m.currentResponse = &protocol.Response{
Announce: messages.BuildResponseAnnounce(responseCode, m.currentID()),
Response: responseData, Response: responseData,
}, nil }
return nil
} }
func (m *MsgPackHandler) buildErrorResponse(commandID string, errMsg string) *protocol.Response { func (m *MsgPackHandler) buildErrorResponse(errMsg string) *protocol.Response {
return &protocol.Response{ return &protocol.Response{
Announce: messages.BuildResponseAnnounce(messages.RespError, commandID), Announce: messages.BuildResponseAnnounce(messages.RespError, m.currentID()),
Response: &messages.ErrorResponse{Message: errMsg}, Response: &messages.ErrorResponse{Message: errMsg},
} }
} }
func (m *MsgPackHandler) parseCommand(frame []byte, command *protocol.Command) error { func (m *MsgPackHandler) parseCommand(frame []byte) error {
switch command.Announce.Code { switch m.currentCommand.Announce.Code {
case messages.CmdHealth: case messages.CmdHealth:
healthCommand, err := m.parseHealthCommand(frame) return m.parseHealthCommand(frame)
if err != nil {
return err
}
command.Command = healthCommand
case messages.CmdFetchCRL: case messages.CmdFetchCRL:
fetchCRLCommand, err := m.parseFetchCRLCommand(frame) return m.parseFetchCRLCommand(frame)
if err != nil {
return err
}
command.Command = fetchCRLCommand
default: default:
return fmt.Errorf("unhandled command code %s", command.Announce.Code) return fmt.Errorf("unhandled command code %s", m.currentCommand.Announce.Code)
} }
return nil
} }
func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error) { func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error) {
@ -237,20 +235,27 @@ func (m *MsgPackHandler) handleHealthCommand() (*messages.HealthResponse, error)
return response, nil return response, nil
} }
func (m *MsgPackHandler) handleFetchCRLCommand(command *messages.FetchCRLCommand) (*messages.FetchCRLResponse, error) { func (m *MsgPackHandler) handleFetchCRLCommand() (*messages.FetchCRLResponse, error) {
res, err := m.fetchCRLHandler.FetchCRL(command.IssuerID) fetchCRLPayload, ok := m.currentCommand.Command.(*messages.FetchCRLCommand)
if !ok {
return nil, fmt.Errorf("could not use payload as FetchCRLPayload")
}
res, err := m.fetchCRLHandler.FetchCRL(fetchCRLPayload.IssuerID)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not fetch CRL: %w", err) return nil, fmt.Errorf("could not fetch CRL: %w", err)
} }
return &messages.FetchCRLResponse{ response := &messages.FetchCRLResponse{
IsDelta: false, IsDelta: false,
CRLNumber: res.Number, CRLNumber: res.Number,
CRLData: res.CRLData, CRLData: res.CRLData,
}, nil }
return response, nil
} }
func New(logger *logrus.Logger, handlers ...RegisterHandler) (protocol.ServerHandler, error) { func New(logger *logrus.Logger, handlers ...RegisterHandler) (protocol.Handler, error) {
messages.RegisterGeneratedResolver() messages.RegisterGeneratedResolver()
h := &MsgPackHandler{ h := &MsgPackHandler{

@ -15,14 +15,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
// Package serial provides a handler for the serial connection of the signer machine. // Package seriallink provides a handler for the serial connection of the signer machine.
package serial package serial
import ( import (
"bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io"
"sync"
"time"
"github.com/justincpresley/go-cobs"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tarm/serial" "github.com/tarm/serial"
@ -35,15 +39,22 @@ type protocolState int8
const ( const (
cmdAnnounce protocolState = iota cmdAnnounce protocolState = iota
cmdData cmdData
handleCommand respAnnounce
respond respData
) )
var validTransitions = map[protocolState]protocolState{
cmdAnnounce: cmdData,
cmdData: respAnnounce,
respAnnounce: respData,
respData: cmdAnnounce,
}
var protocolStateNames = map[protocolState]string{ var protocolStateNames = map[protocolState]string{
cmdAnnounce: "CMD ANNOUNCE", cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA", cmdData: "CMD DATA",
handleCommand: "RESP ANNOUNCE", respAnnounce: "RESP ANNOUNCE",
respond: "RESP DATA", respData: "RESP DATA",
} }
func (p protocolState) String() string { func (p protocolState) String() string {
@ -55,14 +66,13 @@ func (p protocolState) String() string {
} }
type Handler struct { type Handler struct {
protocolHandler protocol.ServerHandler protocolHandler protocol.Handler
protocolState protocolState protocolState protocolState
framer protocol.Framer
config *serial.Config config *serial.Config
port *serial.Port port *serial.Port
logger *logrus.Logger logger *logrus.Logger
lock sync.Mutex
framesIn chan []byte framesIn chan []byte
framesOut chan []byte
} }
func (h *Handler) setupConnection() error { func (h *Handler) setupConnection() error {
@ -85,137 +95,256 @@ func (h *Handler) Close() error {
return nil return nil
} }
var cobsConfig = cobs.Config{SpecialByte: protocol.CobsDelimiter, Delimiter: true, EndingSave: true}
func (h *Handler) Run(ctx context.Context) error { func (h *Handler) Run(ctx context.Context) error {
h.protocolState = cmdAnnounce h.protocolState = cmdAnnounce
protocolErrors, framerErrors := make(chan error), make(chan error) errors := make(chan error)
go func() {
err := h.framer.ReadFrames(h.port, h.framesIn)
framerErrors <- err
}()
go func() { go func() {
err := h.framer.WriteFrames(h.port, h.framesOut) err := h.readFrames()
framerErrors <- err errors <- err
}()
go func() {
err := h.handleProtocolState()
protocolErrors <- err
}() }()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case err := <-framerErrors:
if err != nil {
return fmt.Errorf("error from framer: %w", err)
}
return nil case err := <-errors:
case err := <-protocolErrors:
if err != nil { if err != nil {
return fmt.Errorf("error from protocol handler: %w", err) return fmt.Errorf("error from handler loop: %w", err)
} }
return nil return nil
default:
if err := h.handleProtocolState(); err != nil {
return err
}
} }
} }
} }
var errCommandExpected = errors.New("command must not be nil") func (h *Handler) readFrames() error {
var errResponseExpected = errors.New("response must not be nil") const (
readInterval = 50 * time.Millisecond
func (h *Handler) handleProtocolState() error {
var (
command *protocol.Command
response *protocol.Response
err error
) )
var frame []byte
buffer := &bytes.Buffer{}
delimiter := []byte{cobsConfig.SpecialByte}
for { for {
h.logger.Debugf("handling protocol state %s", h.protocolState) readBytes, err := h.readFromPort()
if err != nil {
close(h.framesIn)
switch h.protocolState { return err
case cmdAnnounce: }
command, err = h.protocolHandler.CommandAnnounce(h.framesIn)
if err != nil {
h.logger.WithError(err).Error("could not handle command announce")
break if len(readBytes) == 0 {
} time.Sleep(readInterval)
h.protocolState = cmdData continue
case cmdData: }
if command == nil {
return errCommandExpected
}
err = h.protocolHandler.CommandData(h.framesIn, command) h.logger.Tracef("read %d bytes", len(readBytes))
if err != nil {
h.logger.WithError(err).Error("could not handle command data")
h.protocolState = cmdAnnounce buffer.Write(readBytes)
break h.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
h.protocolState = handleCommand rest := buffer.Bytes()
case handleCommand:
if command == nil {
return errCommandExpected
}
response, err = h.protocolHandler.HandleCommand(command) if !bytes.Contains(rest, delimiter) {
if err != nil { continue
h.logger.WithError(err).Error("could not handle command") }
for bytes.Contains(rest, delimiter) {
parts := bytes.SplitAfterN(rest, delimiter, 2)
frame, rest = parts[0], parts[1]
h.protocolState = cmdAnnounce h.logger.Tracef("frame of length %d", len(frame))
break if len(frame) == 0 {
continue
} }
command = nil if err := cobs.Verify(frame, cobsConfig); err != nil {
close(h.framesIn)
h.protocolState = respond return fmt.Errorf("could not verify COBS frame: %w", err)
case respond:
if response == nil {
return errResponseExpected
} }
err = h.protocolHandler.Respond(response, h.framesOut) decoded := cobs.Decode(frame, cobsConfig)
if err != nil {
h.logger.WithError(err).Error("could not respond")
h.protocolState = cmdAnnounce h.logger.Tracef("frame decoded to length %d", len(decoded))
break h.framesIn <- decoded
} }
response = nil buffer.Truncate(0)
buffer.Write(rest)
h.protocolState = cmdAnnounce h.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
default: }
return fmt.Errorf("unknown protocol state %s", h.protocolState) }
func (h *Handler) writeFrame(frame []byte) error {
encoded := cobs.Encode(frame, cobsConfig)
return h.writeToPort(encoded)
}
func (h *Handler) nextState() error {
next, ok := validTransitions[h.protocolState]
if !ok {
return fmt.Errorf("illegal protocol state %s", h.protocolState)
}
h.protocolState = next
return nil
}
func (h *Handler) handleProtocolState() error {
h.logger.Tracef("handling protocol state %s", h.protocolState)
h.lock.Lock()
defer h.lock.Unlock()
switch h.protocolState {
case cmdAnnounce:
if err := h.handleCmdAnnounce(); err != nil {
return err
}
case cmdData:
if err := h.handleCmdData(); err != nil {
return err
}
case respAnnounce:
if err := h.handleRespAnnounce(); err != nil {
return err
} }
case respData:
if err := h.handleRespData(); err != nil {
return err
}
default:
return fmt.Errorf("unknown protocol state %s", h.protocolState)
} }
return nil
}
func (h *Handler) writeToPort(data []byte) error {
reader := bytes.NewReader(data)
n, err := io.Copy(h.port, reader)
if err != nil {
return fmt.Errorf("could not write data: %w", err)
}
h.logger.Tracef("wrote %d bytes", n)
return nil
}
func (h *Handler) readFromPort() ([]byte, error) {
const bufferSize = 1024
buf := make([]byte, bufferSize)
count, err := h.port.Read(buf)
if err != nil {
return nil, fmt.Errorf("could not read from serial port: %w", err)
}
return buf[:count], nil
}
func (h *Handler) handleCmdAnnounce() error {
h.logger.Trace("waiting for command announce")
frame := <-h.framesIn
if frame == nil {
return nil
}
if err := h.protocolHandler.HandleCommandAnnounce(frame); err != nil {
return fmt.Errorf("command announce handling failed: %w", err)
}
if err := h.nextState(); err != nil {
return err
}
return nil
}
func (h *Handler) handleCmdData() error {
h.logger.Trace("waiting for command data")
frame := <-h.framesIn
if frame == nil {
return nil
}
if err := h.protocolHandler.HandleCommand(frame); err != nil {
return fmt.Errorf("command handler failed: %w", err)
}
if err := h.nextState(); err != nil {
return err
}
return nil
}
func (h *Handler) handleRespAnnounce() error {
frame, err := h.protocolHandler.ResponseAnnounce()
if err != nil {
return fmt.Errorf("could not get response announcement: %w", err)
}
if err := h.writeFrame(frame); err != nil {
return err
}
if err := h.nextState(); err != nil {
return err
}
return nil
}
func (h *Handler) handleRespData() error {
frame, err := h.protocolHandler.ResponseData()
if err != nil {
return fmt.Errorf("could not get response data: %w", err)
}
if err := h.writeFrame(frame); err != nil {
return err
}
if err := h.nextState(); err != nil {
return err
}
return nil
} }
func New( func New(cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.Handler) (*Handler, error) {
cfg *config.Serial,
logger *logrus.Logger,
protocolHandler protocol.ServerHandler,
) (*Handler, error) {
h := &Handler{ h := &Handler{
protocolHandler: protocolHandler, protocolHandler: protocolHandler,
logger: logger, logger: logger,
framesIn: make(chan []byte), framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: protocol.NewCOBSFramer(logger),
} }
h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout} h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout}

@ -43,7 +43,7 @@ const (
var commandNames = map[CommandCode]string{ var commandNames = map[CommandCode]string{
CmdHealth: "HEALTH", CmdHealth: "HEALTH",
CmdFetchCRL: "FETCH CRL", CmdFetchCRL: "FETCH URL",
} }
func (c CommandCode) String() string { func (c CommandCode) String() string {

@ -19,14 +19,7 @@ limitations under the License.
package protocol package protocol
import ( import (
"bytes"
"errors"
"fmt" "fmt"
"io"
"time"
"github.com/justincpresley/go-cobs"
"github.com/sirupsen/logrus"
"git.cacert.org/cacert-gosigner/pkg/messages" "git.cacert.org/cacert-gosigner/pkg/messages"
) )
@ -51,130 +44,14 @@ func (r *Response) String() string {
return fmt.Sprintf("Rsp[announce={%s}, data={%s}]", r.Announce, r.Response) return fmt.Sprintf("Rsp[announce={%s}, data={%s}]", r.Announce, r.Response)
} }
// ServerHandler is responsible for parsing incoming frames and calling commands // Handler is responsible for parsing incoming frames and calling commands
type ServerHandler interface { type Handler interface {
// CommandAnnounce handles the initial announcement of a command. // HandleCommandAnnounce handles the initial announcement of a command.
CommandAnnounce(chan []byte) (*Command, error) HandleCommandAnnounce([]byte) error
// CommandData handles the command data. // HandleCommand handles the command data.
CommandData(chan []byte, *Command) error HandleCommand([]byte) error
// HandleCommand executes the command, generating a response. // ResponseAnnounce generates the announcement for a response.
HandleCommand(*Command) (*Response, error) ResponseAnnounce() ([]byte, error)
// Respond generates the response for a command. // ResponseData generates the response data.
Respond(*Response, chan []byte) error ResponseData() ([]byte, error)
}
// Framer handles bytes on the wire by adding or removing framing information.
type Framer interface {
// ReadFrames reads data frames and publishes unframed data to the channel.
ReadFrames(io.Reader, chan []byte) error
// WriteFrames takes data from the channel and writes framed data to the writer.
WriteFrames(io.Writer, chan []byte) error
}
const bufferSize = 1024
const readInterval = 50 * time.Millisecond
type COBSFramer struct {
config cobs.Config
logger *logrus.Logger
}
func NewCOBSFramer(logger *logrus.Logger) *COBSFramer {
return &COBSFramer{
config: cobs.Config{SpecialByte: CobsDelimiter, Delimiter: true, EndingSave: true},
logger: logger,
}
}
func (c *COBSFramer) ReadFrames(reader io.Reader, frameChan chan []byte) error {
var (
err error
raw, data, frame []byte
)
buffer := &bytes.Buffer{}
for {
raw, err = c.readRaw(reader)
if err != nil {
close(frameChan)
return err
}
if len(raw) == 0 {
time.Sleep(readInterval)
continue
}
c.logger.Tracef("read %d raw bytes", len(raw))
buffer.Write(raw)
for {
data, err = buffer.ReadBytes(c.config.SpecialByte)
if err != nil {
if errors.Is(err, io.EOF) {
buffer.Write(data)
break
}
return fmt.Errorf("could not read from buffer: %w", err)
}
if err = cobs.Verify(data, c.config); err != nil {
c.logger.WithError(err).Warnf("skipping invalid frame of %d bytes", len(data))
break
}
frame = cobs.Decode(data, c.config)
c.logger.Tracef("frame decoded to length %d", len(frame))
frameChan <- frame
}
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
}
func (c *COBSFramer) readRaw(reader io.Reader) ([]byte, error) {
buf := make([]byte, bufferSize)
count, err := reader.Read(buf)
if err != nil {
if errors.Is(err, io.EOF) {
return []byte{}, nil
}
return nil, fmt.Errorf("could not read data: %w", err)
}
raw := buf[:count]
return raw, nil
}
func (c *COBSFramer) WriteFrames(writer io.Writer, frameChan chan []byte) error {
for {
frame := <-frameChan
if frame == nil {
c.logger.Debug("channel closed")
return nil
}
encoded := cobs.Encode(frame, c.config)
n, err := io.Copy(writer, bytes.NewReader(encoded))
if err != nil {
return fmt.Errorf("cold not write data: %w", err)
}
c.logger.Tracef("wrote %d bytes", n)
}
} }

Loading…
Cancel
Save