Implement protocol improvements

This commit implements a client and server side state machine
for the serial protocol.
This commit is contained in:
Jan Dittberner 2022-11-21 08:26:50 +01:00
parent 2de592d30c
commit 8e443bd8b4
5 changed files with 766 additions and 388 deletions

View file

@ -23,6 +23,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"os"
"sync"
"time"
@ -36,39 +37,135 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages"
)
const cobsDelimiter = 0x00
var cobsConfig = cobs.Config{SpecialByte: 0x00, Delimiter: true, EndingSave: true}
var cobsConfig = cobs.Config{SpecialByte: cobsDelimiter, Delimiter: true, EndingSave: true}
type protocolState int8
func main() {
logger := logrus.New()
logger.SetOutput(os.Stderr)
logger.SetLevel(logrus.InfoLevel)
const (
cmdAnnounce protocolState = iota
cmdData
respAnnounce
respData
)
sim := &clientSimulator{
logger: logger,
var validTransitions = map[protocolState]protocolState{
cmdAnnounce: cmdData,
cmdData: respAnnounce,
respAnnounce: respData,
respData: cmdAnnounce,
}
var protocolStateNames = map[protocolState]string{
cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA",
respAnnounce: "RESP ANNOUNCE",
respData: "RESP DATA",
}
func (p protocolState) String() string {
if name, ok := protocolStateNames[p]; ok {
return name
}
err := sim.Run()
return fmt.Sprintf("unknown %d", p)
}
type TestCommandGenerator struct {
logger *logrus.Logger
currentCommand *protocol.Command
currentResponse *protocol.Response
commands chan *protocol.Command
lock sync.Mutex
}
func (g *TestCommandGenerator) CmdAnnouncement() ([]byte, error) {
g.lock.Lock()
defer g.lock.Unlock()
select {
case g.currentCommand = <-g.commands:
announceData, err := msgpack.Marshal(g.currentCommand.Announce)
if err != nil {
return nil, fmt.Errorf("could not marshal command annoucement: %w", err)
}
g.logger.WithField("announcement", &g.currentCommand.Announce).Info("write command announcement")
return announceData, nil
}
}
func (g *TestCommandGenerator) CmdData() ([]byte, error) {
g.lock.Lock()
defer g.lock.Unlock()
cmdData, err := msgpack.Marshal(g.currentCommand.Command)
if err != nil {
logger.WithError(err).Error("simulator returned an error")
return nil, fmt.Errorf("could not marshal command data: %w", err)
}
g.logger.WithField("command", &g.currentCommand.Command).Info("write command data")
return cmdData, nil
}
type clientSimulator struct {
logger *logrus.Logger
commands chan *protocol.Command
responses chan [][]byte
func (g *TestCommandGenerator) HandleResponseAnnounce(frame []byte) error {
g.lock.Lock()
defer g.lock.Unlock()
var ann messages.ResponseAnnounce
if err := msgpack.Unmarshal(frame, &ann); err != nil {
return fmt.Errorf("could not unmarshal response announcement")
}
g.logger.WithField("announcement", &ann).Info("received response announcement")
g.currentResponse = &protocol.Response{Announce: &ann}
return nil
}
func (c *clientSimulator) writeTestCommands(ctx context.Context) error {
messages.RegisterGeneratedResolver()
func (g *TestCommandGenerator) HandleResponse(frame []byte) error {
g.lock.Lock()
defer g.lock.Unlock()
switch g.currentResponse.Announce.Code {
case messages.RespHealth:
var response messages.HealthResponse
if err := msgpack.Unmarshal(frame, &response); err != nil {
return fmt.Errorf("unmarshal failed: %w", err)
}
g.currentResponse.Response = response
case messages.RespFetchCRL:
var response messages.FetchCRLResponse
if err := msgpack.Unmarshal(frame, &response); err != nil {
return fmt.Errorf("unmarshal failed: %w", err)
}
}
g.logger.WithField(
"command",
g.currentCommand,
).WithField(
"response",
g.currentResponse,
).Info("handled health response")
return nil
}
func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
const healthInterval = 10 * time.Second
g.logger.Info("start generating commands")
time.Sleep(healthInterval)
c.commands <- &protocol.Command{
g.commands <- &protocol.Command{
Announce: messages.BuildCommandAnnounce(messages.CmdFetchCRL),
Command: &messages.FetchCRLCommand{IssuerID: "sub-ecc_person_2022"},
}
@ -80,11 +177,11 @@ func (c *clientSimulator) writeTestCommands(ctx context.Context) error {
case <-ctx.Done():
_ = timer.Stop()
c.logger.Info("stopping health check loop")
g.logger.Info("stopping health check loop")
return nil
case <-timer.C:
c.commands <- &protocol.Command{
g.commands <- &protocol.Command{
Announce: messages.BuildCommandAnnounce(messages.CmdHealth),
Command: &messages.HealthCommand{},
}
@ -94,218 +191,292 @@ func (c *clientSimulator) writeTestCommands(ctx context.Context) error {
}
}
func (c *clientSimulator) handleInput(ctx context.Context) error {
const (
bufferSize = 1024 * 1024
readInterval = 50 * time.Millisecond
)
type clientSimulator struct {
protocolState protocolState
logger *logrus.Logger
lock sync.Mutex
framesIn chan []byte
commandGenerator *TestCommandGenerator
}
buf := make([]byte, bufferSize)
func (c *clientSimulator) readFrames() error {
const readInterval = 50 * time.Millisecond
type protocolState int8
var frame []byte
buffer := &bytes.Buffer{}
delimiter := []byte{cobsConfig.SpecialByte}
const (
stAnn protocolState = iota
stResp
)
state := stAnn
var announce []byte
reading:
for {
select {
case <-ctx.Done():
return nil
default:
count, err := os.Stdin.Read(buf)
if err != nil {
return fmt.Errorf("reading input failed: %w", err)
}
readBytes, err := c.readFromStdin()
if err != nil {
c.logger.WithError(err).Error("stdin read error")
if count == 0 {
time.Sleep(readInterval)
close(c.framesIn)
return err
}
if len(readBytes) == 0 {
time.Sleep(readInterval)
continue
}
c.logger.Tracef("read %d bytes", len(readBytes))
buffer.Write(readBytes)
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
rest := buffer.Bytes()
if !bytes.Contains(rest, delimiter) {
continue
}
for bytes.Contains(rest, delimiter) {
parts := bytes.SplitAfterN(rest, delimiter, 2)
frame, rest = parts[0], parts[1]
c.logger.Tracef("frame of length %d", len(frame))
if len(frame) == 0 {
continue
}
data := buf[:count]
for _, frame := range bytes.SplitAfter(data, []byte{cobsConfig.SpecialByte}) {
if len(frame) == 0 {
continue reading
}
err = cobs.Verify(frame, cobsConfig)
if err != nil {
return fmt.Errorf("frame verification failed: %w", err)
}
if state == stAnn {
announce = cobs.Decode(frame, cobsConfig)
state = stResp
} else {
c.responses <- [][]byte{announce, cobs.Decode(frame, cobsConfig)}
state = stAnn
}
err = cobs.Verify(frame, cobsConfig)
if err != nil {
return fmt.Errorf("frame verification failed: %w", err)
}
decoded := cobs.Decode(frame, cobsConfig)
c.logger.Tracef("frame decoded to length %d", len(decoded))
c.framesIn <- decoded
}
buffer.Truncate(0)
buffer.Write(rest)
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
}
func (c *clientSimulator) handleCommands(ctx context.Context) error {
for {
select {
case command := <-c.commands:
if err := writeCommandAnnouncement(command); err != nil {
return err
}
func (c *clientSimulator) writeFrame(frame []byte) error {
encoded := cobs.Encode(frame, cobsConfig)
if err := writeCommand(command); err != nil {
return err
}
c.lock.Lock()
defer c.lock.Unlock()
respData := <-c.responses
if _, err := io.Copy(os.Stdout, bytes.NewBuffer(encoded)); err != nil {
return fmt.Errorf("could not write data: %w", err)
}
c.logger.WithField("respdata", respData).Trace("read response data")
return nil
}
response := &protocol.Response{}
func (c *clientSimulator) readFromStdin() ([]byte, error) {
const bufferSize = 1024
if err := msgpack.Unmarshal(respData[0], &response.Announce); err != nil {
return fmt.Errorf("could not unmarshal response announcement: %w", err)
}
buf := make([]byte, bufferSize)
if err := c.handleResponse(response, respData[1]); err != nil {
return err
}
count, err := os.Stdin.Read(buf)
if err != nil {
return nil, fmt.Errorf("reading input failed: %w", err)
}
case <-ctx.Done():
return buf[:count], nil
}
func (c *clientSimulator) writeCmdAnnouncement() error {
frame, err := c.commandGenerator.CmdAnnouncement()
if err != nil {
return fmt.Errorf("could not get command annoucement: %w", err)
}
c.logger.Trace("writing command announcement")
if err := c.writeFrame(frame); err != nil {
return err
}
if err := c.nextState(); err != nil {
return err
}
return nil
}
func (c *clientSimulator) writeCommandAnnouncement() error {
frame, err := c.commandGenerator.CmdAnnouncement()
if err != nil {
return fmt.Errorf("could not get command announcement: %w", err)
}
c.logger.Trace("writing command announcement")
if err := c.writeFrame(frame); err != nil {
return err
}
if err := c.nextState(); err != nil {
return err
}
return nil
}
func (c *clientSimulator) writeCommand() error {
frame, err := c.commandGenerator.CmdData()
if err != nil {
return fmt.Errorf("could not get command data: %w", err)
}
c.logger.Trace("writing command data")
if err := c.writeFrame(frame); err != nil {
return err
}
if err := c.nextState(); err != nil {
return err
}
return nil
}
func (c *clientSimulator) handleResponseAnnounce() error {
c.logger.Trace("waiting for response announce")
select {
case frame := <-c.framesIn:
if frame == nil {
return nil
}
if err := c.commandGenerator.HandleResponseAnnounce(frame); err != nil {
return fmt.Errorf("response announce handling failed: %w", err)
}
if err := c.nextState(); err != nil {
return err
}
}
return nil
}
func (c *clientSimulator) handleResponseData() error {
c.logger.Trace("waiting for response data")
select {
case frame := <-c.framesIn:
if frame == nil {
return nil
}
if err := c.commandGenerator.HandleResponse(frame); err != nil {
return fmt.Errorf("response handler failed: %w", err)
}
if err := c.nextState(); err != nil {
return err
}
}
return nil
}
func (c *clientSimulator) Run(ctx context.Context) error {
c.protocolState = cmdAnnounce
errors := make(chan error)
go func() {
err := c.readFrames()
errors <- err
}()
go func() {
err := c.commandGenerator.GenerateCommands(ctx)
errors <- err
}()
for {
select {
case <-ctx.Done():
return nil
case err := <-errors:
if err != nil {
return fmt.Errorf("error from handler loop: %w", err)
}
return nil
default:
if err := c.handleProtocolState(); err != nil {
return err
}
}
}
}
func (c *clientSimulator) handleResponse(response *protocol.Response, respBytes []byte) error {
switch response.Announce.Code {
case messages.RespHealth:
data := messages.HealthResponse{}
func (c *clientSimulator) handleProtocolState() error {
c.logger.Tracef("handling protocol state %s", c.protocolState)
if err := msgpack.Unmarshal(respBytes, &data); err != nil {
return fmt.Errorf("could not unmarshal health data: %w", err)
switch c.protocolState {
case cmdAnnounce:
if err := c.writeCmdAnnouncement(); err != nil {
return err
}
c.logger.WithField(
"announce",
response.Announce,
).WithField(
"data",
&data,
).Infof("received response")
case messages.RespFetchCRL:
data := messages.FetchCRLResponse{}
if err := msgpack.Unmarshal(respBytes, &data); err != nil {
return fmt.Errorf("could not unmarshal fetch CRL data: %w", err)
case cmdData:
if err := c.writeCommand(); err != nil {
return err
}
c.logger.WithField(
"announce",
response.Announce,
).WithField(
"data",
&data,
).Infof("received response")
case messages.RespError:
data := messages.ErrorResponse{}
if err := msgpack.Unmarshal(respBytes, &data); err != nil {
return fmt.Errorf("could not unmarshal error data: %w", err)
case respAnnounce:
if err := c.handleResponseAnnounce(); err != nil {
return err
}
case respData:
if err := c.handleResponseData(); err != nil {
return err
}
c.logger.WithField(
"announce",
response.Announce,
).WithField(
"data",
&data,
).Infof("received response")
default:
if err := msgpack.Unmarshal(respBytes, &response.Response); err != nil {
return fmt.Errorf("could not unmarshal response: %w", err)
}
c.logger.WithField("response", response).Infof("received response")
return fmt.Errorf("unknown protocol state %s", c.protocolState)
}
return nil
}
func writeCommandAnnouncement(command *protocol.Command) error {
cmdAnnounceBytes, err := msgpack.Marshal(&command.Announce)
func (c *clientSimulator) nextState() error {
next, ok := validTransitions[c.protocolState]
if !ok {
return fmt.Errorf("illegal protocol state %s", c.protocolState)
}
c.protocolState = next
return nil
}
func main() {
logger := logrus.New()
logger.SetOutput(os.Stderr)
logger.SetLevel(logrus.TraceLevel)
messages.RegisterGeneratedResolver()
sim := &clientSimulator{
commandGenerator: &TestCommandGenerator{
logger: logger,
commands: make(chan *protocol.Command, 0),
},
logger: logger,
framesIn: make(chan []byte, 0),
}
err := sim.Run(context.Background())
if err != nil {
return fmt.Errorf("could not marshal command announcement bytes: %w", err)
logger.WithError(err).Error("simulator returned an error")
}
if _, err = os.Stdout.Write(cobs.Encode(cmdAnnounceBytes, cobsConfig)); err != nil {
return fmt.Errorf("command announcement write failed: %w", err)
}
return nil
}
func writeCommand(command *protocol.Command) error {
cmdBytes, err := msgpack.Marshal(&command.Command)
if err != nil {
return fmt.Errorf("could not marshal command bytes: %w", err)
}
if _, err = os.Stdout.Write(cobs.Encode(cmdBytes, cobsConfig)); err != nil {
return fmt.Errorf("command write failed: %w", err)
}
return nil
}
func (c *clientSimulator) Run() error {
ctx, cancel := context.WithCancel(context.Background())
c.commands = make(chan *protocol.Command)
c.responses = make(chan [][]byte)
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
if err := c.handleInput(ctx); err != nil {
c.logger.WithError(err).Error("input handling failed")
}
cancel()
wg.Done()
}()
go func() {
if err := c.handleCommands(ctx); err != nil {
c.logger.WithError(err).Error("command handling failed")
}
cancel()
wg.Done()
}()
var result error
if err := c.writeTestCommands(ctx); err != nil {
c.logger.WithError(err).Error("test commands failed")
}
cancel()
wg.Wait()
return result
}

View file

@ -18,6 +18,7 @@ limitations under the License.
package main
import (
"context"
"flag"
"fmt"
"os"
@ -101,18 +102,18 @@ func main() {
logger.WithError(err).Fatal("could not setup protocol handler")
}
serialHandler, err := seriallink.New(caConfig.GetSerial(), proto)
serialHandler, err := seriallink.New(caConfig.GetSerial(), logger, proto)
if err != nil {
logger.WithError(err).Fatal("could not setup serial link handler")
}
defer func() { _ = serialHandler.Close() }()
if err = serialHandler.Run(); err != nil {
logger.Info("setup complete, starting signer operation")
if err = serialHandler.Run(context.Background()); err != nil {
logger.WithError(err).Fatal("error in serial handler")
}
logger.Info("setup complete, starting signer operation")
}
func configureRepositories(

View file

@ -162,8 +162,9 @@ func TestPrivateKeyInfo_UnmarshalYAML(t *testing.T) {
algorithm: "RSA"
rsa-bits: 2048`,
expected: &config.PrivateKeyInfo{
Algorithm: x509.RSA,
RSABits: 2048,
Algorithm: x509.RSA,
RSABits: 2048,
CRLSignatureAlgorithm: x509.SHA256WithRSA,
},
},
{
@ -172,8 +173,9 @@ rsa-bits: 2048`,
algorithm: "EC"
ecc-curve: "P-224"`,
expected: &config.PrivateKeyInfo{
Algorithm: x509.ECDSA,
EccCurve: elliptic.P224(),
Algorithm: x509.ECDSA,
EccCurve: elliptic.P224(),
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
},
},
{
@ -182,8 +184,9 @@ ecc-curve: "P-224"`,
algorithm: "EC"
ecc-curve: "P-256"`,
expected: &config.PrivateKeyInfo{
Algorithm: x509.ECDSA,
EccCurve: elliptic.P256(),
Algorithm: x509.ECDSA,
EccCurve: elliptic.P256(),
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
},
},
{
@ -192,8 +195,9 @@ ecc-curve: "P-256"`,
algorithm: "EC"
ecc-curve: "P-384"`,
expected: &config.PrivateKeyInfo{
Algorithm: x509.ECDSA,
EccCurve: elliptic.P384(),
Algorithm: x509.ECDSA,
EccCurve: elliptic.P384(),
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
},
},
{
@ -202,8 +206,9 @@ ecc-curve: "P-384"`,
algorithm: "EC"
ecc-curve: "P-521"`,
expected: &config.PrivateKeyInfo{
Algorithm: x509.ECDSA,
EccCurve: elliptic.P521(),
Algorithm: x509.ECDSA,
EccCurve: elliptic.P521(),
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
},
},
{
@ -308,8 +313,9 @@ storage: root
`,
expected: config.CaCertificateEntry{
KeyInfo: &config.PrivateKeyInfo{
Algorithm: x509.ECDSA,
EccCurve: elliptic.P521(),
Algorithm: x509.ECDSA,
EccCurve: elliptic.P521(),
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
},
CommonName: "My Little Test Root CA",
Storage: "root",
@ -329,8 +335,9 @@ ext-key-usages:
`,
expected: config.CaCertificateEntry{
KeyInfo: &config.PrivateKeyInfo{
Algorithm: x509.ECDSA,
EccCurve: elliptic.P256(),
Algorithm: x509.ECDSA,
EccCurve: elliptic.P256(),
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
},
Parent: "root",
CommonName: "My Little Test Sub CA",
@ -357,8 +364,9 @@ ext-key-usages:
`,
expected: config.CaCertificateEntry{
KeyInfo: &config.PrivateKeyInfo{
Algorithm: x509.ECDSA,
EccCurve: elliptic.P256(),
Algorithm: x509.ECDSA,
EccCurve: elliptic.P256(),
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
},
CommonName: "My Little Test Sub CA",
Storage: "default",

View file

@ -21,6 +21,7 @@ package protocol
import (
"errors"
"fmt"
"sync"
"github.com/shamaton/msgpackgen/msgpack"
"github.com/sirupsen/logrus"
@ -31,82 +32,6 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages"
)
// Handler is responsible for parsing incoming frames and calling commands
type Handler interface {
HandleCommandAnnounce([]byte) (*messages.CommandAnnounce, error)
HandleCommand(*messages.CommandAnnounce, []byte) ([]byte, []byte, error)
}
type MsgPackHandler struct {
logger *logrus.Logger
healthHandler *health.Handler
fetchCRLHandler *revoking.FetchCRLHandler
}
func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) (*messages.CommandAnnounce, error) {
var ann messages.CommandAnnounce
if err := msgpack.Unmarshal(frame, &ann); err != nil {
return nil, fmt.Errorf("could not unmarshal command announcement: %w", err)
}
m.logger.Infof("received command announcement %+v", ann)
return &ann, nil
}
func (m *MsgPackHandler) HandleCommand(announce *messages.CommandAnnounce, frame []byte) ([]byte, []byte, error) {
var (
response *Response
clientError, err error
)
switch announce.Code {
case messages.CmdHealth:
// health has no payload, ignore the frame
response, err = m.handleCommand(&Command{Announce: announce, Command: nil})
if err != nil {
m.logger.WithError(err).Error("health handling failed")
clientError = errors.New("could not handle request")
}
case messages.CmdFetchCRL:
var command messages.FetchCRLCommand
err = msgpack.Unmarshal(frame, &command)
if err != nil {
m.logger.WithError(err).Error("unmarshal failed")
clientError = errors.New("could not unmarshal fetch crl command")
break
}
response, err = m.handleCommand(&Command{Announce: announce, Command: command})
if err != nil {
m.logger.WithError(err).Error("fetch CRL handling failed")
clientError = errors.New("could not handle request")
}
}
if clientError != nil {
response = buildErrorResponse(clientError.Error())
}
announceData, err := msgpack.Marshal(response.Announce)
if err != nil {
return nil, nil, fmt.Errorf("could not marshal response announcement: %w", err)
}
responseData, err := msgpack.Marshal(response.Response)
if err != nil {
return nil, nil, fmt.Errorf("could not marshal response: %w", err)
}
return announceData, responseData, nil
}
type Command struct {
Announce *messages.CommandAnnounce
Command interface{}
@ -121,23 +46,146 @@ func (r *Response) String() string {
return fmt.Sprintf("Response[Code=%s] created=%s data=%s", r.Announce.Code, r.Announce.Created, r.Response)
}
func (m *MsgPackHandler) handleCommand(command *Command) (*Response, error) {
// Handler is responsible for parsing incoming frames and calling commands
type Handler interface {
HandleCommandAnnounce([]byte) error
HandleCommand([]byte) error
ResponseAnnounce() ([]byte, error)
ResponseData() ([]byte, error)
}
type MsgPackHandler struct {
logger *logrus.Logger
healthHandler *health.Handler
fetchCRLHandler *revoking.FetchCRLHandler
currentCommand *Command
currentResponse *Response
lock sync.Mutex
}
func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) error {
m.lock.Lock()
defer m.lock.Unlock()
var ann messages.CommandAnnounce
if err := msgpack.Unmarshal(frame, &ann); err != nil {
return fmt.Errorf("could not unmarshal command announcement: %w", err)
}
m.logger.WithField("announcement", &ann).Info("received command announcement")
m.currentCommand = &Command{Announce: &ann}
return nil
}
func (m *MsgPackHandler) HandleCommand(frame []byte) error {
m.lock.Lock()
defer m.lock.Unlock()
var clientError error
switch m.currentCommand.Announce.Code {
case messages.CmdHealth:
// health has no payload, ignore the frame
response, err := m.handleCommand()
if err != nil {
m.logger.WithError(err).Error("health handling failed")
clientError = errors.New("could not handle request")
break
}
m.currentResponse = response
case messages.CmdFetchCRL:
var command messages.FetchCRLCommand
if err := msgpack.Unmarshal(frame, &command); err != nil {
m.logger.WithError(err).Error("unmarshal failed")
clientError = errors.New("could not unmarshal fetch crl command")
break
}
m.currentCommand.Command = command
response, err := m.handleCommand()
if err != nil {
m.logger.WithError(err).Error("fetch CRL handling failed")
clientError = errors.New("could not handle request")
break
}
m.currentResponse = response
}
if clientError != nil {
m.currentResponse = buildErrorResponse(clientError.Error())
}
m.logger.WithField(
"command",
m.currentCommand,
).WithField(
"response",
m.currentResponse,
).Info("handled command")
m.currentCommand = nil
return nil
}
func (m *MsgPackHandler) ResponseAnnounce() ([]byte, error) {
m.lock.Lock()
defer m.lock.Unlock()
announceData, err := msgpack.Marshal(m.currentResponse.Announce)
if err != nil {
return nil, fmt.Errorf("could not marshal response announcement: %w", err)
}
m.logger.WithField("announcement", &m.currentResponse.Announce).Info("write response announcement")
return announceData, nil
}
func (m *MsgPackHandler) ResponseData() ([]byte, error) {
m.lock.Lock()
defer m.lock.Unlock()
responseData, err := msgpack.Marshal(m.currentResponse.Response)
if err != nil {
return nil, fmt.Errorf("could not marshal response: %w", err)
}
m.logger.WithField("response", &m.currentResponse.Response).Info("write response")
return responseData, nil
}
func (m *MsgPackHandler) handleCommand() (*Response, error) {
var (
err error
responseData interface{}
responseCode messages.ResponseCode
)
switch command.Announce.Code {
switch m.currentCommand.Announce.Code {
case messages.CmdHealth:
var res *health.Result
res, err = m.healthHandler.CheckHealth()
if err != nil {
break
return nil, err
}
response := messages.HealthResponse{
response := &messages.HealthResponse{
Version: res.Version,
Healthy: res.Healthy,
}
@ -154,24 +202,24 @@ func (m *MsgPackHandler) handleCommand(command *Command) (*Response, error) {
case messages.CmdFetchCRL:
var res *revoking.Result
fetchCRLPayload, ok := command.Command.(messages.FetchCRLCommand)
fetchCRLPayload, ok := m.currentCommand.Command.(messages.FetchCRLCommand)
if !ok {
return nil, fmt.Errorf("could not use payload as FetchCRLPayload")
}
res, err = m.fetchCRLHandler.FetchCRL(fetchCRLPayload.IssuerID)
if err != nil {
break
return nil, err
}
response := messages.FetchCRLResponse{
response := &messages.FetchCRLResponse{
IsDelta: false,
CRLData: res.Data,
}
responseCode, responseData = messages.RespFetchCRL, response
default:
return nil, fmt.Errorf("unhandled command %v", command)
return nil, fmt.Errorf("unhandled command %s", m.currentCommand.Announce)
}
if err != nil {
@ -187,7 +235,7 @@ func (m *MsgPackHandler) handleCommand(command *Command) (*Response, error) {
func buildErrorResponse(errMsg string) *Response {
return &Response{
Announce: messages.BuildResponseAnnounce(messages.RespError),
Response: messages.ErrorResponse{Message: errMsg},
Response: &messages.ErrorResponse{Message: errMsg},
}
}

View file

@ -20,10 +20,14 @@ package seriallink
import (
"bytes"
"context"
"fmt"
"io"
"sync"
"time"
"github.com/justincpresley/go-cobs"
"github.com/sirupsen/logrus"
"github.com/tarm/serial"
"git.cacert.org/cacert-gosigner/pkg/config"
@ -33,16 +37,42 @@ import (
type protocolState int8
const (
stAnnounce protocolState = iota
stCommand
cmdAnnounce protocolState = iota
cmdData
respAnnounce
respData
)
var validTransitions = map[protocolState]protocolState{
cmdAnnounce: cmdData,
cmdData: respAnnounce,
respAnnounce: respData,
respData: cmdAnnounce,
}
var protocolStateNames = map[protocolState]string{
cmdAnnounce: "CMD ANNOUNCE",
cmdData: "CMD DATA",
respAnnounce: "RESP ANNOUNCE",
respData: "RESP DATA",
}
func (p protocolState) String() string {
if name, ok := protocolStateNames[p]; ok {
return name
}
return fmt.Sprintf("unknown %d", p)
}
type Handler struct {
protocolHandler protocol.Handler
protocolState protocolState
currentCommand *protocol.Command
config *serial.Config
port *serial.Port
logger *logrus.Logger
lock sync.Mutex
framesIn chan []byte
}
func (h *Handler) setupConnection() error {
@ -65,78 +95,191 @@ func (h *Handler) Close() error {
return nil
}
const cobsDelimiter = 0x00
var cobsConfig = cobs.Config{SpecialByte: 0x00, Delimiter: true, EndingSave: true}
var cobsConfig = cobs.Config{SpecialByte: cobsDelimiter, Delimiter: true, EndingSave: true}
func (h *Handler) Run(ctx context.Context) error {
h.protocolState = cmdAnnounce
errors := make(chan error)
func (h *Handler) Run() error {
go func() {
err := h.readFrames()
errors <- err
}()
for {
select {
case <-ctx.Done():
return nil
case err := <-errors:
if err != nil {
return fmt.Errorf("error from handler loop: %w", err)
}
return nil
default:
if err := h.handleProtocolState(); err != nil {
return err
}
}
}
}
func (h *Handler) readFrames() error {
const (
bufferSize = 1024 * 1024
readInterval = 50 * time.Millisecond
)
errors := make(chan error)
var frame []byte
buffer := &bytes.Buffer{}
delimiter := []byte{cobsConfig.SpecialByte}
h.protocolState = stAnnounce
for {
readBytes, err := h.readFromPort()
if err != nil {
close(h.framesIn)
go func() {
buf := make([]byte, bufferSize)
return err
}
for {
count, err := h.port.Read(buf)
if err != nil {
errors <- err
if len(readBytes) == 0 {
time.Sleep(readInterval)
return
}
continue
}
if count == 0 {
time.Sleep(readInterval)
h.logger.Tracef("read %d bytes", len(readBytes))
buffer.Write(readBytes)
h.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
rest := buffer.Bytes()
if !bytes.Contains(rest, delimiter) {
continue
}
for bytes.Contains(rest, delimiter) {
parts := bytes.SplitAfterN(rest, delimiter, 2)
frame, rest = parts[0], parts[1]
h.logger.Tracef("frame of length %d", len(frame))
if len(frame) == 0 {
continue
}
frames := bytes.SplitAfter(buf[:count], []byte{cobsDelimiter})
if err := cobs.Verify(frame, cobsConfig); err != nil {
close(h.framesIn)
if err := h.handleFrames(frames); err != nil {
errors <- err
return
return fmt.Errorf("could not verify COBS frame: %w", err)
}
}
}()
err := <-errors
if err != nil {
return fmt.Errorf("error from handler loop: %w", err)
decoded := cobs.Decode(frame, cobsConfig)
h.logger.Tracef("frame decoded to length %d", len(decoded))
h.framesIn <- decoded
}
buffer.Truncate(0)
buffer.Write(rest)
h.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
}
func (h *Handler) writeFrame(frame []byte) error {
encoded := cobs.Encode(frame, cobsConfig)
return h.writeToPort(encoded)
}
func (h *Handler) nextState() error {
next, ok := validTransitions[h.protocolState]
if !ok {
return fmt.Errorf("illegal protocol state %s", h.protocolState)
}
h.protocolState = next
return nil
}
func (h *Handler) handleProtocolState() error {
h.logger.Tracef("handling protocol state %s", h.protocolState)
switch h.protocolState {
case cmdAnnounce:
if err := h.handleCmdAnnounce(); err != nil {
return err
}
case cmdData:
if err := h.handleCmdData(); err != nil {
return err
}
case respAnnounce:
if err := h.handleRespAnnounce(); err != nil {
return err
}
case respData:
if err := h.handleRespData(); err != nil {
return err
}
default:
return fmt.Errorf("unknown protocol state %s", h.protocolState)
}
return nil
}
func (h *Handler) handleFrames(frames [][]byte) error {
for _, frame := range frames {
if len(frame) == 0 {
func (h *Handler) writeToPort(data []byte) error {
h.lock.Lock()
defer h.lock.Unlock()
reader := bytes.NewReader(data)
n, err := io.Copy(h.port, reader)
if err != nil {
return fmt.Errorf("could not write data: %w", err)
}
h.logger.Tracef("wrote %d bytes", n)
if err := h.port.Flush(); err != nil {
return fmt.Errorf("could not flush data: %w", err)
}
return nil
}
func (h *Handler) readFromPort() ([]byte, error) {
const bufferSize = 1024
buf := make([]byte, bufferSize)
count, err := h.port.Read(buf)
if err != nil {
return nil, fmt.Errorf("could not read from serial port: %w", err)
}
return buf[:count], nil
}
func (h *Handler) handleCmdAnnounce() error {
h.logger.Trace("waiting for command announce")
select {
case frame := <-h.framesIn:
if frame == nil {
return nil
}
if err := cobs.Verify(frame, cobsConfig); err != nil {
return fmt.Errorf("could not verify COBS frame: %w", err)
}
// perform COBS decoding
decoded := cobs.Decode(frame, cobsConfig)
if h.protocolState == stAnnounce {
if err := h.handleCommandAnnounce(decoded); err != nil {
return err
}
}
if h.protocolState == stCommand {
if err := h.handleCommandData(decoded); err != nil {
return err
}
if err := h.protocolHandler.HandleCommandAnnounce(frame); err != nil {
return fmt.Errorf("command announce handling failed: %w", err)
}
if err := h.nextState(); err != nil {
@ -147,65 +290,72 @@ func (h *Handler) handleFrames(frames [][]byte) error {
return nil
}
func (h *Handler) handleCommandData(decoded []byte) error {
respAnn, msg, err := h.protocolHandler.HandleCommand(h.currentCommand.Announce, decoded)
if err != nil {
return fmt.Errorf("command handler for %s failed: %w", h.currentCommand.Announce.Code, err)
func (h *Handler) handleCmdData() error {
h.logger.Trace("waiting for command data")
select {
case frame := <-h.framesIn:
if frame == nil {
return nil
}
if err := h.protocolHandler.HandleCommand(frame); err != nil {
return fmt.Errorf("command handler failed: %w", err)
}
if err := h.nextState(); err != nil {
return err
}
}
if err := h.writeResponse(respAnn, msg, cobsConfig); err != nil {
return nil
}
func (h *Handler) handleRespAnnounce() error {
frame, err := h.protocolHandler.ResponseAnnounce()
if err != nil {
return fmt.Errorf("could not get response announcement: %w", err)
}
h.logger.Trace("writing response announce")
if err := h.writeFrame(frame); err != nil {
return err
}
if err := h.nextState(); err != nil {
return err
}
return nil
}
func (h *Handler) handleCommandAnnounce(decoded []byte) error {
announce, err := h.protocolHandler.HandleCommandAnnounce(decoded)
func (h *Handler) handleRespData() error {
frame, err := h.protocolHandler.ResponseData()
if err != nil {
return fmt.Errorf("command announce handling failed: %w", err)
return fmt.Errorf("could not get response data: %w", err)
}
h.currentCommand = &protocol.Command{Announce: announce}
h.logger.Trace("writing response data")
return nil
}
func (h *Handler) writeResponse(ann, msg []byte, cobsConfig cobs.Config) error {
encoded := cobs.Encode(ann, cobsConfig)
if _, err := h.port.Write(encoded); err != nil {
return fmt.Errorf("could not write response announcement: %w", err)
if err := h.writeFrame(frame); err != nil {
return err
}
encoded = cobs.Encode(msg, cobsConfig)
if _, err := h.port.Write(encoded); err != nil {
return fmt.Errorf("could not write response: %w", err)
if err := h.nextState(); err != nil {
return err
}
return nil
}
func (h *Handler) nextState() error {
var next protocolState
switch h.protocolState {
case stAnnounce:
next = stCommand
case stCommand:
next = stAnnounce
default:
return fmt.Errorf("illegal protocol state %d", int(h.protocolState))
func New(cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.Handler) (*Handler, error) {
h := &Handler{
protocolHandler: protocolHandler,
logger: logger,
framesIn: make(chan []byte, 0),
}
h.protocolState = next
return nil
}
func New(cfg *config.Serial, protocolHandler protocol.Handler) (*Handler, error) {
h := &Handler{protocolHandler: protocolHandler}
h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout}
err := h.setupConnection()