Implement protocol improvements
This commit implements a client and server side state machine for the serial protocol.
This commit is contained in:
parent
2de592d30c
commit
8e443bd8b4
5 changed files with 766 additions and 388 deletions
|
@ -23,6 +23,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -36,39 +37,135 @@ import (
|
|||
"git.cacert.org/cacert-gosigner/pkg/messages"
|
||||
)
|
||||
|
||||
const cobsDelimiter = 0x00
|
||||
var cobsConfig = cobs.Config{SpecialByte: 0x00, Delimiter: true, EndingSave: true}
|
||||
|
||||
var cobsConfig = cobs.Config{SpecialByte: cobsDelimiter, Delimiter: true, EndingSave: true}
|
||||
type protocolState int8
|
||||
|
||||
func main() {
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(os.Stderr)
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
const (
|
||||
cmdAnnounce protocolState = iota
|
||||
cmdData
|
||||
respAnnounce
|
||||
respData
|
||||
)
|
||||
|
||||
sim := &clientSimulator{
|
||||
logger: logger,
|
||||
var validTransitions = map[protocolState]protocolState{
|
||||
cmdAnnounce: cmdData,
|
||||
cmdData: respAnnounce,
|
||||
respAnnounce: respData,
|
||||
respData: cmdAnnounce,
|
||||
}
|
||||
|
||||
var protocolStateNames = map[protocolState]string{
|
||||
cmdAnnounce: "CMD ANNOUNCE",
|
||||
cmdData: "CMD DATA",
|
||||
respAnnounce: "RESP ANNOUNCE",
|
||||
respData: "RESP DATA",
|
||||
}
|
||||
|
||||
func (p protocolState) String() string {
|
||||
if name, ok := protocolStateNames[p]; ok {
|
||||
return name
|
||||
}
|
||||
|
||||
err := sim.Run()
|
||||
return fmt.Sprintf("unknown %d", p)
|
||||
}
|
||||
|
||||
type TestCommandGenerator struct {
|
||||
logger *logrus.Logger
|
||||
currentCommand *protocol.Command
|
||||
currentResponse *protocol.Response
|
||||
commands chan *protocol.Command
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (g *TestCommandGenerator) CmdAnnouncement() ([]byte, error) {
|
||||
g.lock.Lock()
|
||||
defer g.lock.Unlock()
|
||||
|
||||
select {
|
||||
case g.currentCommand = <-g.commands:
|
||||
announceData, err := msgpack.Marshal(g.currentCommand.Announce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not marshal command annoucement: %w", err)
|
||||
}
|
||||
|
||||
g.logger.WithField("announcement", &g.currentCommand.Announce).Info("write command announcement")
|
||||
|
||||
return announceData, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (g *TestCommandGenerator) CmdData() ([]byte, error) {
|
||||
g.lock.Lock()
|
||||
defer g.lock.Unlock()
|
||||
|
||||
cmdData, err := msgpack.Marshal(g.currentCommand.Command)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("simulator returned an error")
|
||||
return nil, fmt.Errorf("could not marshal command data: %w", err)
|
||||
}
|
||||
|
||||
g.logger.WithField("command", &g.currentCommand.Command).Info("write command data")
|
||||
|
||||
return cmdData, nil
|
||||
}
|
||||
|
||||
type clientSimulator struct {
|
||||
logger *logrus.Logger
|
||||
commands chan *protocol.Command
|
||||
responses chan [][]byte
|
||||
func (g *TestCommandGenerator) HandleResponseAnnounce(frame []byte) error {
|
||||
g.lock.Lock()
|
||||
defer g.lock.Unlock()
|
||||
|
||||
var ann messages.ResponseAnnounce
|
||||
|
||||
if err := msgpack.Unmarshal(frame, &ann); err != nil {
|
||||
return fmt.Errorf("could not unmarshal response announcement")
|
||||
}
|
||||
|
||||
g.logger.WithField("announcement", &ann).Info("received response announcement")
|
||||
|
||||
g.currentResponse = &protocol.Response{Announce: &ann}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) writeTestCommands(ctx context.Context) error {
|
||||
messages.RegisterGeneratedResolver()
|
||||
func (g *TestCommandGenerator) HandleResponse(frame []byte) error {
|
||||
g.lock.Lock()
|
||||
defer g.lock.Unlock()
|
||||
|
||||
switch g.currentResponse.Announce.Code {
|
||||
case messages.RespHealth:
|
||||
var response messages.HealthResponse
|
||||
|
||||
if err := msgpack.Unmarshal(frame, &response); err != nil {
|
||||
return fmt.Errorf("unmarshal failed: %w", err)
|
||||
}
|
||||
|
||||
g.currentResponse.Response = response
|
||||
case messages.RespFetchCRL:
|
||||
var response messages.FetchCRLResponse
|
||||
|
||||
if err := msgpack.Unmarshal(frame, &response); err != nil {
|
||||
return fmt.Errorf("unmarshal failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
g.logger.WithField(
|
||||
"command",
|
||||
g.currentCommand,
|
||||
).WithField(
|
||||
"response",
|
||||
g.currentResponse,
|
||||
).Info("handled health response")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *TestCommandGenerator) GenerateCommands(ctx context.Context) error {
|
||||
const healthInterval = 10 * time.Second
|
||||
|
||||
g.logger.Info("start generating commands")
|
||||
|
||||
time.Sleep(healthInterval)
|
||||
|
||||
c.commands <- &protocol.Command{
|
||||
g.commands <- &protocol.Command{
|
||||
Announce: messages.BuildCommandAnnounce(messages.CmdFetchCRL),
|
||||
Command: &messages.FetchCRLCommand{IssuerID: "sub-ecc_person_2022"},
|
||||
}
|
||||
|
@ -80,11 +177,11 @@ func (c *clientSimulator) writeTestCommands(ctx context.Context) error {
|
|||
case <-ctx.Done():
|
||||
_ = timer.Stop()
|
||||
|
||||
c.logger.Info("stopping health check loop")
|
||||
g.logger.Info("stopping health check loop")
|
||||
|
||||
return nil
|
||||
case <-timer.C:
|
||||
c.commands <- &protocol.Command{
|
||||
g.commands <- &protocol.Command{
|
||||
Announce: messages.BuildCommandAnnounce(messages.CmdHealth),
|
||||
Command: &messages.HealthCommand{},
|
||||
}
|
||||
|
@ -94,218 +191,292 @@ func (c *clientSimulator) writeTestCommands(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *clientSimulator) handleInput(ctx context.Context) error {
|
||||
const (
|
||||
bufferSize = 1024 * 1024
|
||||
readInterval = 50 * time.Millisecond
|
||||
)
|
||||
type clientSimulator struct {
|
||||
protocolState protocolState
|
||||
logger *logrus.Logger
|
||||
lock sync.Mutex
|
||||
framesIn chan []byte
|
||||
commandGenerator *TestCommandGenerator
|
||||
}
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
func (c *clientSimulator) readFrames() error {
|
||||
const readInterval = 50 * time.Millisecond
|
||||
|
||||
type protocolState int8
|
||||
var frame []byte
|
||||
buffer := &bytes.Buffer{}
|
||||
delimiter := []byte{cobsConfig.SpecialByte}
|
||||
|
||||
const (
|
||||
stAnn protocolState = iota
|
||||
stResp
|
||||
)
|
||||
|
||||
state := stAnn
|
||||
|
||||
var announce []byte
|
||||
|
||||
reading:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
count, err := os.Stdin.Read(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading input failed: %w", err)
|
||||
}
|
||||
readBytes, err := c.readFromStdin()
|
||||
if err != nil {
|
||||
c.logger.WithError(err).Error("stdin read error")
|
||||
|
||||
if count == 0 {
|
||||
time.Sleep(readInterval)
|
||||
close(c.framesIn)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if len(readBytes) == 0 {
|
||||
time.Sleep(readInterval)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Tracef("read %d bytes", len(readBytes))
|
||||
|
||||
buffer.Write(readBytes)
|
||||
|
||||
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
|
||||
|
||||
rest := buffer.Bytes()
|
||||
|
||||
if !bytes.Contains(rest, delimiter) {
|
||||
continue
|
||||
}
|
||||
|
||||
for bytes.Contains(rest, delimiter) {
|
||||
parts := bytes.SplitAfterN(rest, delimiter, 2)
|
||||
frame, rest = parts[0], parts[1]
|
||||
|
||||
c.logger.Tracef("frame of length %d", len(frame))
|
||||
|
||||
if len(frame) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
data := buf[:count]
|
||||
|
||||
for _, frame := range bytes.SplitAfter(data, []byte{cobsConfig.SpecialByte}) {
|
||||
if len(frame) == 0 {
|
||||
continue reading
|
||||
}
|
||||
|
||||
err = cobs.Verify(frame, cobsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("frame verification failed: %w", err)
|
||||
}
|
||||
|
||||
if state == stAnn {
|
||||
announce = cobs.Decode(frame, cobsConfig)
|
||||
|
||||
state = stResp
|
||||
} else {
|
||||
c.responses <- [][]byte{announce, cobs.Decode(frame, cobsConfig)}
|
||||
|
||||
state = stAnn
|
||||
}
|
||||
err = cobs.Verify(frame, cobsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("frame verification failed: %w", err)
|
||||
}
|
||||
|
||||
decoded := cobs.Decode(frame, cobsConfig)
|
||||
|
||||
c.logger.Tracef("frame decoded to length %d", len(decoded))
|
||||
|
||||
c.framesIn <- decoded
|
||||
}
|
||||
|
||||
buffer.Truncate(0)
|
||||
buffer.Write(rest)
|
||||
|
||||
c.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientSimulator) handleCommands(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case command := <-c.commands:
|
||||
if err := writeCommandAnnouncement(command); err != nil {
|
||||
return err
|
||||
}
|
||||
func (c *clientSimulator) writeFrame(frame []byte) error {
|
||||
encoded := cobs.Encode(frame, cobsConfig)
|
||||
|
||||
if err := writeCommand(command); err != nil {
|
||||
return err
|
||||
}
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
respData := <-c.responses
|
||||
if _, err := io.Copy(os.Stdout, bytes.NewBuffer(encoded)); err != nil {
|
||||
return fmt.Errorf("could not write data: %w", err)
|
||||
}
|
||||
|
||||
c.logger.WithField("respdata", respData).Trace("read response data")
|
||||
return nil
|
||||
}
|
||||
|
||||
response := &protocol.Response{}
|
||||
func (c *clientSimulator) readFromStdin() ([]byte, error) {
|
||||
const bufferSize = 1024
|
||||
|
||||
if err := msgpack.Unmarshal(respData[0], &response.Announce); err != nil {
|
||||
return fmt.Errorf("could not unmarshal response announcement: %w", err)
|
||||
}
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
if err := c.handleResponse(response, respData[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
count, err := os.Stdin.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading input failed: %w", err)
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return buf[:count], nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) writeCmdAnnouncement() error {
|
||||
frame, err := c.commandGenerator.CmdAnnouncement()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get command annoucement: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Trace("writing command announcement")
|
||||
|
||||
if err := c.writeFrame(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) writeCommandAnnouncement() error {
|
||||
frame, err := c.commandGenerator.CmdAnnouncement()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get command announcement: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Trace("writing command announcement")
|
||||
|
||||
if err := c.writeFrame(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) writeCommand() error {
|
||||
frame, err := c.commandGenerator.CmdData()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get command data: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Trace("writing command data")
|
||||
|
||||
if err := c.writeFrame(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) handleResponseAnnounce() error {
|
||||
c.logger.Trace("waiting for response announce")
|
||||
|
||||
select {
|
||||
case frame := <-c.framesIn:
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.commandGenerator.HandleResponseAnnounce(frame); err != nil {
|
||||
return fmt.Errorf("response announce handling failed: %w", err)
|
||||
}
|
||||
|
||||
if err := c.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) handleResponseData() error {
|
||||
c.logger.Trace("waiting for response data")
|
||||
|
||||
select {
|
||||
case frame := <-c.framesIn:
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.commandGenerator.HandleResponse(frame); err != nil {
|
||||
return fmt.Errorf("response handler failed: %w", err)
|
||||
}
|
||||
|
||||
if err := c.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) Run(ctx context.Context) error {
|
||||
c.protocolState = cmdAnnounce
|
||||
errors := make(chan error)
|
||||
|
||||
go func() {
|
||||
err := c.readFrames()
|
||||
|
||||
errors <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := c.commandGenerator.GenerateCommands(ctx)
|
||||
|
||||
errors <- err
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case err := <-errors:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error from handler loop: %w", err)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
if err := c.handleProtocolState(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientSimulator) handleResponse(response *protocol.Response, respBytes []byte) error {
|
||||
switch response.Announce.Code {
|
||||
case messages.RespHealth:
|
||||
data := messages.HealthResponse{}
|
||||
func (c *clientSimulator) handleProtocolState() error {
|
||||
c.logger.Tracef("handling protocol state %s", c.protocolState)
|
||||
|
||||
if err := msgpack.Unmarshal(respBytes, &data); err != nil {
|
||||
return fmt.Errorf("could not unmarshal health data: %w", err)
|
||||
switch c.protocolState {
|
||||
case cmdAnnounce:
|
||||
if err := c.writeCmdAnnouncement(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.WithField(
|
||||
"announce",
|
||||
response.Announce,
|
||||
).WithField(
|
||||
"data",
|
||||
&data,
|
||||
).Infof("received response")
|
||||
case messages.RespFetchCRL:
|
||||
data := messages.FetchCRLResponse{}
|
||||
|
||||
if err := msgpack.Unmarshal(respBytes, &data); err != nil {
|
||||
return fmt.Errorf("could not unmarshal fetch CRL data: %w", err)
|
||||
case cmdData:
|
||||
if err := c.writeCommand(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.WithField(
|
||||
"announce",
|
||||
response.Announce,
|
||||
).WithField(
|
||||
"data",
|
||||
&data,
|
||||
).Infof("received response")
|
||||
case messages.RespError:
|
||||
data := messages.ErrorResponse{}
|
||||
|
||||
if err := msgpack.Unmarshal(respBytes, &data); err != nil {
|
||||
return fmt.Errorf("could not unmarshal error data: %w", err)
|
||||
case respAnnounce:
|
||||
if err := c.handleResponseAnnounce(); err != nil {
|
||||
return err
|
||||
}
|
||||
case respData:
|
||||
if err := c.handleResponseData(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.WithField(
|
||||
"announce",
|
||||
response.Announce,
|
||||
).WithField(
|
||||
"data",
|
||||
&data,
|
||||
).Infof("received response")
|
||||
default:
|
||||
if err := msgpack.Unmarshal(respBytes, &response.Response); err != nil {
|
||||
return fmt.Errorf("could not unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
c.logger.WithField("response", response).Infof("received response")
|
||||
return fmt.Errorf("unknown protocol state %s", c.protocolState)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeCommandAnnouncement(command *protocol.Command) error {
|
||||
cmdAnnounceBytes, err := msgpack.Marshal(&command.Announce)
|
||||
func (c *clientSimulator) nextState() error {
|
||||
next, ok := validTransitions[c.protocolState]
|
||||
if !ok {
|
||||
return fmt.Errorf("illegal protocol state %s", c.protocolState)
|
||||
}
|
||||
|
||||
c.protocolState = next
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(os.Stderr)
|
||||
logger.SetLevel(logrus.TraceLevel)
|
||||
|
||||
messages.RegisterGeneratedResolver()
|
||||
|
||||
sim := &clientSimulator{
|
||||
commandGenerator: &TestCommandGenerator{
|
||||
logger: logger,
|
||||
commands: make(chan *protocol.Command, 0),
|
||||
},
|
||||
logger: logger,
|
||||
framesIn: make(chan []byte, 0),
|
||||
}
|
||||
|
||||
err := sim.Run(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal command announcement bytes: %w", err)
|
||||
logger.WithError(err).Error("simulator returned an error")
|
||||
}
|
||||
|
||||
if _, err = os.Stdout.Write(cobs.Encode(cmdAnnounceBytes, cobsConfig)); err != nil {
|
||||
return fmt.Errorf("command announcement write failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeCommand(command *protocol.Command) error {
|
||||
cmdBytes, err := msgpack.Marshal(&command.Command)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal command bytes: %w", err)
|
||||
}
|
||||
|
||||
if _, err = os.Stdout.Write(cobs.Encode(cmdBytes, cobsConfig)); err != nil {
|
||||
return fmt.Errorf("command write failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientSimulator) Run() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
c.commands = make(chan *protocol.Command)
|
||||
c.responses = make(chan [][]byte)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
if err := c.handleInput(ctx); err != nil {
|
||||
c.logger.WithError(err).Error("input handling failed")
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err := c.handleCommands(ctx); err != nil {
|
||||
c.logger.WithError(err).Error("command handling failed")
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
var result error
|
||||
|
||||
if err := c.writeTestCommands(ctx); err != nil {
|
||||
c.logger.WithError(err).Error("test commands failed")
|
||||
}
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
|
||||
return result
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -101,18 +102,18 @@ func main() {
|
|||
logger.WithError(err).Fatal("could not setup protocol handler")
|
||||
}
|
||||
|
||||
serialHandler, err := seriallink.New(caConfig.GetSerial(), proto)
|
||||
serialHandler, err := seriallink.New(caConfig.GetSerial(), logger, proto)
|
||||
if err != nil {
|
||||
logger.WithError(err).Fatal("could not setup serial link handler")
|
||||
}
|
||||
|
||||
defer func() { _ = serialHandler.Close() }()
|
||||
|
||||
if err = serialHandler.Run(); err != nil {
|
||||
logger.Info("setup complete, starting signer operation")
|
||||
|
||||
if err = serialHandler.Run(context.Background()); err != nil {
|
||||
logger.WithError(err).Fatal("error in serial handler")
|
||||
}
|
||||
|
||||
logger.Info("setup complete, starting signer operation")
|
||||
}
|
||||
|
||||
func configureRepositories(
|
||||
|
|
|
@ -162,8 +162,9 @@ func TestPrivateKeyInfo_UnmarshalYAML(t *testing.T) {
|
|||
algorithm: "RSA"
|
||||
rsa-bits: 2048`,
|
||||
expected: &config.PrivateKeyInfo{
|
||||
Algorithm: x509.RSA,
|
||||
RSABits: 2048,
|
||||
Algorithm: x509.RSA,
|
||||
RSABits: 2048,
|
||||
CRLSignatureAlgorithm: x509.SHA256WithRSA,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -172,8 +173,9 @@ rsa-bits: 2048`,
|
|||
algorithm: "EC"
|
||||
ecc-curve: "P-224"`,
|
||||
expected: &config.PrivateKeyInfo{
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P224(),
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P224(),
|
||||
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -182,8 +184,9 @@ ecc-curve: "P-224"`,
|
|||
algorithm: "EC"
|
||||
ecc-curve: "P-256"`,
|
||||
expected: &config.PrivateKeyInfo{
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P256(),
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P256(),
|
||||
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -192,8 +195,9 @@ ecc-curve: "P-256"`,
|
|||
algorithm: "EC"
|
||||
ecc-curve: "P-384"`,
|
||||
expected: &config.PrivateKeyInfo{
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P384(),
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P384(),
|
||||
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -202,8 +206,9 @@ ecc-curve: "P-384"`,
|
|||
algorithm: "EC"
|
||||
ecc-curve: "P-521"`,
|
||||
expected: &config.PrivateKeyInfo{
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P521(),
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P521(),
|
||||
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -308,8 +313,9 @@ storage: root
|
|||
`,
|
||||
expected: config.CaCertificateEntry{
|
||||
KeyInfo: &config.PrivateKeyInfo{
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P521(),
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P521(),
|
||||
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
},
|
||||
CommonName: "My Little Test Root CA",
|
||||
Storage: "root",
|
||||
|
@ -329,8 +335,9 @@ ext-key-usages:
|
|||
`,
|
||||
expected: config.CaCertificateEntry{
|
||||
KeyInfo: &config.PrivateKeyInfo{
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P256(),
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P256(),
|
||||
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
},
|
||||
Parent: "root",
|
||||
CommonName: "My Little Test Sub CA",
|
||||
|
@ -357,8 +364,9 @@ ext-key-usages:
|
|||
`,
|
||||
expected: config.CaCertificateEntry{
|
||||
KeyInfo: &config.PrivateKeyInfo{
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P256(),
|
||||
Algorithm: x509.ECDSA,
|
||||
EccCurve: elliptic.P256(),
|
||||
CRLSignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
},
|
||||
CommonName: "My Little Test Sub CA",
|
||||
Storage: "default",
|
||||
|
|
|
@ -21,6 +21,7 @@ package protocol
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/shamaton/msgpackgen/msgpack"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -31,82 +32,6 @@ import (
|
|||
"git.cacert.org/cacert-gosigner/pkg/messages"
|
||||
)
|
||||
|
||||
// Handler is responsible for parsing incoming frames and calling commands
|
||||
type Handler interface {
|
||||
HandleCommandAnnounce([]byte) (*messages.CommandAnnounce, error)
|
||||
HandleCommand(*messages.CommandAnnounce, []byte) ([]byte, []byte, error)
|
||||
}
|
||||
|
||||
type MsgPackHandler struct {
|
||||
logger *logrus.Logger
|
||||
healthHandler *health.Handler
|
||||
fetchCRLHandler *revoking.FetchCRLHandler
|
||||
}
|
||||
|
||||
func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) (*messages.CommandAnnounce, error) {
|
||||
var ann messages.CommandAnnounce
|
||||
|
||||
if err := msgpack.Unmarshal(frame, &ann); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal command announcement: %w", err)
|
||||
}
|
||||
|
||||
m.logger.Infof("received command announcement %+v", ann)
|
||||
|
||||
return &ann, nil
|
||||
}
|
||||
|
||||
func (m *MsgPackHandler) HandleCommand(announce *messages.CommandAnnounce, frame []byte) ([]byte, []byte, error) {
|
||||
var (
|
||||
response *Response
|
||||
clientError, err error
|
||||
)
|
||||
|
||||
switch announce.Code {
|
||||
case messages.CmdHealth:
|
||||
// health has no payload, ignore the frame
|
||||
response, err = m.handleCommand(&Command{Announce: announce, Command: nil})
|
||||
if err != nil {
|
||||
m.logger.WithError(err).Error("health handling failed")
|
||||
|
||||
clientError = errors.New("could not handle request")
|
||||
}
|
||||
case messages.CmdFetchCRL:
|
||||
var command messages.FetchCRLCommand
|
||||
|
||||
err = msgpack.Unmarshal(frame, &command)
|
||||
if err != nil {
|
||||
m.logger.WithError(err).Error("unmarshal failed")
|
||||
|
||||
clientError = errors.New("could not unmarshal fetch crl command")
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
response, err = m.handleCommand(&Command{Announce: announce, Command: command})
|
||||
if err != nil {
|
||||
m.logger.WithError(err).Error("fetch CRL handling failed")
|
||||
|
||||
clientError = errors.New("could not handle request")
|
||||
}
|
||||
}
|
||||
|
||||
if clientError != nil {
|
||||
response = buildErrorResponse(clientError.Error())
|
||||
}
|
||||
|
||||
announceData, err := msgpack.Marshal(response.Announce)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("could not marshal response announcement: %w", err)
|
||||
}
|
||||
|
||||
responseData, err := msgpack.Marshal(response.Response)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("could not marshal response: %w", err)
|
||||
}
|
||||
|
||||
return announceData, responseData, nil
|
||||
}
|
||||
|
||||
type Command struct {
|
||||
Announce *messages.CommandAnnounce
|
||||
Command interface{}
|
||||
|
@ -121,23 +46,146 @@ func (r *Response) String() string {
|
|||
return fmt.Sprintf("Response[Code=%s] created=%s data=%s", r.Announce.Code, r.Announce.Created, r.Response)
|
||||
}
|
||||
|
||||
func (m *MsgPackHandler) handleCommand(command *Command) (*Response, error) {
|
||||
// Handler is responsible for parsing incoming frames and calling commands
|
||||
type Handler interface {
|
||||
HandleCommandAnnounce([]byte) error
|
||||
HandleCommand([]byte) error
|
||||
ResponseAnnounce() ([]byte, error)
|
||||
ResponseData() ([]byte, error)
|
||||
}
|
||||
|
||||
type MsgPackHandler struct {
|
||||
logger *logrus.Logger
|
||||
healthHandler *health.Handler
|
||||
fetchCRLHandler *revoking.FetchCRLHandler
|
||||
currentCommand *Command
|
||||
currentResponse *Response
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (m *MsgPackHandler) HandleCommandAnnounce(frame []byte) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
var ann messages.CommandAnnounce
|
||||
|
||||
if err := msgpack.Unmarshal(frame, &ann); err != nil {
|
||||
return fmt.Errorf("could not unmarshal command announcement: %w", err)
|
||||
}
|
||||
|
||||
m.logger.WithField("announcement", &ann).Info("received command announcement")
|
||||
|
||||
m.currentCommand = &Command{Announce: &ann}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MsgPackHandler) HandleCommand(frame []byte) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
var clientError error
|
||||
|
||||
switch m.currentCommand.Announce.Code {
|
||||
case messages.CmdHealth:
|
||||
// health has no payload, ignore the frame
|
||||
response, err := m.handleCommand()
|
||||
if err != nil {
|
||||
m.logger.WithError(err).Error("health handling failed")
|
||||
|
||||
clientError = errors.New("could not handle request")
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
m.currentResponse = response
|
||||
case messages.CmdFetchCRL:
|
||||
var command messages.FetchCRLCommand
|
||||
|
||||
if err := msgpack.Unmarshal(frame, &command); err != nil {
|
||||
m.logger.WithError(err).Error("unmarshal failed")
|
||||
|
||||
clientError = errors.New("could not unmarshal fetch crl command")
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
m.currentCommand.Command = command
|
||||
|
||||
response, err := m.handleCommand()
|
||||
if err != nil {
|
||||
m.logger.WithError(err).Error("fetch CRL handling failed")
|
||||
|
||||
clientError = errors.New("could not handle request")
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
m.currentResponse = response
|
||||
}
|
||||
|
||||
if clientError != nil {
|
||||
m.currentResponse = buildErrorResponse(clientError.Error())
|
||||
}
|
||||
|
||||
m.logger.WithField(
|
||||
"command",
|
||||
m.currentCommand,
|
||||
).WithField(
|
||||
"response",
|
||||
m.currentResponse,
|
||||
).Info("handled command")
|
||||
|
||||
m.currentCommand = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MsgPackHandler) ResponseAnnounce() ([]byte, error) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
announceData, err := msgpack.Marshal(m.currentResponse.Announce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not marshal response announcement: %w", err)
|
||||
}
|
||||
|
||||
m.logger.WithField("announcement", &m.currentResponse.Announce).Info("write response announcement")
|
||||
|
||||
return announceData, nil
|
||||
}
|
||||
|
||||
func (m *MsgPackHandler) ResponseData() ([]byte, error) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
responseData, err := msgpack.Marshal(m.currentResponse.Response)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not marshal response: %w", err)
|
||||
}
|
||||
|
||||
m.logger.WithField("response", &m.currentResponse.Response).Info("write response")
|
||||
|
||||
return responseData, nil
|
||||
}
|
||||
|
||||
func (m *MsgPackHandler) handleCommand() (*Response, error) {
|
||||
var (
|
||||
err error
|
||||
responseData interface{}
|
||||
responseCode messages.ResponseCode
|
||||
)
|
||||
|
||||
switch command.Announce.Code {
|
||||
switch m.currentCommand.Announce.Code {
|
||||
case messages.CmdHealth:
|
||||
var res *health.Result
|
||||
|
||||
res, err = m.healthHandler.CheckHealth()
|
||||
if err != nil {
|
||||
break
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := messages.HealthResponse{
|
||||
response := &messages.HealthResponse{
|
||||
Version: res.Version,
|
||||
Healthy: res.Healthy,
|
||||
}
|
||||
|
@ -154,24 +202,24 @@ func (m *MsgPackHandler) handleCommand(command *Command) (*Response, error) {
|
|||
case messages.CmdFetchCRL:
|
||||
var res *revoking.Result
|
||||
|
||||
fetchCRLPayload, ok := command.Command.(messages.FetchCRLCommand)
|
||||
fetchCRLPayload, ok := m.currentCommand.Command.(messages.FetchCRLCommand)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not use payload as FetchCRLPayload")
|
||||
}
|
||||
|
||||
res, err = m.fetchCRLHandler.FetchCRL(fetchCRLPayload.IssuerID)
|
||||
if err != nil {
|
||||
break
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := messages.FetchCRLResponse{
|
||||
response := &messages.FetchCRLResponse{
|
||||
IsDelta: false,
|
||||
CRLData: res.Data,
|
||||
}
|
||||
|
||||
responseCode, responseData = messages.RespFetchCRL, response
|
||||
default:
|
||||
return nil, fmt.Errorf("unhandled command %v", command)
|
||||
return nil, fmt.Errorf("unhandled command %s", m.currentCommand.Announce)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -187,7 +235,7 @@ func (m *MsgPackHandler) handleCommand(command *Command) (*Response, error) {
|
|||
func buildErrorResponse(errMsg string) *Response {
|
||||
return &Response{
|
||||
Announce: messages.BuildResponseAnnounce(messages.RespError),
|
||||
Response: messages.ErrorResponse{Message: errMsg},
|
||||
Response: &messages.ErrorResponse{Message: errMsg},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,10 +20,14 @@ package seriallink
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/justincpresley/go-cobs"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tarm/serial"
|
||||
|
||||
"git.cacert.org/cacert-gosigner/pkg/config"
|
||||
|
@ -33,16 +37,42 @@ import (
|
|||
type protocolState int8
|
||||
|
||||
const (
|
||||
stAnnounce protocolState = iota
|
||||
stCommand
|
||||
cmdAnnounce protocolState = iota
|
||||
cmdData
|
||||
respAnnounce
|
||||
respData
|
||||
)
|
||||
|
||||
var validTransitions = map[protocolState]protocolState{
|
||||
cmdAnnounce: cmdData,
|
||||
cmdData: respAnnounce,
|
||||
respAnnounce: respData,
|
||||
respData: cmdAnnounce,
|
||||
}
|
||||
|
||||
var protocolStateNames = map[protocolState]string{
|
||||
cmdAnnounce: "CMD ANNOUNCE",
|
||||
cmdData: "CMD DATA",
|
||||
respAnnounce: "RESP ANNOUNCE",
|
||||
respData: "RESP DATA",
|
||||
}
|
||||
|
||||
func (p protocolState) String() string {
|
||||
if name, ok := protocolStateNames[p]; ok {
|
||||
return name
|
||||
}
|
||||
|
||||
return fmt.Sprintf("unknown %d", p)
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
protocolHandler protocol.Handler
|
||||
protocolState protocolState
|
||||
currentCommand *protocol.Command
|
||||
config *serial.Config
|
||||
port *serial.Port
|
||||
logger *logrus.Logger
|
||||
lock sync.Mutex
|
||||
framesIn chan []byte
|
||||
}
|
||||
|
||||
func (h *Handler) setupConnection() error {
|
||||
|
@ -65,78 +95,191 @@ func (h *Handler) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
const cobsDelimiter = 0x00
|
||||
var cobsConfig = cobs.Config{SpecialByte: 0x00, Delimiter: true, EndingSave: true}
|
||||
|
||||
var cobsConfig = cobs.Config{SpecialByte: cobsDelimiter, Delimiter: true, EndingSave: true}
|
||||
func (h *Handler) Run(ctx context.Context) error {
|
||||
h.protocolState = cmdAnnounce
|
||||
errors := make(chan error)
|
||||
|
||||
func (h *Handler) Run() error {
|
||||
go func() {
|
||||
err := h.readFrames()
|
||||
|
||||
errors <- err
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
|
||||
case err := <-errors:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error from handler loop: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
default:
|
||||
if err := h.handleProtocolState(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) readFrames() error {
|
||||
const (
|
||||
bufferSize = 1024 * 1024
|
||||
readInterval = 50 * time.Millisecond
|
||||
)
|
||||
|
||||
errors := make(chan error)
|
||||
var frame []byte
|
||||
buffer := &bytes.Buffer{}
|
||||
delimiter := []byte{cobsConfig.SpecialByte}
|
||||
|
||||
h.protocolState = stAnnounce
|
||||
for {
|
||||
readBytes, err := h.readFromPort()
|
||||
if err != nil {
|
||||
close(h.framesIn)
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, bufferSize)
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
count, err := h.port.Read(buf)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
if len(readBytes) == 0 {
|
||||
time.Sleep(readInterval)
|
||||
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
time.Sleep(readInterval)
|
||||
h.logger.Tracef("read %d bytes", len(readBytes))
|
||||
|
||||
buffer.Write(readBytes)
|
||||
|
||||
h.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
|
||||
|
||||
rest := buffer.Bytes()
|
||||
|
||||
if !bytes.Contains(rest, delimiter) {
|
||||
continue
|
||||
}
|
||||
|
||||
for bytes.Contains(rest, delimiter) {
|
||||
parts := bytes.SplitAfterN(rest, delimiter, 2)
|
||||
frame, rest = parts[0], parts[1]
|
||||
|
||||
h.logger.Tracef("frame of length %d", len(frame))
|
||||
|
||||
if len(frame) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
frames := bytes.SplitAfter(buf[:count], []byte{cobsDelimiter})
|
||||
if err := cobs.Verify(frame, cobsConfig); err != nil {
|
||||
close(h.framesIn)
|
||||
|
||||
if err := h.handleFrames(frames); err != nil {
|
||||
errors <- err
|
||||
|
||||
return
|
||||
return fmt.Errorf("could not verify COBS frame: %w", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err := <-errors
|
||||
if err != nil {
|
||||
return fmt.Errorf("error from handler loop: %w", err)
|
||||
decoded := cobs.Decode(frame, cobsConfig)
|
||||
|
||||
h.logger.Tracef("frame decoded to length %d", len(decoded))
|
||||
|
||||
h.framesIn <- decoded
|
||||
}
|
||||
|
||||
buffer.Truncate(0)
|
||||
buffer.Write(rest)
|
||||
|
||||
h.logger.Tracef("read buffer is now %d bytes long", buffer.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) writeFrame(frame []byte) error {
|
||||
encoded := cobs.Encode(frame, cobsConfig)
|
||||
|
||||
return h.writeToPort(encoded)
|
||||
}
|
||||
|
||||
func (h *Handler) nextState() error {
|
||||
next, ok := validTransitions[h.protocolState]
|
||||
if !ok {
|
||||
return fmt.Errorf("illegal protocol state %s", h.protocolState)
|
||||
}
|
||||
|
||||
h.protocolState = next
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) handleProtocolState() error {
|
||||
h.logger.Tracef("handling protocol state %s", h.protocolState)
|
||||
|
||||
switch h.protocolState {
|
||||
case cmdAnnounce:
|
||||
if err := h.handleCmdAnnounce(); err != nil {
|
||||
return err
|
||||
}
|
||||
case cmdData:
|
||||
if err := h.handleCmdData(); err != nil {
|
||||
return err
|
||||
}
|
||||
case respAnnounce:
|
||||
if err := h.handleRespAnnounce(); err != nil {
|
||||
return err
|
||||
}
|
||||
case respData:
|
||||
if err := h.handleRespData(); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown protocol state %s", h.protocolState)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) handleFrames(frames [][]byte) error {
|
||||
for _, frame := range frames {
|
||||
if len(frame) == 0 {
|
||||
func (h *Handler) writeToPort(data []byte) error {
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
reader := bytes.NewReader(data)
|
||||
|
||||
n, err := io.Copy(h.port, reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not write data: %w", err)
|
||||
}
|
||||
|
||||
h.logger.Tracef("wrote %d bytes", n)
|
||||
|
||||
if err := h.port.Flush(); err != nil {
|
||||
return fmt.Errorf("could not flush data: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) readFromPort() ([]byte, error) {
|
||||
const bufferSize = 1024
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
count, err := h.port.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read from serial port: %w", err)
|
||||
}
|
||||
|
||||
return buf[:count], nil
|
||||
}
|
||||
|
||||
func (h *Handler) handleCmdAnnounce() error {
|
||||
h.logger.Trace("waiting for command announce")
|
||||
|
||||
select {
|
||||
case frame := <-h.framesIn:
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := cobs.Verify(frame, cobsConfig); err != nil {
|
||||
return fmt.Errorf("could not verify COBS frame: %w", err)
|
||||
}
|
||||
|
||||
// perform COBS decoding
|
||||
decoded := cobs.Decode(frame, cobsConfig)
|
||||
|
||||
if h.protocolState == stAnnounce {
|
||||
if err := h.handleCommandAnnounce(decoded); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if h.protocolState == stCommand {
|
||||
if err := h.handleCommandData(decoded); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.protocolHandler.HandleCommandAnnounce(frame); err != nil {
|
||||
return fmt.Errorf("command announce handling failed: %w", err)
|
||||
}
|
||||
|
||||
if err := h.nextState(); err != nil {
|
||||
|
@ -147,65 +290,72 @@ func (h *Handler) handleFrames(frames [][]byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) handleCommandData(decoded []byte) error {
|
||||
respAnn, msg, err := h.protocolHandler.HandleCommand(h.currentCommand.Announce, decoded)
|
||||
if err != nil {
|
||||
return fmt.Errorf("command handler for %s failed: %w", h.currentCommand.Announce.Code, err)
|
||||
func (h *Handler) handleCmdData() error {
|
||||
h.logger.Trace("waiting for command data")
|
||||
|
||||
select {
|
||||
case frame := <-h.framesIn:
|
||||
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := h.protocolHandler.HandleCommand(frame); err != nil {
|
||||
return fmt.Errorf("command handler failed: %w", err)
|
||||
}
|
||||
|
||||
if err := h.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.writeResponse(respAnn, msg, cobsConfig); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) handleRespAnnounce() error {
|
||||
frame, err := h.protocolHandler.ResponseAnnounce()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get response announcement: %w", err)
|
||||
}
|
||||
|
||||
h.logger.Trace("writing response announce")
|
||||
|
||||
if err := h.writeFrame(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) handleCommandAnnounce(decoded []byte) error {
|
||||
announce, err := h.protocolHandler.HandleCommandAnnounce(decoded)
|
||||
func (h *Handler) handleRespData() error {
|
||||
frame, err := h.protocolHandler.ResponseData()
|
||||
if err != nil {
|
||||
return fmt.Errorf("command announce handling failed: %w", err)
|
||||
return fmt.Errorf("could not get response data: %w", err)
|
||||
}
|
||||
|
||||
h.currentCommand = &protocol.Command{Announce: announce}
|
||||
h.logger.Trace("writing response data")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) writeResponse(ann, msg []byte, cobsConfig cobs.Config) error {
|
||||
encoded := cobs.Encode(ann, cobsConfig)
|
||||
|
||||
if _, err := h.port.Write(encoded); err != nil {
|
||||
return fmt.Errorf("could not write response announcement: %w", err)
|
||||
if err := h.writeFrame(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoded = cobs.Encode(msg, cobsConfig)
|
||||
|
||||
if _, err := h.port.Write(encoded); err != nil {
|
||||
return fmt.Errorf("could not write response: %w", err)
|
||||
if err := h.nextState(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) nextState() error {
|
||||
var next protocolState
|
||||
|
||||
switch h.protocolState {
|
||||
case stAnnounce:
|
||||
next = stCommand
|
||||
case stCommand:
|
||||
next = stAnnounce
|
||||
default:
|
||||
return fmt.Errorf("illegal protocol state %d", int(h.protocolState))
|
||||
func New(cfg *config.Serial, logger *logrus.Logger, protocolHandler protocol.Handler) (*Handler, error) {
|
||||
h := &Handler{
|
||||
protocolHandler: protocolHandler,
|
||||
logger: logger,
|
||||
framesIn: make(chan []byte, 0),
|
||||
}
|
||||
|
||||
h.protocolState = next
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func New(cfg *config.Serial, protocolHandler protocol.Handler) (*Handler, error) {
|
||||
h := &Handler{protocolHandler: protocolHandler}
|
||||
h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout}
|
||||
|
||||
err := h.setupConnection()
|
||||
|
|
Loading…
Reference in a new issue