diff --git a/cmd/clientsim/main.go b/cmd/clientsim/main.go index 5a01f13..8027487 100644 --- a/cmd/clientsim/main.go +++ b/cmd/clientsim/main.go @@ -23,6 +23,7 @@ import ( "bytes" "context" "fmt" + "io" "os" "sync" "time" @@ -36,39 +37,135 @@ import ( "git.cacert.org/cacert-gosigner/pkg/messages" ) -const cobsDelimiter = 0x00 +var cobsConfig = cobs.Config{SpecialByte: 0x00, Delimiter: true, EndingSave: true} -var cobsConfig = cobs.Config{SpecialByte: cobsDelimiter, Delimiter: true, EndingSave: true} +type protocolState int8 -func main() { - logger := logrus.New() - logger.SetOutput(os.Stderr) - logger.SetLevel(logrus.InfoLevel) +const ( + cmdAnnounce protocolState = iota + cmdData + respAnnounce + respData +) - sim := &clientSimulator{ - logger: logger, +var validTransitions = map[protocolState]protocolState{ + cmdAnnounce: cmdData, + cmdData: respAnnounce, + respAnnounce: respData, + respData: cmdAnnounce, +} + +var protocolStateNames = map[protocolState]string{ + cmdAnnounce: "CMD ANNOUNCE", + cmdData: "CMD DATA", + respAnnounce: "RESP ANNOUNCE", + respData: "RESP DATA", +} + +func (p protocolState) String() string { + if name, ok := protocolStateNames[p]; ok { + return name } - err := sim.Run() + return fmt.Sprintf("unknown %d", p) +} + +type TestCommandGenerator struct { + logger *logrus.Logger + currentCommand *protocol.Command + currentResponse *protocol.Response + commands chan *protocol.Command + lock sync.Mutex +} + +func (g *TestCommandGenerator) CmdAnnouncement() ([]byte, error) { + g.lock.Lock() + defer g.lock.Unlock() + + select { + case g.currentCommand = <-g.commands: + announceData, err := msgpack.Marshal(g.currentCommand.Announce) + if err != nil { + return nil, fmt.Errorf("could not marshal command annoucement: %w", err) + } + + g.logger.WithField("announcement", &g.currentCommand.Announce).Info("write command announcement") + + return announceData, nil + } +} + +func (g *TestCommandGenerator) CmdData() ([]byte, error) { + g.lock.Lock() + defer g.lock.Unlock() + + cmdData, err := msgpack.Marshal(g.currentCommand.Command) if err != nil { - logger.WithError(err).Error("simulator returned an error") + return nil, fmt.Errorf("could not marshal command data: %w", err) } + + g.logger.WithField("command", &g.currentCommand.Command).Info("write command data") + + return cmdData, nil } -type clientSimulator struct { - logger *logrus.Logger - commands chan *protocol.Command - responses chan [][]byte +func (g *TestCommandGenerator) HandleResponseAnnounce(frame []byte) error { + g.lock.Lock() + defer g.lock.Unlock() + + var ann messages.ResponseAnnounce + + if err := msgpack.Unmarshal(frame, &ann); err != nil { + return fmt.Errorf("could not unmarshal response announcement") + } + + g.logger.WithField("announcement", &ann).Info("received response announcement") + + g.currentResponse = &protocol.Response{Announce: &ann} + + return nil } -func (c *clientSimulator) writeTestCommands(ctx context.Context) error { - messages.RegisterGeneratedResolver() +func (g *TestCommandGenerator) HandleResponse(frame []byte) error { + g.lock.Lock() + defer g.lock.Unlock() + + switch g.currentResponse.Announce.Code { + case messages.RespHealth: + var response messages.HealthResponse + + if err := msgpack.Unmarshal(frame, &response); err != nil { + return fmt.Errorf("unmarshal failed: %w", err) + } + + g.currentResponse.Response = response + case messages.RespFetchCRL: + var response messages.FetchCRLResponse + + if err := msgpack.Unmarshal(frame, &response); err != nil { + return fmt.Errorf("unmarshal failed: %w", err) + } + } + g.logger.WithField( + "command", + g.currentCommand, + ).WithField( + "response", + g.currentResponse, + ).Info("handled health response") + + return nil +} + +func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error { const healthInterval = 10 * time.Second + g.logger.Info("start generating commands") + time.Sleep(healthInterval) - c.commands <- &protocol.Command{ + g.commands <- &protocol.Command{ Announce: messages.BuildCommandAnnounce(messages.CmdFetchCRL), Command: &messages.FetchCRLCommand{IssuerID: "sub-ecc_person_2022"}, } @@ -80,11 +177,11 @@ func (c *clientSimulator) writeTestCommands(ctx context.Context) error { case <-ctx.Done(): _ = timer.Stop() - c.logger.Info("stopping health check loop") + g.logger.Info("stopping health check loop") return nil case <-timer.C: - c.commands <- &protocol.Command{ + g.commands <- &protocol.Command{ Announce: messages.BuildCommandAnnounce(messages.CmdHealth), Command: &messages.HealthCommand{}, } @@ -94,218 +191,292 @@ func (c *clientSimulator) writeTestCommands(ctx context.Context) error { } } -func (c *clientSimulator) handleInput(ctx context.Context) error { - const ( - bufferSize = 1024 * 1024 - readInterval = 50 * time.Millisecond - ) +type clientSimulator struct { + protocolState protocolState + logger *logrus.Logger + lock sync.Mutex + framesIn chan []byte + commandGenerator *TestCommandGenerator +} - buf := make([]byte, bufferSize) +func (c *clientSimulator) readFrames() error { + const readInterval = 50 * time.Millisecond - type protocolState int8 + var frame []byte + buffer := &bytes.Buffer{} + delimiter := []byte{cobsConfig.SpecialByte} - const ( - stAnn protocolState = iota - stResp - ) + for { + readBytes, err := c.readFromStdin() + if err != nil { + c.logger.WithError(err).Error("stdin read error") - state := stAnn + close(c.framesIn) - var announce []byte + return err + } -reading: - for { - select { - case <-ctx.Done(): - return nil - default: - count, err := os.Stdin.Read(buf) - if err != nil { - return fmt.Errorf("reading input failed: %w", err) - } + if len(readBytes) == 0 { + time.Sleep(readInterval) - if count == 0 { - time.Sleep(readInterval) + continue + } - continue - } + c.logger.Tracef("read %d bytes", len(readBytes)) + + buffer.Write(readBytes) - data := buf[:count] + c.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) - for _, frame := range bytes.SplitAfter(data, []byte{cobsConfig.SpecialByte}) { - if len(frame) == 0 { - continue reading - } + rest := buffer.Bytes() - err = cobs.Verify(frame, cobsConfig) - if err != nil { - return fmt.Errorf("frame verification failed: %w", err) - } + if !bytes.Contains(rest, delimiter) { + continue + } - if state == stAnn { - announce = cobs.Decode(frame, cobsConfig) + for bytes.Contains(rest, delimiter) { + parts := bytes.SplitAfterN(rest, delimiter, 2) + frame, rest = parts[0], parts[1] - state = stResp - } else { - c.responses <- [][]byte{announce, cobs.Decode(frame, cobsConfig)} + c.logger.Tracef("frame of length %d", len(frame)) - state = stAnn - } + if len(frame) == 0 { + continue } - } - } -} -func (c *clientSimulator) handleCommands(ctx context.Context) error { - for { - select { - case command := <-c.commands: - if err := writeCommandAnnouncement(command); err != nil { - return err + err = cobs.Verify(frame, cobsConfig) + if err != nil { + return fmt.Errorf("frame verification failed: %w", err) } - if err := writeCommand(command); err != nil { - return err - } + decoded := cobs.Decode(frame, cobsConfig) - respData := <-c.responses + c.logger.Tracef("frame decoded to length %d", len(decoded)) - c.logger.WithField("respdata", respData).Trace("read response data") + c.framesIn <- decoded + } - response := &protocol.Response{} + buffer.Truncate(0) + buffer.Write(rest) - if err := msgpack.Unmarshal(respData[0], &response.Announce); err != nil { - return fmt.Errorf("could not unmarshal response announcement: %w", err) - } + c.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) + } +} - if err := c.handleResponse(response, respData[1]); err != nil { - return err - } +func (c *clientSimulator) writeFrame(frame []byte) error { + encoded := cobs.Encode(frame, cobsConfig) - case <-ctx.Done(): - return nil - } + c.lock.Lock() + defer c.lock.Unlock() + + if _, err := io.Copy(os.Stdout, bytes.NewBuffer(encoded)); err != nil { + return fmt.Errorf("could not write data: %w", err) } + + return nil } -func (c *clientSimulator) handleResponse(response *protocol.Response, respBytes []byte) error { - switch response.Announce.Code { - case messages.RespHealth: - data := messages.HealthResponse{} +func (c *clientSimulator) readFromStdin() ([]byte, error) { + const bufferSize = 1024 - if err := msgpack.Unmarshal(respBytes, &data); err != nil { - return fmt.Errorf("could not unmarshal health data: %w", err) - } + buf := make([]byte, bufferSize) - c.logger.WithField( - "announce", - response.Announce, - ).WithField( - "data", - &data, - ).Infof("received response") - case messages.RespFetchCRL: - data := messages.FetchCRLResponse{} + count, err := os.Stdin.Read(buf) + if err != nil { + return nil, fmt.Errorf("reading input failed: %w", err) + } - if err := msgpack.Unmarshal(respBytes, &data); err != nil { - return fmt.Errorf("could not unmarshal fetch CRL data: %w", err) - } + return buf[:count], nil +} - c.logger.WithField( - "announce", - response.Announce, - ).WithField( - "data", - &data, - ).Infof("received response") - case messages.RespError: - data := messages.ErrorResponse{} - - if err := msgpack.Unmarshal(respBytes, &data); err != nil { - return fmt.Errorf("could not unmarshal error data: %w", err) - } +func (c *clientSimulator) writeCmdAnnouncement() error { + frame, err := c.commandGenerator.CmdAnnouncement() + if err != nil { + return fmt.Errorf("could not get command annoucement: %w", err) + } - c.logger.WithField( - "announce", - response.Announce, - ).WithField( - "data", - &data, - ).Infof("received response") - default: - if err := msgpack.Unmarshal(respBytes, &response.Response); err != nil { - return fmt.Errorf("could not unmarshal response: %w", err) - } + c.logger.Trace("writing command announcement") + + if err := c.writeFrame(frame); err != nil { + return err + } - c.logger.WithField("response", response).Infof("received response") + if err := c.nextState(); err != nil { + return err } return nil } -func writeCommandAnnouncement(command *protocol.Command) error { - cmdAnnounceBytes, err := msgpack.Marshal(&command.Announce) +func (c *clientSimulator) writeCommandAnnouncement() error { + frame, err := c.commandGenerator.CmdAnnouncement() if err != nil { - return fmt.Errorf("could not marshal command announcement bytes: %w", err) + return fmt.Errorf("could not get command announcement: %w", err) } - if _, err = os.Stdout.Write(cobs.Encode(cmdAnnounceBytes, cobsConfig)); err != nil { - return fmt.Errorf("command announcement write failed: %w", err) + c.logger.Trace("writing command announcement") + + if err := c.writeFrame(frame); err != nil { + return err + } + + if err := c.nextState(); err != nil { + return err } return nil } -func writeCommand(command *protocol.Command) error { - cmdBytes, err := msgpack.Marshal(&command.Command) +func (c *clientSimulator) writeCommand() error { + frame, err := c.commandGenerator.CmdData() if err != nil { - return fmt.Errorf("could not marshal command bytes: %w", err) + return fmt.Errorf("could not get command data: %w", err) + } + + c.logger.Trace("writing command data") + + if err := c.writeFrame(frame); err != nil { + return err } - if _, err = os.Stdout.Write(cobs.Encode(cmdBytes, cobsConfig)); err != nil { - return fmt.Errorf("command write failed: %w", err) + if err := c.nextState(); err != nil { + return err } return nil } -func (c *clientSimulator) Run() error { - ctx, cancel := context.WithCancel(context.Background()) +func (c *clientSimulator) handleResponseAnnounce() error { + c.logger.Trace("waiting for response announce") - c.commands = make(chan *protocol.Command) - c.responses = make(chan [][]byte) + select { + case frame := <-c.framesIn: + if frame == nil { + return nil + } - wg := sync.WaitGroup{} - wg.Add(2) + if err := c.commandGenerator.HandleResponseAnnounce(frame); err != nil { + return fmt.Errorf("response announce handling failed: %w", err) + } - go func() { - if err := c.handleInput(ctx); err != nil { - c.logger.WithError(err).Error("input handling failed") + if err := c.nextState(); err != nil { + return err } + } + + return nil +} + +func (c *clientSimulator) handleResponseData() error { + c.logger.Trace("waiting for response data") + + select { + case frame := <-c.framesIn: + if frame == nil { + return nil + } + + if err := c.commandGenerator.HandleResponse(frame); err != nil { + return fmt.Errorf("response handler failed: %w", err) + } + + if err := c.nextState(); err != nil { + return err + } + } + + return nil +} + +func (c *clientSimulator) Run(ctx context.Context) error { + c.protocolState = cmdAnnounce + errors := make(chan error) - cancel() + go func() { + err := c.readFrames() - wg.Done() + errors <- err }() go func() { - if err := c.handleCommands(ctx); err != nil { - c.logger.WithError(err).Error("command handling failed") + err := c.commandGenerator.GenerateCommands(ctx) + + errors <- err + }() + + for { + select { + case <-ctx.Done(): + return nil + case err := <-errors: + if err != nil { + return fmt.Errorf("error from handler loop: %w", err) + } + return nil + default: + if err := c.handleProtocolState(); err != nil { + return err + } } + } +} - cancel() +func (c *clientSimulator) handleProtocolState() error { + c.logger.Tracef("handling protocol state %s", c.protocolState) - wg.Done() - }() + switch c.protocolState { + case cmdAnnounce: + if err := c.writeCmdAnnouncement(); err != nil { + return err + } + case cmdData: + if err := c.writeCommand(); err != nil { + return err + } + case respAnnounce: + if err := c.handleResponseAnnounce(); err != nil { + return err + } + case respData: + if err := c.handleResponseData(); err != nil { + return err + } + default: + return fmt.Errorf("unknown protocol state %s", c.protocolState) + } - var result error + return nil +} - if err := c.writeTestCommands(ctx); err != nil { - c.logger.WithError(err).Error("test commands failed") +func (c *clientSimulator) nextState() error { + next, ok := validTransitions[c.protocolState] + if !ok { + return fmt.Errorf("illegal protocol state %s", c.protocolState) } - cancel() - wg.Wait() + c.protocolState = next - return result + return nil +} + +func main() { + logger := logrus.New() + logger.SetOutput(os.Stderr) + logger.SetLevel(logrus.TraceLevel) + + messages.RegisterGeneratedResolver() + + sim := &clientSimulator{ + commandGenerator: &TestCommandGenerator{ + logger: logger, + commands: make(chan *protocol.Command, 0), + }, + logger: logger, + framesIn: make(chan []byte, 0), + } + + err := sim.Run(context.Background()) + if err != nil { + logger.WithError(err).Error("simulator returned an error") + } } diff --git a/cmd/signer/main.go b/cmd/signer/main.go index e0911b8..c0296d6 100644 --- a/cmd/signer/main.go +++ b/cmd/signer/main.go @@ -18,6 +18,7 @@ limitations under the License. package main import ( + "context" "flag" "fmt" "os" @@ -101,18 +102,18 @@ func main() { logger.WithError(err).Fatal("could not setup protocol handler") } - serialHandler, err := seriallink.New(caConfig.GetSerial(), proto) + serialHandler, err := seriallink.New(caConfig.GetSerial(), logger, proto) if err != nil { logger.WithError(err).Fatal("could not setup serial link handler") } defer func() { _ = serialHandler.Close() }() - if err = serialHandler.Run(); err != nil { + logger.Info("setup complete, starting signer operation") + + if err = serialHandler.Run(context.Background()); err != nil { logger.WithError(err).Fatal("error in serial handler") } - - logger.Info("setup complete, starting signer operation") } func configureRepositories( diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 4c0c798..cacd2c4 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -162,8 +162,9 @@ func TestPrivateKeyInfo_UnmarshalYAML(t *testing.T) { algorithm: "RSA" rsa-bits: 2048`, expected: &config.PrivateKeyInfo{ - Algorithm: x509.RSA, - RSABits: 2048, + Algorithm: x509.RSA, + RSABits: 2048, + CRLSignatureAlgorithm: x509.SHA256WithRSA, }, }, { @@ -172,8 +173,9 @@ rsa-bits: 2048`, algorithm: "EC" ecc-curve: "P-224"`, expected: &config.PrivateKeyInfo{ - Algorithm: x509.ECDSA, - EccCurve: elliptic.P224(), + Algorithm: x509.ECDSA, + EccCurve: elliptic.P224(), + CRLSignatureAlgorithm: x509.ECDSAWithSHA256, }, }, { @@ -182,8 +184,9 @@ ecc-curve: "P-224"`, algorithm: "EC" ecc-curve: "P-256"`, expected: &config.PrivateKeyInfo{ - Algorithm: x509.ECDSA, - EccCurve: elliptic.P256(), + Algorithm: x509.ECDSA, + EccCurve: elliptic.P256(), + CRLSignatureAlgorithm: x509.ECDSAWithSHA256, }, }, { @@ -192,8 +195,9 @@ ecc-curve: "P-256"`, algorithm: "EC" ecc-curve: "P-384"`, expected: &config.PrivateKeyInfo{ - Algorithm: x509.ECDSA, - EccCurve: elliptic.P384(), + Algorithm: x509.ECDSA, + EccCurve: elliptic.P384(), + CRLSignatureAlgorithm: x509.ECDSAWithSHA256, }, }, { @@ -202,8 +206,9 @@ ecc-curve: "P-384"`, algorithm: "EC" ecc-curve: "P-521"`, expected: &config.PrivateKeyInfo{ - Algorithm: x509.ECDSA, - EccCurve: elliptic.P521(), + Algorithm: x509.ECDSA, + EccCurve: elliptic.P521(), + CRLSignatureAlgorithm: x509.ECDSAWithSHA256, }, }, { @@ -308,8 +313,9 @@ storage: root `, expected: config.CaCertificateEntry{ KeyInfo: &config.PrivateKeyInfo{ - Algorithm: x509.ECDSA, - EccCurve: elliptic.P521(), + Algorithm: x509.ECDSA, + EccCurve: elliptic.P521(), + CRLSignatureAlgorithm: x509.ECDSAWithSHA256, }, CommonName: "My Little Test Root CA", Storage: "root", @@ -329,8 +335,9 @@ ext-key-usages: `, expected: config.CaCertificateEntry{ KeyInfo: &config.PrivateKeyInfo{ - Algorithm: x509.ECDSA, - EccCurve: elliptic.P256(), + Algorithm: x509.ECDSA, + EccCurve: elliptic.P256(), + CRLSignatureAlgorithm: x509.ECDSAWithSHA256, }, Parent: "root", CommonName: "My Little Test Sub CA", @@ -357,8 +364,9 @@ ext-key-usages: `, expected: config.CaCertificateEntry{ KeyInfo: &config.PrivateKeyInfo{ - Algorithm: x509.ECDSA, - EccCurve: elliptic.P256(), + Algorithm: x509.ECDSA, + EccCurve: elliptic.P256(), + CRLSignatureAlgorithm: x509.ECDSAWithSHA256, }, CommonName: "My Little Test Sub CA", Storage: "default", diff --git a/pkg/protocol/protocol.go b/pkg/protocol/protocol.go index 0914405..d97ea21 100644 --- a/pkg/protocol/protocol.go +++ b/pkg/protocol/protocol.go @@ -21,6 +21,7 @@ package protocol import ( "errors" "fmt" + "sync" "github.com/shamaton/msgpackgen/msgpack" "github.com/sirupsen/logrus" @@ -31,50 +32,77 @@ import ( "git.cacert.org/cacert-gosigner/pkg/messages" ) +type Command struct { + Announce *messages.CommandAnnounce + Command interface{} +} + +type Response struct { + Announce *messages.ResponseAnnounce + Response interface{} +} + +func (r *Response) String() string { + return fmt.Sprintf("Response[Code=%s] created=%s data=%s", r.Announce.Code, r.Announce.Created, r.Response) +} + // Handler is responsible for parsing incoming frames and calling commands type Handler interface { - HandleCommandAnnounce([]byte) (*messages.CommandAnnounce, error) - HandleCommand(*messages.CommandAnnounce, []byte) ([]byte, []byte, error) + HandleCommandAnnounce([]byte) error + HandleCommand([]byte) error + ResponseAnnounce() ([]byte, error) + ResponseData() ([]byte, error) } type MsgPackHandler struct { logger *logrus.Logger healthHandler *health.Handler fetchCRLHandler *revoking.FetchCRLHandler + currentCommand *Command + currentResponse *Response + lock sync.Mutex } -func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) (*messages.CommandAnnounce, error) { +func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) error { + m.lock.Lock() + defer m.lock.Unlock() + var ann messages.CommandAnnounce if err := msgpack.Unmarshal(frame, &ann); err != nil { - return nil, fmt.Errorf("could not unmarshal command announcement: %w", err) + return fmt.Errorf("could not unmarshal command announcement: %w", err) } - m.logger.Infof("received command announcement %+v", ann) + m.logger.WithField("announcement", &ann).Info("received command announcement") + + m.currentCommand = &Command{Announce: &ann} - return &ann, nil + return nil } -func (m *MsgPackHandler) HandleCommand(announce *messages.CommandAnnounce, frame []byte) ([]byte, []byte, error) { - var ( - response *Response - clientError, err error - ) +func (m *MsgPackHandler) HandleCommand(frame []byte) error { + m.lock.Lock() + defer m.lock.Unlock() + + var clientError error - switch announce.Code { + switch m.currentCommand.Announce.Code { case messages.CmdHealth: // health has no payload, ignore the frame - response, err = m.handleCommand(&Command{Announce: announce, Command: nil}) + response, err := m.handleCommand() if err != nil { m.logger.WithError(err).Error("health handling failed") clientError = errors.New("could not handle request") + + break } + + m.currentResponse = response case messages.CmdFetchCRL: var command messages.FetchCRLCommand - err = msgpack.Unmarshal(frame, &command) - if err != nil { + if err := msgpack.Unmarshal(frame, &command); err != nil { m.logger.WithError(err).Error("unmarshal failed") clientError = errors.New("could not unmarshal fetch crl command") @@ -82,62 +110,82 @@ func (m *MsgPackHandler) HandleCommand(announce *messages.CommandAnnounce, frame break } - response, err = m.handleCommand(&Command{Announce: announce, Command: command}) + m.currentCommand.Command = command + + response, err := m.handleCommand() if err != nil { m.logger.WithError(err).Error("fetch CRL handling failed") clientError = errors.New("could not handle request") + + break } + + m.currentResponse = response } if clientError != nil { - response = buildErrorResponse(clientError.Error()) + m.currentResponse = buildErrorResponse(clientError.Error()) } - announceData, err := msgpack.Marshal(response.Announce) - if err != nil { - return nil, nil, fmt.Errorf("could not marshal response announcement: %w", err) - } + m.logger.WithField( + "command", + m.currentCommand, + ).WithField( + "response", + m.currentResponse, + ).Info("handled command") + + m.currentCommand = nil + + return nil +} + +func (m *MsgPackHandler) ResponseAnnounce() ([]byte, error) { + m.lock.Lock() + defer m.lock.Unlock() - responseData, err := msgpack.Marshal(response.Response) + announceData, err := msgpack.Marshal(m.currentResponse.Announce) if err != nil { - return nil, nil, fmt.Errorf("could not marshal response: %w", err) + return nil, fmt.Errorf("could not marshal response announcement: %w", err) } - return announceData, responseData, nil -} + m.logger.WithField("announcement", &m.currentResponse.Announce).Info("write response announcement") -type Command struct { - Announce *messages.CommandAnnounce - Command interface{} + return announceData, nil } -type Response struct { - Announce *messages.ResponseAnnounce - Response interface{} -} +func (m *MsgPackHandler) ResponseData() ([]byte, error) { + m.lock.Lock() + defer m.lock.Unlock() -func (r *Response) String() string { - return fmt.Sprintf("Response[Code=%s] created=%s data=%s", r.Announce.Code, r.Announce.Created, r.Response) + responseData, err := msgpack.Marshal(m.currentResponse.Response) + if err != nil { + return nil, fmt.Errorf("could not marshal response: %w", err) + } + + m.logger.WithField("response", &m.currentResponse.Response).Info("write response") + + return responseData, nil } -func (m *MsgPackHandler) handleCommand(command *Command) (*Response, error) { +func (m *MsgPackHandler) handleCommand() (*Response, error) { var ( err error responseData interface{} responseCode messages.ResponseCode ) - switch command.Announce.Code { + switch m.currentCommand.Announce.Code { case messages.CmdHealth: var res *health.Result res, err = m.healthHandler.CheckHealth() if err != nil { - break + return nil, err } - response := messages.HealthResponse{ + response := &messages.HealthResponse{ Version: res.Version, Healthy: res.Healthy, } @@ -154,24 +202,24 @@ func (m *MsgPackHandler) handleCommand(command *Command) (*Response, error) { case messages.CmdFetchCRL: var res *revoking.Result - fetchCRLPayload, ok := command.Command.(messages.FetchCRLCommand) + fetchCRLPayload, ok := m.currentCommand.Command.(messages.FetchCRLCommand) if !ok { return nil, fmt.Errorf("could not use payload as FetchCRLPayload") } res, err = m.fetchCRLHandler.FetchCRL(fetchCRLPayload.IssuerID) if err != nil { - break + return nil, err } - response := messages.FetchCRLResponse{ + response := &messages.FetchCRLResponse{ IsDelta: false, CRLData: res.Data, } responseCode, responseData = messages.RespFetchCRL, response default: - return nil, fmt.Errorf("unhandled command %v", command) + return nil, fmt.Errorf("unhandled command %s", m.currentCommand.Announce) } if err != nil { @@ -187,7 +235,7 @@ func (m *MsgPackHandler) handleCommand(command *Command) (*Response, error) { func buildErrorResponse(errMsg string) *Response { return &Response{ Announce: messages.BuildResponseAnnounce(messages.RespError), - Response: messages.ErrorResponse{Message: errMsg}, + Response: &messages.ErrorResponse{Message: errMsg}, } } diff --git a/pkg/seriallink/seriallink.go b/pkg/seriallink/seriallink.go index 1e6c402..b9acb77 100644 --- a/pkg/seriallink/seriallink.go +++ b/pkg/seriallink/seriallink.go @@ -20,10 +20,14 @@ package seriallink import ( "bytes" + "context" "fmt" + "io" + "sync" "time" "github.com/justincpresley/go-cobs" + "github.com/sirupsen/logrus" "github.com/tarm/serial" "git.cacert.org/cacert-gosigner/pkg/config" @@ -33,16 +37,42 @@ import ( type protocolState int8 const ( - stAnnounce protocolState = iota - stCommand + cmdAnnounce protocolState = iota + cmdData + respAnnounce + respData ) +var validTransitions = map[protocolState]protocolState{ + cmdAnnounce: cmdData, + cmdData: respAnnounce, + respAnnounce: respData, + respData: cmdAnnounce, +} + +var protocolStateNames = map[protocolState]string{ + cmdAnnounce: "CMD ANNOUNCE", + cmdData: "CMD DATA", + respAnnounce: "RESP ANNOUNCE", + respData: "RESP DATA", +} + +func (p protocolState) String() string { + if name, ok := protocolStateNames[p]; ok { + return name + } + + return fmt.Sprintf("unknown %d", p) +} + type Handler struct { protocolHandler protocol.Handler protocolState protocolState - currentCommand *protocol.Command config *serial.Config port *serial.Port + logger *logrus.Logger + lock sync.Mutex + framesIn chan []byte } func (h *Handler) setupConnection() error { @@ -65,147 +95,267 @@ func (h *Handler) Close() error { return nil } -const cobsDelimiter = 0x00 +var cobsConfig = cobs.Config{SpecialByte: 0x00, Delimiter: true, EndingSave: true} + +func (h *Handler) Run(ctx context.Context) error { + h.protocolState = cmdAnnounce + errors := make(chan error) + + go func() { + err := h.readFrames() + + errors <- err + }() + + for { + select { + case <-ctx.Done(): + return nil + + case err := <-errors: + if err != nil { + return fmt.Errorf("error from handler loop: %w", err) + } -var cobsConfig = cobs.Config{SpecialByte: cobsDelimiter, Delimiter: true, EndingSave: true} + return nil -func (h *Handler) Run() error { + default: + if err := h.handleProtocolState(); err != nil { + return err + } + } + } +} + +func (h *Handler) readFrames() error { const ( - bufferSize = 1024 * 1024 readInterval = 50 * time.Millisecond ) - errors := make(chan error) + var frame []byte + buffer := &bytes.Buffer{} + delimiter := []byte{cobsConfig.SpecialByte} - h.protocolState = stAnnounce + for { + readBytes, err := h.readFromPort() + if err != nil { + close(h.framesIn) - go func() { - buf := make([]byte, bufferSize) + return err + } - for { - count, err := h.port.Read(buf) - if err != nil { - errors <- err + if len(readBytes) == 0 { + time.Sleep(readInterval) - return - } + continue + } + + h.logger.Tracef("read %d bytes", len(readBytes)) - if count == 0 { - time.Sleep(readInterval) + buffer.Write(readBytes) + + h.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) + + rest := buffer.Bytes() + + if !bytes.Contains(rest, delimiter) { + continue + } + for bytes.Contains(rest, delimiter) { + parts := bytes.SplitAfterN(rest, delimiter, 2) + frame, rest = parts[0], parts[1] + + h.logger.Tracef("frame of length %d", len(frame)) + + if len(frame) == 0 { continue } - frames := bytes.SplitAfter(buf[:count], []byte{cobsDelimiter}) + if err := cobs.Verify(frame, cobsConfig); err != nil { + close(h.framesIn) - if err := h.handleFrames(frames); err != nil { - errors <- err - - return + return fmt.Errorf("could not verify COBS frame: %w", err) } + + decoded := cobs.Decode(frame, cobsConfig) + + h.logger.Tracef("frame decoded to length %d", len(decoded)) + + h.framesIn <- decoded } - }() - err := <-errors - if err != nil { - return fmt.Errorf("error from handler loop: %w", err) + buffer.Truncate(0) + buffer.Write(rest) + + h.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) } +} - return nil +func (h *Handler) writeFrame(frame []byte) error { + encoded := cobs.Encode(frame, cobsConfig) + + return h.writeToPort(encoded) } -func (h *Handler) handleFrames(frames [][]byte) error { - for _, frame := range frames { - if len(frame) == 0 { - return nil - } +func (h *Handler) nextState() error { + next, ok := validTransitions[h.protocolState] + if !ok { + return fmt.Errorf("illegal protocol state %s", h.protocolState) + } - if err := cobs.Verify(frame, cobsConfig); err != nil { - return fmt.Errorf("could not verify COBS frame: %w", err) - } + h.protocolState = next - // perform COBS decoding - decoded := cobs.Decode(frame, cobsConfig) + return nil +} - if h.protocolState == stAnnounce { - if err := h.handleCommandAnnounce(decoded); err != nil { - return err - } - } +func (h *Handler) handleProtocolState() error { + h.logger.Tracef("handling protocol state %s", h.protocolState) - if h.protocolState == stCommand { - if err := h.handleCommandData(decoded); err != nil { - return err - } + switch h.protocolState { + case cmdAnnounce: + if err := h.handleCmdAnnounce(); err != nil { + return err } - - if err := h.nextState(); err != nil { + case cmdData: + if err := h.handleCmdData(); err != nil { return err } + case respAnnounce: + if err := h.handleRespAnnounce(); err != nil { + return err + } + case respData: + if err := h.handleRespData(); err != nil { + return err + } + default: + return fmt.Errorf("unknown protocol state %s", h.protocolState) } return nil } -func (h *Handler) handleCommandData(decoded []byte) error { - respAnn, msg, err := h.protocolHandler.HandleCommand(h.currentCommand.Announce, decoded) +func (h *Handler) writeToPort(data []byte) error { + h.lock.Lock() + defer h.lock.Unlock() + + reader := bytes.NewReader(data) + + n, err := io.Copy(h.port, reader) if err != nil { - return fmt.Errorf("command handler for %s failed: %w", h.currentCommand.Announce.Code, err) + return fmt.Errorf("could not write data: %w", err) } - if err := h.writeResponse(respAnn, msg, cobsConfig); err != nil { - return err + h.logger.Tracef("wrote %d bytes", n) + + if err := h.port.Flush(); err != nil { + return fmt.Errorf("could not flush data: %w", err) } return nil } -func (h *Handler) handleCommandAnnounce(decoded []byte) error { - announce, err := h.protocolHandler.HandleCommandAnnounce(decoded) +func (h *Handler) readFromPort() ([]byte, error) { + const bufferSize = 1024 + + buf := make([]byte, bufferSize) + + count, err := h.port.Read(buf) if err != nil { - return fmt.Errorf("command announce handling failed: %w", err) + return nil, fmt.Errorf("could not read from serial port: %w", err) } - h.currentCommand = &protocol.Command{Announce: announce} + return buf[:count], nil +} + +func (h *Handler) handleCmdAnnounce() error { + h.logger.Trace("waiting for command announce") + + select { + case frame := <-h.framesIn: + if frame == nil { + return nil + } + + if err := h.protocolHandler.HandleCommandAnnounce(frame); err != nil { + return fmt.Errorf("command announce handling failed: %w", err) + } + + if err := h.nextState(); err != nil { + return err + } + } return nil } -func (h *Handler) writeResponse(ann, msg []byte, cobsConfig cobs.Config) error { - encoded := cobs.Encode(ann, cobsConfig) +func (h *Handler) handleCmdData() error { + h.logger.Trace("waiting for command data") - if _, err := h.port.Write(encoded); err != nil { - return fmt.Errorf("could not write response announcement: %w", err) + select { + case frame := <-h.framesIn: + + if frame == nil { + return nil + } + + if err := h.protocolHandler.HandleCommand(frame); err != nil { + return fmt.Errorf("command handler failed: %w", err) + } + + if err := h.nextState(); err != nil { + return err + } + } + + return nil +} + +func (h *Handler) handleRespAnnounce() error { + frame, err := h.protocolHandler.ResponseAnnounce() + if err != nil { + return fmt.Errorf("could not get response announcement: %w", err) } - encoded = cobs.Encode(msg, cobsConfig) + h.logger.Trace("writing response announce") - if _, err := h.port.Write(encoded); err != nil { - return fmt.Errorf("could not write response: %w", err) + if err := h.writeFrame(frame); err != nil { + return err + } + + if err := h.nextState(); err != nil { + return err } return nil } -func (h *Handler) nextState() error { - var next protocolState +func (h *Handler) handleRespData() error { + frame, err := h.protocolHandler.ResponseData() + if err != nil { + return fmt.Errorf("could not get response data: %w", err) + } - switch h.protocolState { - case stAnnounce: - next = stCommand - case stCommand: - next = stAnnounce - default: - return fmt.Errorf("illegal protocol state %d", int(h.protocolState)) + h.logger.Trace("writing response data") + + if err := h.writeFrame(frame); err != nil { + return err } - h.protocolState = next + if err := h.nextState(); err != nil { + return err + } return nil } -func New(cfg *config.Serial, protocolHandler protocol.Handler) (*Handler, error) { - h := &Handler{protocolHandler: protocolHandler} +func New(cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.Handler) (*Handler, error) { + h := &Handler{ + protocolHandler: protocolHandler, + logger: logger, + framesIn: make(chan []byte, 0), + } h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout} err := h.setupConnection()