Fix potential race condition in client

Synchronize go routines in client.Run to make sure to avoid access to the
common context before use.
This commit is contained in:
Jan Dittberner 2022-12-03 13:15:44 +01:00
parent f4a1958307
commit 9c608ed81f

View file

@ -91,44 +91,79 @@ type Client struct {
} }
func (c *Client) Run(ctx context.Context) error { func (c *Client) Run(ctx context.Context) error {
protocolErrors := make(chan error) const componentCount = 4
framerErrors := make(chan error)
protocolErrors, framerErrors := make(chan error), make(chan error)
subCtx, cancel := context.WithCancel(ctx)
wg := sync.WaitGroup{}
wg.Add(componentCount)
defer func() {
cancel()
c.logger.Info("context canceled, waiting for shutdown of components")
wg.Wait()
c.logger.Info("shutdown complete")
}()
go func(f protocol.Framer) { go func(f protocol.Framer) {
framerErrors <- f.ReadFrames(ctx, c.port, c.in) defer wg.Done()
err := f.ReadFrames(subCtx, c.port, c.in)
c.logger.Info("frame reading stopped")
select {
case framerErrors <- err:
case <-subCtx.Done():
}
}(c.framer) }(c.framer)
go func(f protocol.Framer) { go func(f protocol.Framer) {
framerErrors <- f.WriteFrames(ctx, c.port, c.out) defer wg.Done()
err := f.WriteFrames(subCtx, c.port, c.out)
c.logger.Info("frame writing stopped")
select {
case framerErrors <- err:
case <-subCtx.Done():
}
}(c.framer) }(c.framer)
go func() { go func() {
clientProtocol := protocol.NewClient(c.handler, c.commands, c.in, c.out, c.logger) clientProtocol := protocol.NewClient(c.handler, c.commands, c.in, c.out, c.logger)
protocolErrors <- clientProtocol.Handle(ctx) err := clientProtocol.Handle(subCtx)
c.logger.Info("client protocol stopped")
select {
case protocolErrors <- err:
case <-subCtx.Done():
}
}() }()
ctx, cancelCommandLoop := context.WithCancel(ctx) go func() {
defer wg.Done()
go c.commandLoop(ctx) c.commandLoop(subCtx)
c.logger.Info("client command loop stopped")
}()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
cancelCommandLoop()
return nil return nil
case err := <-framerErrors: case err := <-framerErrors:
cancelCommandLoop()
if err != nil { if err != nil {
return fmt.Errorf("error from framer: %w", err) return fmt.Errorf("error from framer: %w", err)
} }
return nil return nil
case err := <-protocolErrors: case err := <-protocolErrors:
cancelCommandLoop()
if err != nil { if err != nil {
return fmt.Errorf("error from protocol: %w", err) return fmt.Errorf("error from protocol: %w", err)
} }