From 9c608ed81fed7a0b388b2c149452bd701a08861c Mon Sep 17 00:00:00 2001 From: Jan Dittberner Date: Sat, 3 Dec 2022 13:15:44 +0100 Subject: [PATCH] Fix potential race condition in client Synchronize go routines in client.Run to make sure to avoid access to the common context before use. --- internal/client/client.go | 61 ++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/internal/client/client.go b/internal/client/client.go index 3a5ba5b..dabdd11 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -91,44 +91,79 @@ type Client struct { } func (c *Client) Run(ctx context.Context) error { - protocolErrors := make(chan error) - framerErrors := make(chan error) + const componentCount = 4 + + protocolErrors, framerErrors := make(chan error), make(chan error) + + subCtx, cancel := context.WithCancel(ctx) + wg := sync.WaitGroup{} + wg.Add(componentCount) + + defer func() { + cancel() + c.logger.Info("context canceled, waiting for shutdown of components") + wg.Wait() + c.logger.Info("shutdown complete") + }() go func(f protocol.Framer) { - framerErrors <- f.ReadFrames(ctx, c.port, c.in) + defer wg.Done() + + err := f.ReadFrames(subCtx, c.port, c.in) + + c.logger.Info("frame reading stopped") + + select { + case framerErrors <- err: + case <-subCtx.Done(): + } }(c.framer) go func(f protocol.Framer) { - framerErrors <- f.WriteFrames(ctx, c.port, c.out) + defer wg.Done() + + err := f.WriteFrames(subCtx, c.port, c.out) + + c.logger.Info("frame writing stopped") + + select { + case framerErrors <- err: + case <-subCtx.Done(): + } }(c.framer) go func() { clientProtocol := protocol.NewClient(c.handler, c.commands, c.in, c.out, c.logger) - protocolErrors <- clientProtocol.Handle(ctx) + err := clientProtocol.Handle(subCtx) + + c.logger.Info("client protocol stopped") + + select { + case protocolErrors <- err: + case <-subCtx.Done(): + } }() - ctx, cancelCommandLoop := context.WithCancel(ctx) + go func() { + defer wg.Done() - go c.commandLoop(ctx) + c.commandLoop(subCtx) + + c.logger.Info("client command loop stopped") + }() for { select { case <-ctx.Done(): - cancelCommandLoop() - return nil case err := <-framerErrors: - cancelCommandLoop() - if err != nil { return fmt.Errorf("error from framer: %w", err) } return nil case err := <-protocolErrors: - cancelCommandLoop() - if err != nil { return fmt.Errorf("error from protocol: %w", err) }