Implement graceful shutdown on interrupt or SIGTERM

This commit is contained in:
Jan Dittberner 2022-12-02 12:54:07 +01:00
parent 7837164e6e
commit 1374fe58e8
3 changed files with 101 additions and 29 deletions

View file

@ -22,9 +22,13 @@ import (
"flag"
"fmt"
"os"
"os/signal"
"syscall"
"github.com/sirupsen/logrus"
"git.cacert.org/cacert-gosigner/pkg/protocol"
"git.cacert.org/cacert-gosigner/internal/config"
"git.cacert.org/cacert-gosigner/internal/handler"
"git.cacert.org/cacert-gosigner/internal/health"
@ -83,6 +87,11 @@ func main() {
return
}
framer, err := protocol.NewCOBSFramer(logger)
if err != nil {
logger.WithError(err).Fatal("could not create framer")
}
healthHandler := health.New(version, access)
revokingRepositories, err := configureRepositories(caConfig, logger)
@ -101,7 +110,7 @@ func main() {
logger.WithError(err).Fatal("could not setup protocol handler")
}
serialHandler, err := serial.New(caConfig.GetSerial(), logger, proto)
serialHandler, err := serial.New(caConfig.GetSerial(), logger, framer, proto)
if err != nil {
logger.WithError(err).Fatal("could not setup serial link handler")
}
@ -110,11 +119,33 @@ func main() {
logger.Info("setup complete, starting signer operation")
if err = serialHandler.Run(context.Background()); err != nil {
logger.WithError(err).Fatal("error in serial handler")
if err = runSigner(logger, serialHandler); err != nil {
logger.WithError(err).Fatal("error running serial handler")
}
}
func runSigner(logger *logrus.Logger, serialHandler *serial.Handler) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
logger.Info("received shutdown signal")
cancel()
}()
if err := serialHandler.Run(ctx); err != nil {
return fmt.Errorf("error from serial handler: %w", err)
}
return nil
}
func configureRepositories(
caConfig *config.SignerConfig,
logger *logrus.Logger,

View file

@ -21,6 +21,7 @@ package serial
import (
"context"
"fmt"
"sync"
"github.com/sirupsen/logrus"
"github.com/tarm/serial"
@ -60,26 +61,60 @@ func (h *Handler) Close() error {
}
func (h *Handler) Run(ctx context.Context) error {
const componentCount = 3
protocolErrors, framerErrors := make(chan error), make(chan error)
go func() {
err := h.framer.ReadFrames(ctx, h.port, h.framesIn)
subCtx, cancel := context.WithCancel(ctx)
wg := sync.WaitGroup{}
wg.Add(componentCount)
framerErrors <- err
defer func() {
cancel()
h.logger.Info("context canceled waiting for shutdown of components")
wg.Wait()
h.logger.Info("shutdown complete")
}()
go func() {
err := h.framer.WriteFrames(ctx, h.port, h.framesOut)
defer wg.Done()
framerErrors <- err
err := h.framer.ReadFrames(subCtx, h.port, h.framesIn)
h.logger.Info("frame reading stopped")
select {
case framerErrors <- err:
case <-subCtx.Done():
}
}()
go func() {
defer wg.Done()
err := h.framer.WriteFrames(subCtx, h.port, h.framesOut)
h.logger.Info("frame writing stopped")
select {
case framerErrors <- err:
case <-subCtx.Done():
}
}()
go func() {
defer wg.Done()
serverProtocol := protocol.NewServer(h.serverHandler, h.framesIn, h.framesOut, h.logger)
err := serverProtocol.Handle(ctx)
err := serverProtocol.Handle(subCtx)
protocolErrors <- err
h.logger.Info("server protocol stopped")
select {
case protocolErrors <- err:
case <-subCtx.Done():
}
}()
for {
@ -105,23 +140,19 @@ func (h *Handler) Run(ctx context.Context) error {
func New(
cfg *config.Serial,
logger *logrus.Logger,
framer protocol.Framer,
protocolHandler protocol.ServerHandler,
) (*Handler, error) {
cobsFramer, err := protocol.NewCOBSFramer(logger)
if err != nil {
return nil, fmt.Errorf("could not create COBS framer: %w", err)
}
h := &Handler{
serverHandler: protocolHandler,
logger: logger,
framesIn: make(chan []byte),
framesOut: make(chan []byte),
framer: cobsFramer,
framer: framer,
}
h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout}
err = h.setupConnection()
err := h.setupConnection()
if err != nil {
return nil, err
}

View file

@ -421,7 +421,7 @@ func (c *COBSFramer) ReadFrames(ctx context.Context, reader io.Reader, frameChan
case <-ctx.Done():
return nil
default:
raw, err = c.readRaw(reader)
raw, err = c.readRaw(ctx, reader)
if err != nil {
close(frameChan)
@ -469,21 +469,31 @@ func (c *COBSFramer) ReadFrames(ctx context.Context, reader io.Reader, frameChan
}
}
func (c *COBSFramer) readRaw(reader io.Reader) ([]byte, error) {
buf := make([]byte, bufferSize)
func (c *COBSFramer) readRaw(ctx context.Context, reader io.Reader) ([]byte, error) {
result := make(chan []byte)
errChan := make(chan error)
count, err := reader.Read(buf)
if err != nil {
if errors.Is(err, io.EOF) {
return buf[:count], nil
go func() {
buf := make([]byte, bufferSize)
count, err := reader.Read(buf)
if err != nil {
if !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("could not read data: %w", err)
}
}
return nil, fmt.Errorf("could not read data: %w", err)
result <- buf[:count]
}()
select {
case <-ctx.Done():
return nil, nil
case raw := <-result:
return raw, nil
case err := <-errChan:
return nil, err
}
raw := buf[:count]
return raw, nil
}
func (c *COBSFramer) WriteFrames(ctx context.Context, writer io.Writer, frameChan chan []byte) error {