Refactor COBS wire protocol

Wire protocol handling has been moved to protocol.Framer and its
implementation protocol.COBSFramer
This commit is contained in:
Jan Dittberner 2022-11-29 09:57:23 +01:00
parent faaadbe5aa
commit e5dcf7afa9
4 changed files with 150 additions and 225 deletions

View file

@ -20,15 +20,12 @@ limitations under the License.
package main package main
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io"
"os" "os"
"sync" "sync"
"time" "time"
"github.com/justincpresley/go-cobs"
"github.com/shamaton/msgpackgen/msgpack" "github.com/shamaton/msgpackgen/msgpack"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -37,8 +34,6 @@ import (
"git.cacert.org/cacert-gosigner/pkg/messages" "git.cacert.org/cacert-gosigner/pkg/messages"
) )
var cobsConfig = cobs.Config{SpecialByte: protocol.CobsDelimiter, Delimiter: true, EndingSave: true}
type protocolState int8 type protocolState int8
const ( const (
@ -215,106 +210,11 @@ type clientSimulator struct {
logger *logrus.Logger logger *logrus.Logger
lock sync.Mutex lock sync.Mutex
framesIn chan []byte framesIn chan []byte
framesOut chan []byte
framer protocol.Framer
commandGenerator *TestCommandGenerator commandGenerator *TestCommandGenerator
} }
func (c *clientSimulator) readFrames() error {
const readInterval = 50 * time.Millisecond
var frame []byte
buffer := &bytes.Buffer{}
delimiter := []byte{cobsConfig.SpecialByte}
for {
readBytes, err := c.readFromStdin()
if err != nil {
c.logger.WithError(err).Error("stdin read error")
close(c.framesIn)
return err
}
if len(readBytes) == 0 {
time.Sleep(readInterval)
continue
}
c.logger.Tracef("read %d bytes", len(readBytes))
buffer.Write(readBytes)
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
rest := buffer.Bytes()
if !bytes.Contains(rest, delimiter) {
c.logger.Tracef("read data does not contain the delimiter %x", delimiter)
continue
}
for bytes.Contains(rest, delimiter) {
parts := bytes.SplitAfterN(rest, delimiter, 2)
frame, rest = parts[0], parts[1]
c.logger.Tracef("frame of length %d", len(frame))
if len(frame) == 0 {
continue
}
err = cobs.Verify(frame, cobsConfig)
if err != nil {
return fmt.Errorf("frame verification failed: %w", err)
}
decoded := cobs.Decode(frame, cobsConfig)
c.logger.Tracef("frame decoded to length %d", len(decoded))
c.framesIn <- decoded
c.logger.Tracef("%d bytes remaining", len(rest))
}
buffer.Truncate(0)
buffer.Write(rest)
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
}
func (c *clientSimulator) writeFrame(frame []byte) error {
encoded := cobs.Encode(frame, cobsConfig)
if _, err := io.Copy(os.Stdout, bytes.NewBuffer(encoded)); err != nil {
return fmt.Errorf("could not write data: %w", err)
}
return nil
}
func (c *clientSimulator) readFromStdin() ([]byte, error) {
const bufferSize = 1024
buf := make([]byte, bufferSize)
c.logger.Trace("waiting for input")
count, err := os.Stdin.Read(buf)
if err != nil {
return nil, fmt.Errorf("reading input failed: %w", err)
}
c.logger.Tracef("read %d bytes from stdin", count)
return buf[:count], nil
}
func (c *clientSimulator) writeCmdAnnouncement() error { func (c *clientSimulator) writeCmdAnnouncement() error {
frame, err := c.commandGenerator.CmdAnnouncement() frame, err := c.commandGenerator.CmdAnnouncement()
if err != nil { if err != nil {
@ -323,9 +223,7 @@ func (c *clientSimulator) writeCmdAnnouncement() error {
c.logger.Trace("writing command announcement") c.logger.Trace("writing command announcement")
if err := c.writeFrame(frame); err != nil { c.framesOut <- frame
return err
}
if err := c.nextState(); err != nil { if err := c.nextState(); err != nil {
return err return err
@ -342,9 +240,7 @@ func (c *clientSimulator) writeCommand() error {
c.logger.Trace("writing command data") c.logger.Trace("writing command data")
if err := c.writeFrame(frame); err != nil { c.framesOut <- frame
return err
}
if err := c.nextState(); err != nil { if err := c.nextState(); err != nil {
return err return err
@ -398,7 +294,13 @@ func (c *clientSimulator) Run(ctx context.Context) error {
errors := make(chan error) errors := make(chan error)
go func() { go func() {
err := c.readFrames() err := c.framer.ReadFrames(os.Stdin, c.framesIn)
errors <- err
}()
go func() {
err := c.framer.WriteFrames(os.Stdout, c.framesOut)
errors <- err errors <- err
}() }()
@ -482,6 +384,8 @@ func main() {
}, },
logger: logger, logger: logger,
framesIn: make(chan []byte), framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: protocol.NewCOBSFramer(logger),
} }
err := sim.Run(context.Background()) err := sim.Run(context.Background())

View file

@ -15,18 +15,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
// Package seriallink provides a handler for the serial connection of the signer machine. // Package serial provides a handler for the serial connection of the signer machine.
package serial package serial
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io"
"sync" "sync"
"time"
"github.com/justincpresley/go-cobs"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tarm/serial" "github.com/tarm/serial"
@ -68,11 +64,13 @@ func (p protocolState) String() string {
type Handler struct { type Handler struct {
protocolHandler protocol.Handler protocolHandler protocol.Handler
protocolState protocolState protocolState protocolState
framer protocol.Framer
config *serial.Config config *serial.Config
port *serial.Port port *serial.Port
logger *logrus.Logger logger *logrus.Logger
lock sync.Mutex lock sync.Mutex
framesIn chan []byte framesIn chan []byte
framesOut chan []byte
} }
func (h *Handler) setupConnection() error { func (h *Handler) setupConnection() error {
@ -95,14 +93,18 @@ func (h *Handler) Close() error {
return nil return nil
} }
var cobsConfig = cobs.Config{SpecialByte: protocol.CobsDelimiter, Delimiter: true, EndingSave: true}
func (h *Handler) Run(ctx context.Context) error { func (h *Handler) Run(ctx context.Context) error {
h.protocolState = cmdAnnounce h.protocolState = cmdAnnounce
errors := make(chan error) errors := make(chan error)
go func() { go func() {
err := h.readFrames() err := h.framer.ReadFrames(h.port, h.framesIn)
errors <- err
}()
go func() {
err := h.framer.WriteFrames(h.port, h.framesOut)
errors <- err errors <- err
}() }()
@ -127,79 +129,6 @@ func (h *Handler) Run(ctx context.Context) error {
} }
} }
func (h *Handler) readFrames() error {
const (
readInterval = 50 * time.Millisecond
)
var frame []byte
buffer := &bytes.Buffer{}
delimiter := []byte{cobsConfig.SpecialByte}
for {
readBytes, err := h.readFromPort()
if err != nil {
close(h.framesIn)
return err
}
if len(readBytes) == 0 {
time.Sleep(readInterval)
continue
}
h.logger.Tracef("read %d bytes", len(readBytes))
buffer.Write(readBytes)
h.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
rest := buffer.Bytes()
if !bytes.Contains(rest, delimiter) {
continue
}
for bytes.Contains(rest, delimiter) {
parts := bytes.SplitAfterN(rest, delimiter, 2)
frame, rest = parts[0], parts[1]
h.logger.Tracef("frame of length %d", len(frame))
if len(frame) == 0 {
continue
}
if err := cobs.Verify(frame, cobsConfig); err != nil {
close(h.framesIn)
return fmt.Errorf("could not verify COBS frame: %w", err)
}
decoded := cobs.Decode(frame, cobsConfig)
h.logger.Tracef("frame decoded to length %d", len(decoded))
h.framesIn <- decoded
}
buffer.Truncate(0)
buffer.Write(rest)
h.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
}
func (h *Handler) writeFrame(frame []byte) error {
encoded := cobs.Encode(frame, cobsConfig)
return h.writeToPort(encoded)
}
func (h *Handler) nextState() error { func (h *Handler) nextState() error {
next, ok := validTransitions[h.protocolState] next, ok := validTransitions[h.protocolState]
if !ok { if !ok {
@ -241,36 +170,11 @@ func (h *Handler) handleProtocolState() error {
return nil return nil
} }
func (h *Handler) writeToPort(data []byte) error {
reader := bytes.NewReader(data)
n, err := io.Copy(h.port, reader)
if err != nil {
return fmt.Errorf("could not write data: %w", err)
}
h.logger.Tracef("wrote %d bytes", n)
return nil
}
func (h *Handler) readFromPort() ([]byte, error) {
const bufferSize = 1024
buf := make([]byte, bufferSize)
count, err := h.port.Read(buf)
if err != nil {
return nil, fmt.Errorf("could not read from serial port: %w", err)
}
return buf[:count], nil
}
func (h *Handler) handleCmdAnnounce() error { func (h *Handler) handleCmdAnnounce() error {
h.logger.Trace("waiting for command announce") h.logger.Trace("waiting for command announce")
frame := <-h.framesIn frame := <-h.framesIn
if frame == nil { if frame == nil {
return nil return nil
} }
@ -312,9 +216,7 @@ func (h *Handler) handleRespAnnounce() error {
return fmt.Errorf("could not get response announcement: %w", err) return fmt.Errorf("could not get response announcement: %w", err)
} }
if err := h.writeFrame(frame); err != nil { h.framesOut <- frame
return err
}
if err := h.nextState(); err != nil { if err := h.nextState(); err != nil {
return err return err
@ -329,9 +231,7 @@ func (h *Handler) handleRespData() error {
return fmt.Errorf("could not get response data: %w", err) return fmt.Errorf("could not get response data: %w", err)
} }
if err := h.writeFrame(frame); err != nil { h.framesOut <- frame
return err
}
if err := h.nextState(); err != nil { if err := h.nextState(); err != nil {
return err return err
@ -345,6 +245,8 @@ func New(cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.Han
protocolHandler: protocolHandler, protocolHandler: protocolHandler,
logger: logger, logger: logger,
framesIn: make(chan []byte), framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: protocol.NewCOBSFramer(logger),
} }
h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout} h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout}

View file

@ -43,7 +43,7 @@ const (
var commandNames = map[CommandCode]string{ var commandNames = map[CommandCode]string{
CmdHealth: "HEALTH", CmdHealth: "HEALTH",
CmdFetchCRL: "FETCH URL", CmdFetchCRL: "FETCH CRL",
} }
func (c CommandCode) String() string { func (c CommandCode) String() string {

View file

@ -19,7 +19,14 @@ limitations under the License.
package protocol package protocol
import ( import (
"bytes"
"errors"
"fmt" "fmt"
"io"
"time"
"github.com/justincpresley/go-cobs"
"github.com/sirupsen/logrus"
"git.cacert.org/cacert-gosigner/pkg/messages" "git.cacert.org/cacert-gosigner/pkg/messages"
) )
@ -55,3 +62,115 @@ type Handler interface {
// ResponseData generates the response data. // ResponseData generates the response data.
ResponseData() ([]byte, error) ResponseData() ([]byte, error)
} }
// Framer handles bytes on the wire by adding or removing framing information.
type Framer interface {
// ReadFrames reads data frames and publishes unframed data to the channel.
ReadFrames(io.Reader, chan []byte) error
// WriteFrames takes data from the channel and writes framed data to the writer.
WriteFrames(io.Writer, chan []byte) error
}
const bufferSize = 1024
const readInterval = 50 * time.Millisecond
type COBSFramer struct {
config cobs.Config
logger *logrus.Logger
}
func NewCOBSFramer(logger *logrus.Logger) *COBSFramer {
return &COBSFramer{
config: cobs.Config{SpecialByte: CobsDelimiter, Delimiter: true, EndingSave: true},
logger: logger,
}
}
func (c *COBSFramer) ReadFrames(reader io.Reader, frameChan chan []byte) error {
var (
err error
raw, data, frame []byte
)
buffer := &bytes.Buffer{}
for {
raw, err = c.readRaw(reader)
if err != nil {
close(frameChan)
return err
}
if len(raw) == 0 {
time.Sleep(readInterval)
continue
}
c.logger.Tracef("read %d raw bytes", len(raw))
buffer.Write(raw)
for {
data, err = buffer.ReadBytes(c.config.SpecialByte)
if err != nil {
if errors.Is(err, io.EOF) {
buffer.Write(data)
break
}
return fmt.Errorf("could not read from buffer: %w", err)
}
if err = cobs.Verify(data, c.config); err != nil {
c.logger.WithError(err).Warnf("skipping invalid frame of %d bytes", len(data))
break
}
frame = cobs.Decode(data, c.config)
c.logger.Tracef("frame decoded to length %d", len(frame))
frameChan <- frame
}
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
}
}
func (c *COBSFramer) readRaw(reader io.Reader) ([]byte, error) {
buf := make([]byte, bufferSize)
count, err := reader.Read(buf)
if err != nil {
return nil, fmt.Errorf("could not read data: %w", err)
}
raw := buf[:count]
return raw, nil
}
func (c *COBSFramer) WriteFrames(writer io.Writer, frameChan chan []byte) error {
for {
frame := <-frameChan
if frame == nil {
c.logger.Debug("channel closed")
return nil
}
encoded := cobs.Encode(frame, c.config)
n, err := io.Copy(writer, bytes.NewReader(encoded))
if err != nil {
return fmt.Errorf("cold not write data: %w", err)
}
c.logger.Tracef("wrote %d bytes", n)
}
}