Implement graceful shutdown on interrupt or SIGTERM
This commit is contained in:
parent
7837164e6e
commit
1374fe58e8
3 changed files with 101 additions and 29 deletions
|
@ -22,9 +22,13 @@ import (
|
|||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"git.cacert.org/cacert-gosigner/pkg/protocol"
|
||||
|
||||
"git.cacert.org/cacert-gosigner/internal/config"
|
||||
"git.cacert.org/cacert-gosigner/internal/handler"
|
||||
"git.cacert.org/cacert-gosigner/internal/health"
|
||||
|
@ -83,6 +87,11 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
framer, err := protocol.NewCOBSFramer(logger)
|
||||
if err != nil {
|
||||
logger.WithError(err).Fatal("could not create framer")
|
||||
}
|
||||
|
||||
healthHandler := health.New(version, access)
|
||||
|
||||
revokingRepositories, err := configureRepositories(caConfig, logger)
|
||||
|
@ -101,7 +110,7 @@ func main() {
|
|||
logger.WithError(err).Fatal("could not setup protocol handler")
|
||||
}
|
||||
|
||||
serialHandler, err := serial.New(caConfig.GetSerial(), logger, proto)
|
||||
serialHandler, err := serial.New(caConfig.GetSerial(), logger, framer, proto)
|
||||
if err != nil {
|
||||
logger.WithError(err).Fatal("could not setup serial link handler")
|
||||
}
|
||||
|
@ -110,11 +119,33 @@ func main() {
|
|||
|
||||
logger.Info("setup complete, starting signer operation")
|
||||
|
||||
if err = serialHandler.Run(context.Background()); err != nil {
|
||||
logger.WithError(err).Fatal("error in serial handler")
|
||||
if err = runSigner(logger, serialHandler); err != nil {
|
||||
logger.WithError(err).Fatal("error running serial handler")
|
||||
}
|
||||
}
|
||||
|
||||
func runSigner(logger *logrus.Logger, serialHandler *serial.Handler) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-c
|
||||
|
||||
logger.Info("received shutdown signal")
|
||||
|
||||
cancel()
|
||||
}()
|
||||
|
||||
if err := serialHandler.Run(ctx); err != nil {
|
||||
return fmt.Errorf("error from serial handler: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func configureRepositories(
|
||||
caConfig *config.SignerConfig,
|
||||
logger *logrus.Logger,
|
||||
|
|
|
@ -21,6 +21,7 @@ package serial
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tarm/serial"
|
||||
|
@ -60,26 +61,60 @@ func (h *Handler) Close() error {
|
|||
}
|
||||
|
||||
func (h *Handler) Run(ctx context.Context) error {
|
||||
const componentCount = 3
|
||||
|
||||
protocolErrors, framerErrors := make(chan error), make(chan error)
|
||||
|
||||
go func() {
|
||||
err := h.framer.ReadFrames(ctx, h.port, h.framesIn)
|
||||
subCtx, cancel := context.WithCancel(ctx)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(componentCount)
|
||||
|
||||
framerErrors <- err
|
||||
defer func() {
|
||||
cancel()
|
||||
h.logger.Info("context canceled waiting for shutdown of components")
|
||||
wg.Wait()
|
||||
h.logger.Info("shutdown complete")
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := h.framer.WriteFrames(ctx, h.port, h.framesOut)
|
||||
defer wg.Done()
|
||||
|
||||
framerErrors <- err
|
||||
err := h.framer.ReadFrames(subCtx, h.port, h.framesIn)
|
||||
|
||||
h.logger.Info("frame reading stopped")
|
||||
|
||||
select {
|
||||
case framerErrors <- err:
|
||||
case <-subCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
err := h.framer.WriteFrames(subCtx, h.port, h.framesOut)
|
||||
|
||||
h.logger.Info("frame writing stopped")
|
||||
|
||||
select {
|
||||
case framerErrors <- err:
|
||||
case <-subCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
serverProtocol := protocol.NewServer(h.serverHandler, h.framesIn, h.framesOut, h.logger)
|
||||
|
||||
err := serverProtocol.Handle(ctx)
|
||||
err := serverProtocol.Handle(subCtx)
|
||||
|
||||
protocolErrors <- err
|
||||
h.logger.Info("server protocol stopped")
|
||||
|
||||
select {
|
||||
case protocolErrors <- err:
|
||||
case <-subCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
|
@ -105,23 +140,19 @@ func (h *Handler) Run(ctx context.Context) error {
|
|||
func New(
|
||||
cfg *config.Serial,
|
||||
logger *logrus.Logger,
|
||||
framer protocol.Framer,
|
||||
protocolHandler protocol.ServerHandler,
|
||||
) (*Handler, error) {
|
||||
cobsFramer, err := protocol.NewCOBSFramer(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create COBS framer: %w", err)
|
||||
}
|
||||
|
||||
h := &Handler{
|
||||
serverHandler: protocolHandler,
|
||||
logger: logger,
|
||||
framesIn: make(chan []byte),
|
||||
framesOut: make(chan []byte),
|
||||
framer: cobsFramer,
|
||||
framer: framer,
|
||||
}
|
||||
h.config = &serial.Config{Name: cfg.Device, Baud: cfg.Baud, ReadTimeout: cfg.Timeout}
|
||||
|
||||
err = h.setupConnection()
|
||||
err := h.setupConnection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -421,7 +421,7 @@ func (c *COBSFramer) ReadFrames(ctx context.Context, reader io.Reader, frameChan
|
|||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
raw, err = c.readRaw(reader)
|
||||
raw, err = c.readRaw(ctx, reader)
|
||||
if err != nil {
|
||||
close(frameChan)
|
||||
|
||||
|
@ -469,21 +469,31 @@ func (c *COBSFramer) ReadFrames(ctx context.Context, reader io.Reader, frameChan
|
|||
}
|
||||
}
|
||||
|
||||
func (c *COBSFramer) readRaw(reader io.Reader) ([]byte, error) {
|
||||
buf := make([]byte, bufferSize)
|
||||
func (c *COBSFramer) readRaw(ctx context.Context, reader io.Reader) ([]byte, error) {
|
||||
result := make(chan []byte)
|
||||
errChan := make(chan error)
|
||||
|
||||
count, err := reader.Read(buf)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return buf[:count], nil
|
||||
go func() {
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
count, err := reader.Read(buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
errChan <- fmt.Errorf("could not read data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("could not read data: %w", err)
|
||||
result <- buf[:count]
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, nil
|
||||
case raw := <-result:
|
||||
return raw, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
raw := buf[:count]
|
||||
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func (c *COBSFramer) WriteFrames(ctx context.Context, writer io.Writer, frameChan chan []byte) error {
|
||||
|
|
Loading…
Reference in a new issue