From 1374fe58e831efb833615c42d6531431f75c473c Mon Sep 17 00:00:00 2001 From: Jan Dittberner Date: Fri, 2 Dec 2022 12:54:07 +0100 Subject: [PATCH] Implement graceful shutdown on interrupt or SIGTERM --- cmd/signer/main.go | 37 +++++++++++++++++++++-- internal/serial/seriallink.go | 57 +++++++++++++++++++++++++++-------- pkg/protocol/protocol.go | 34 +++++++++++++-------- 3 files changed, 100 insertions(+), 28 deletions(-) diff --git a/cmd/signer/main.go b/cmd/signer/main.go index 44214cc..d4ec95c 100644 --- a/cmd/signer/main.go +++ b/cmd/signer/main.go @@ -22,9 +22,13 @@ import ( "flag" "fmt" "os" + "os/signal" + "syscall" "github.com/sirupsen/logrus" + "git.cacert.org/cacert-gosigner/pkg/protocol" + "git.cacert.org/cacert-gosigner/internal/config" "git.cacert.org/cacert-gosigner/internal/handler" "git.cacert.org/cacert-gosigner/internal/health" @@ -83,6 +87,11 @@ func main() { return } + framer, err := protocol.NewCOBSFramer(logger) + if err != nil { + logger.WithError(err).Fatal("could not create framer") + } + healthHandler := health.New(version, access) revokingRepositories, err := configureRepositories(caConfig, logger) @@ -101,7 +110,7 @@ func main() { logger.WithError(err).Fatal("could not setup protocol handler") } - serialHandler, err := serial.New(caConfig.GetSerial(), logger, proto) + serialHandler, err := serial.New(caConfig.GetSerial(), logger, framer, proto) if err != nil { logger.WithError(err).Fatal("could not setup serial link handler") } @@ -110,11 +119,33 @@ func main() { logger.Info("setup complete, starting signer operation") - if err = serialHandler.Run(context.Background()); err != nil { - logger.WithError(err).Fatal("error in serial handler") + if err = runSigner(logger, serialHandler); err != nil { + logger.WithError(err).Fatal("error running serial handler") } } +func runSigner(logger *logrus.Logger, serialHandler *serial.Handler) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + go func() { + <-c + + logger.Info("received shutdown signal") + + cancel() + }() + + if err := serialHandler.Run(ctx); err != nil { + return fmt.Errorf("error from serial handler: %w", err) + } + + return nil +} + func configureRepositories( caConfig *config.SignerConfig, logger *logrus.Logger, diff --git a/internal/serial/seriallink.go b/internal/serial/seriallink.go index ea1a2a3..5334a19 100644 --- a/internal/serial/seriallink.go +++ b/internal/serial/seriallink.go @@ -21,6 +21,7 @@ package serial import ( "context" "fmt" + "sync" "github.com/sirupsen/logrus" "github.com/tarm/serial" @@ -60,26 +61,60 @@ func (h *Handler) Close() error { } func (h *Handler) Run(ctx context.Context) error { + const componentCount = 3 + protocolErrors, framerErrors := make(chan error), make(chan error) + subCtx, cancel := context.WithCancel(ctx) + wg := sync.WaitGroup{} + wg.Add(componentCount) + + defer func() { + cancel() + h.logger.Info("context canceled waiting for shutdown of components") + wg.Wait() + h.logger.Info("shutdown complete") + }() + go func() { - err := h.framer.ReadFrames(ctx, h.port, h.framesIn) + defer wg.Done() + + err := h.framer.ReadFrames(subCtx, h.port, h.framesIn) - framerErrors <- err + h.logger.Info("frame reading stopped") + + select { + case framerErrors <- err: + case <-subCtx.Done(): + } }() go func() { - err := h.framer.WriteFrames(ctx, h.port, h.framesOut) + defer wg.Done() + + err := h.framer.WriteFrames(subCtx, h.port, h.framesOut) + + h.logger.Info("frame writing stopped") - framerErrors <- err + select { + case framerErrors <- err: + case <-subCtx.Done(): + } }() go func() { + defer wg.Done() + serverProtocol := protocol.NewServer(h.serverHandler, h.framesIn, h.framesOut, h.logger) - err := serverProtocol.Handle(ctx) + err := serverProtocol.Handle(subCtx) - protocolErrors <- err + h.logger.Info("server protocol stopped") + + select { + case protocolErrors <- err: + case <-subCtx.Done(): + } }() for { @@ -105,23 +140,19 @@ func (h *Handler) Run(ctx context.Context) error { func New( cfg *config.Serial, logger *logrus.Logger, + framer protocol.Framer, protocolHandler protocol.ServerHandler, ) (*Handler, error) { - cobsFramer, err := protocol.NewCOBSFramer(logger) - if err != nil { - return nil, fmt.Errorf("could not create COBS framer: %w", err) - } - h := &Handler{ serverHandler: protocolHandler, logger: logger, framesIn: make(chan []byte), framesOut: make(chan []byte), - framer: cobsFramer, + framer: framer, } h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout} - err = h.setupConnection() + err := h.setupConnection() if err != nil { return nil, err } diff --git a/pkg/protocol/protocol.go b/pkg/protocol/protocol.go index cb1c1c7..4f1371c 100644 --- a/pkg/protocol/protocol.go +++ b/pkg/protocol/protocol.go @@ -421,7 +421,7 @@ func (c *COBSFramer) ReadFrames(ctx context.Context, reader io.Reader, frameChan case <-ctx.Done(): return nil default: - raw, err = c.readRaw(reader) + raw, err = c.readRaw(ctx, reader) if err != nil { close(frameChan) @@ -469,21 +469,31 @@ func (c *COBSFramer) ReadFrames(ctx context.Context, reader io.Reader, frameChan } } -func (c *COBSFramer) readRaw(reader io.Reader) ([]byte, error) { - buf := make([]byte, bufferSize) +func (c *COBSFramer) readRaw(ctx context.Context, reader io.Reader) ([]byte, error) { + result := make(chan []byte) + errChan := make(chan error) - count, err := reader.Read(buf) - if err != nil { - if errors.Is(err, io.EOF) { - return buf[:count], nil - } + go func() { + buf := make([]byte, bufferSize) - return nil, fmt.Errorf("could not read data: %w", err) - } + count, err := reader.Read(buf) + if err != nil { + if !errors.Is(err, io.EOF) { + errChan <- fmt.Errorf("could not read data: %w", err) + } + } - raw := buf[:count] + result <- buf[:count] + }() - return raw, nil + select { + case <-ctx.Done(): + return nil, nil + case raw := <-result: + return raw, nil + case err := <-errChan: + return nil, err + } } func (c *COBSFramer) WriteFrames(ctx context.Context, writer io.Writer, frameChan chan []byte) error {