From e5dcf7afa9fead614d00dd316570adbfa6fb5693 Mon Sep 17 00:00:00 2001 From: Jan Dittberner Date: Tue, 29 Nov 2022 09:57:23 +0100 Subject: [PATCH] Refactor COBS wire protocol Wire protocol handling has been moved to protocol.Framer and its implementation protocol.COBSFramer --- cmd/clientsim/main.go | 126 ++++----------------------------- internal/serial/seriallink.go | 128 ++++------------------------------ pkg/messages/messages.go | 2 +- pkg/protocol/protocol.go | 119 +++++++++++++++++++++++++++++++ 4 files changed, 150 insertions(+), 225 deletions(-) diff --git a/cmd/clientsim/main.go b/cmd/clientsim/main.go index 5a9309c..e82e17b 100644 --- a/cmd/clientsim/main.go +++ b/cmd/clientsim/main.go @@ -20,15 +20,12 @@ limitations under the License. package main import ( - "bytes" "context" "fmt" - "io" "os" "sync" "time" - "github.com/justincpresley/go-cobs" "github.com/shamaton/msgpackgen/msgpack" "github.com/sirupsen/logrus" @@ -37,8 +34,6 @@ import ( "git.cacert.org/cacert-gosigner/pkg/messages" ) -var cobsConfig = cobs.Config{SpecialByte: protocol.CobsDelimiter, Delimiter: true, EndingSave: true} - type protocolState int8 const ( @@ -215,106 +210,11 @@ type clientSimulator struct { logger *logrus.Logger lock sync.Mutex framesIn chan []byte + framesOut chan []byte + framer protocol.Framer commandGenerator *TestCommandGenerator } -func (c *clientSimulator) readFrames() error { - const readInterval = 50 * time.Millisecond - - var frame []byte - - buffer := &bytes.Buffer{} - - delimiter := []byte{cobsConfig.SpecialByte} - - for { - readBytes, err := c.readFromStdin() - if err != nil { - c.logger.WithError(err).Error("stdin read error") - - close(c.framesIn) - - return err - } - - if len(readBytes) == 0 { - time.Sleep(readInterval) - - continue - } - - c.logger.Tracef("read %d bytes", len(readBytes)) - - buffer.Write(readBytes) - - c.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) - - rest := buffer.Bytes() - - if !bytes.Contains(rest, delimiter) { - c.logger.Tracef("read data does not contain the delimiter %x", delimiter) - - continue - } - - for bytes.Contains(rest, delimiter) { - parts := bytes.SplitAfterN(rest, delimiter, 2) - frame, rest = parts[0], parts[1] - - c.logger.Tracef("frame of length %d", len(frame)) - - if len(frame) == 0 { - continue - } - - err = cobs.Verify(frame, cobsConfig) - if err != nil { - return fmt.Errorf("frame verification failed: %w", err) - } - - decoded := cobs.Decode(frame, cobsConfig) - - c.logger.Tracef("frame decoded to length %d", len(decoded)) - - c.framesIn <- decoded - - c.logger.Tracef("%d bytes remaining", len(rest)) - } - - buffer.Truncate(0) - buffer.Write(rest) - - c.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) - } -} - -func (c *clientSimulator) writeFrame(frame []byte) error { - encoded := cobs.Encode(frame, cobsConfig) - - if _, err := io.Copy(os.Stdout, bytes.NewBuffer(encoded)); err != nil { - return fmt.Errorf("could not write data: %w", err) - } - - return nil -} - -func (c *clientSimulator) readFromStdin() ([]byte, error) { - const bufferSize = 1024 - - buf := make([]byte, bufferSize) - - c.logger.Trace("waiting for input") - - count, err := os.Stdin.Read(buf) - if err != nil { - return nil, fmt.Errorf("reading input failed: %w", err) - } - - c.logger.Tracef("read %d bytes from stdin", count) - - return buf[:count], nil -} - func (c *clientSimulator) writeCmdAnnouncement() error { frame, err := c.commandGenerator.CmdAnnouncement() if err != nil { @@ -323,9 +223,7 @@ func (c *clientSimulator) writeCmdAnnouncement() error { c.logger.Trace("writing command announcement") - if err := c.writeFrame(frame); err != nil { - return err - } + c.framesOut <- frame if err := c.nextState(); err != nil { return err @@ -342,9 +240,7 @@ func (c *clientSimulator) writeCommand() error { c.logger.Trace("writing command data") - if err := c.writeFrame(frame); err != nil { - return err - } + c.framesOut <- frame if err := c.nextState(); err != nil { return err @@ -398,7 +294,13 @@ func (c *clientSimulator) Run(ctx context.Context) error { errors := make(chan error) go func() { - err := c.readFrames() + err := c.framer.ReadFrames(os.Stdin, c.framesIn) + + errors <- err + }() + + go func() { + err := c.framer.WriteFrames(os.Stdout, c.framesOut) errors <- err }() @@ -480,8 +382,10 @@ func main() { logger: logger, commands: make(chan *protocol.Command), }, - logger: logger, - framesIn: make(chan []byte), + logger: logger, + framesIn: make(chan []byte), + framesOut: make(chan []byte), + framer: protocol.NewCOBSFramer(logger), } err := sim.Run(context.Background()) diff --git a/internal/serial/seriallink.go b/internal/serial/seriallink.go index 1320424..41cc9a7 100644 --- a/internal/serial/seriallink.go +++ b/internal/serial/seriallink.go @@ -15,18 +15,14 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package seriallink provides a handler for the serial connection of the signer machine. +// Package serial provides a handler for the serial connection of the signer machine. package serial import ( - "bytes" "context" "fmt" - "io" "sync" - "time" - "github.com/justincpresley/go-cobs" "github.com/sirupsen/logrus" "github.com/tarm/serial" @@ -68,11 +64,13 @@ func (p protocolState) String() string { type Handler struct { protocolHandler protocol.Handler protocolState protocolState + framer protocol.Framer config *serial.Config port *serial.Port logger *logrus.Logger lock sync.Mutex framesIn chan []byte + framesOut chan []byte } func (h *Handler) setupConnection() error { @@ -95,14 +93,18 @@ func (h *Handler) Close() error { return nil } -var cobsConfig = cobs.Config{SpecialByte: protocol.CobsDelimiter, Delimiter: true, EndingSave: true} - func (h *Handler) Run(ctx context.Context) error { h.protocolState = cmdAnnounce errors := make(chan error) go func() { - err := h.readFrames() + err := h.framer.ReadFrames(h.port, h.framesIn) + + errors <- err + }() + + go func() { + err := h.framer.WriteFrames(h.port, h.framesOut) errors <- err }() @@ -127,79 +129,6 @@ func (h *Handler) Run(ctx context.Context) error { } } -func (h *Handler) readFrames() error { - const ( - readInterval = 50 * time.Millisecond - ) - - var frame []byte - - buffer := &bytes.Buffer{} - - delimiter := []byte{cobsConfig.SpecialByte} - - for { - readBytes, err := h.readFromPort() - if err != nil { - close(h.framesIn) - - return err - } - - if len(readBytes) == 0 { - time.Sleep(readInterval) - - continue - } - - h.logger.Tracef("read %d bytes", len(readBytes)) - - buffer.Write(readBytes) - - h.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) - - rest := buffer.Bytes() - - if !bytes.Contains(rest, delimiter) { - continue - } - - for bytes.Contains(rest, delimiter) { - parts := bytes.SplitAfterN(rest, delimiter, 2) - frame, rest = parts[0], parts[1] - - h.logger.Tracef("frame of length %d", len(frame)) - - if len(frame) == 0 { - continue - } - - if err := cobs.Verify(frame, cobsConfig); err != nil { - close(h.framesIn) - - return fmt.Errorf("could not verify COBS frame: %w", err) - } - - decoded := cobs.Decode(frame, cobsConfig) - - h.logger.Tracef("frame decoded to length %d", len(decoded)) - - h.framesIn <- decoded - } - - buffer.Truncate(0) - buffer.Write(rest) - - h.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) - } -} - -func (h *Handler) writeFrame(frame []byte) error { - encoded := cobs.Encode(frame, cobsConfig) - - return h.writeToPort(encoded) -} - func (h *Handler) nextState() error { next, ok := validTransitions[h.protocolState] if !ok { @@ -241,36 +170,11 @@ func (h *Handler) handleProtocolState() error { return nil } -func (h *Handler) writeToPort(data []byte) error { - reader := bytes.NewReader(data) - - n, err := io.Copy(h.port, reader) - if err != nil { - return fmt.Errorf("could not write data: %w", err) - } - - h.logger.Tracef("wrote %d bytes", n) - - return nil -} - -func (h *Handler) readFromPort() ([]byte, error) { - const bufferSize = 1024 - - buf := make([]byte, bufferSize) - - count, err := h.port.Read(buf) - if err != nil { - return nil, fmt.Errorf("could not read from serial port: %w", err) - } - - return buf[:count], nil -} - func (h *Handler) handleCmdAnnounce() error { h.logger.Trace("waiting for command announce") frame := <-h.framesIn + if frame == nil { return nil } @@ -312,9 +216,7 @@ func (h *Handler) handleRespAnnounce() error { return fmt.Errorf("could not get response announcement: %w", err) } - if err := h.writeFrame(frame); err != nil { - return err - } + h.framesOut <- frame if err := h.nextState(); err != nil { return err @@ -329,9 +231,7 @@ func (h *Handler) handleRespData() error { return fmt.Errorf("could not get response data: %w", err) } - if err := h.writeFrame(frame); err != nil { - return err - } + h.framesOut <- frame if err := h.nextState(); err != nil { return err @@ -345,6 +245,8 @@ func New(cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.Han protocolHandler: protocolHandler, logger: logger, framesIn: make(chan []byte), + framesOut: make(chan []byte), + framer: protocol.NewCOBSFramer(logger), } h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout} diff --git a/pkg/messages/messages.go b/pkg/messages/messages.go index 1523713..7b76751 100644 --- a/pkg/messages/messages.go +++ b/pkg/messages/messages.go @@ -43,7 +43,7 @@ const ( var commandNames = map[CommandCode]string{ CmdHealth: "HEALTH", - CmdFetchCRL: "FETCH URL", + CmdFetchCRL: "FETCH CRL", } func (c CommandCode) String() string { diff --git a/pkg/protocol/protocol.go b/pkg/protocol/protocol.go index e2b5216..2e1c765 100644 --- a/pkg/protocol/protocol.go +++ b/pkg/protocol/protocol.go @@ -19,7 +19,14 @@ limitations under the License. package protocol import ( + "bytes" + "errors" "fmt" + "io" + "time" + + "github.com/justincpresley/go-cobs" + "github.com/sirupsen/logrus" "git.cacert.org/cacert-gosigner/pkg/messages" ) @@ -55,3 +62,115 @@ type Handler interface { // ResponseData generates the response data. ResponseData() ([]byte, error) } + +// Framer handles bytes on the wire by adding or removing framing information. +type Framer interface { + // ReadFrames reads data frames and publishes unframed data to the channel. + ReadFrames(io.Reader, chan []byte) error + // WriteFrames takes data from the channel and writes framed data to the writer. + WriteFrames(io.Writer, chan []byte) error +} + +const bufferSize = 1024 +const readInterval = 50 * time.Millisecond + +type COBSFramer struct { + config cobs.Config + logger *logrus.Logger +} + +func NewCOBSFramer(logger *logrus.Logger) *COBSFramer { + return &COBSFramer{ + config: cobs.Config{SpecialByte: CobsDelimiter, Delimiter: true, EndingSave: true}, + logger: logger, + } +} + +func (c *COBSFramer) ReadFrames(reader io.Reader, frameChan chan []byte) error { + var ( + err error + raw, data, frame []byte + ) + + buffer := &bytes.Buffer{} + + for { + raw, err = c.readRaw(reader) + if err != nil { + close(frameChan) + + return err + } + + if len(raw) == 0 { + time.Sleep(readInterval) + + continue + } + + c.logger.Tracef("read %d raw bytes", len(raw)) + + buffer.Write(raw) + + for { + data, err = buffer.ReadBytes(c.config.SpecialByte) + if err != nil { + if errors.Is(err, io.EOF) { + buffer.Write(data) + + break + } + + return fmt.Errorf("could not read from buffer: %w", err) + } + + if err = cobs.Verify(data, c.config); err != nil { + c.logger.WithError(err).Warnf("skipping invalid frame of %d bytes", len(data)) + + break + } + + frame = cobs.Decode(data, c.config) + + c.logger.Tracef("frame decoded to length %d", len(frame)) + + frameChan <- frame + } + + c.logger.Tracef("read buffer is now %d bytes long", buffer.Len()) + } +} + +func (c *COBSFramer) readRaw(reader io.Reader) ([]byte, error) { + buf := make([]byte, bufferSize) + + count, err := reader.Read(buf) + if err != nil { + return nil, fmt.Errorf("could not read data: %w", err) + } + + raw := buf[:count] + + return raw, nil +} + +func (c *COBSFramer) WriteFrames(writer io.Writer, frameChan chan []byte) error { + for { + frame := <-frameChan + + if frame == nil { + c.logger.Debug("channel closed") + + return nil + } + + encoded := cobs.Encode(frame, c.config) + + n, err := io.Copy(writer, bytes.NewReader(encoded)) + if err != nil { + return fmt.Errorf("cold not write data: %w", err) + } + + c.logger.Tracef("wrote %d bytes", n) + } +}