diff --git a/go.mod b/go.mod index 4555b68..8ffe5c5 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module git.cacert.org/cacert-gosignerclient go 1.19 require ( - git.cacert.org/cacert-gosigner v0.0.0-20221201103407-51afebf2c12f + git.cacert.org/cacert-gosigner v0.0.0-20221201203610-19436c06c2cf github.com/balacode/go-delta v0.1.0 github.com/shamaton/msgpackgen v0.3.0 github.com/sirupsen/logrus v1.9.0 diff --git a/go.sum b/go.sum index d3ce837..e722e7d 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ -git.cacert.org/cacert-gosigner v0.0.0-20221130191226-de7e716a8274 h1:lGaIVUyXCtmDZ3ZhCYE44rpbvDF/JMDA/zrPgCZKMvc= -git.cacert.org/cacert-gosigner v0.0.0-20221130191226-de7e716a8274/go.mod h1:mb8oBdxQ26GI3xT4b8B7hXYWGED9vvjPGxehmbicyc4= -git.cacert.org/cacert-gosigner v0.0.0-20221201103407-51afebf2c12f h1:OLe/r/dK4WtQXjLCP3Naq/MUkkVkClvw2JIzLxNR2sI= -git.cacert.org/cacert-gosigner v0.0.0-20221201103407-51afebf2c12f/go.mod h1:mb8oBdxQ26GI3xT4b8B7hXYWGED9vvjPGxehmbicyc4= +git.cacert.org/cacert-gosigner v0.0.0-20221201203610-19436c06c2cf h1:dV0Y485b/HvtrEVC9Yflla8pVAnMoRvfKZI4pozOftY= +git.cacert.org/cacert-gosigner v0.0.0-20221201203610-19436c06c2cf/go.mod h1:mb8oBdxQ26GI3xT4b8B7hXYWGED9vvjPGxehmbicyc4= github.com/balacode/go-delta v0.1.0 h1:pwz4CMn06P2bIaIfAx3GSabMPwJp/Ww4if+7SgPYa3I= github.com/balacode/go-delta v0.1.0/go.mod h1:wLNrwTI3lHbPBvnLzqbHmA7HVVlm1u22XLvhbeA6t3o= github.com/balacode/zr v1.0.0 h1:MCupkEoXvrnCljc4KddiDOhR04ZLUAACgtKuo3o+9vc= diff --git a/internal/client/client.go b/internal/client/client.go index 1ac3d9d..1802102 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -80,17 +80,17 @@ func (c *Client) Run(ctx context.Context) error { framerErrors := make(chan error) go func(f protocol.Framer) { - framerErrors <- f.ReadFrames(c.port, c.in) + framerErrors <- f.ReadFrames(ctx, c.port, c.in) }(c.framer) go func(f protocol.Framer) { - framerErrors <- f.WriteFrames(c.port, c.out) + framerErrors <- f.WriteFrames(ctx, c.port, c.out) }(c.framer) go func() { clientProtocol := protocol.NewClient(c.handler, c.commands, c.in, c.out, c.logger) - protocolErrors <- clientProtocol.Handle() + protocolErrors <- clientProtocol.Handle(ctx) }() ctx, cancelCommandLoop := context.WithCancel(ctx) diff --git a/internal/handler/handler.go b/internal/handler/handler.go index df82a5d..937359c 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -18,6 +18,7 @@ limitations under the License. package handler import ( + "context" "fmt" "time" @@ -39,7 +40,7 @@ type SignerClientHandler struct { clientCallback chan interface{} } -func (s *SignerClientHandler) Send(command *protocol.Command, out chan []byte) error { +func (s *SignerClientHandler) Send(ctx context.Context, command *protocol.Command, out chan []byte) error { var ( frame []byte err error @@ -54,7 +55,12 @@ func (s *SignerClientHandler) Send(command *protocol.Command, out chan []byte) e s.logger.Trace("writing command announcement") - out <- frame + select { + case <-ctx.Done(): + return nil + case out <- frame: + break + } frame, err = msgpack.Marshal(command.Command) if err != nil { @@ -63,17 +69,22 @@ func (s *SignerClientHandler) Send(command *protocol.Command, out chan []byte) e s.logger.WithField("command", command.Command).Debug("write command data") - out <- frame - - return nil + select { + case <-ctx.Done(): + return nil + case out <- frame: + return nil + } } -func (s *SignerClientHandler) ResponseAnnounce(in chan []byte) (*protocol.Response, error) { +func (s *SignerClientHandler) ResponseAnnounce(ctx context.Context, in chan []byte) (*protocol.Response, error) { response := &protocol.Response{} var announce messages.ResponseAnnounce select { + case <-ctx.Done(): + return nil, nil case frame := <-in: if err := msgpack.Unmarshal(frame, &announce); err != nil { return nil, fmt.Errorf("could not unmarshal response announcement: %w", err) @@ -89,8 +100,10 @@ func (s *SignerClientHandler) ResponseAnnounce(in chan []byte) (*protocol.Respon } } -func (s *SignerClientHandler) ResponseData(in chan []byte, response *protocol.Response) error { +func (s *SignerClientHandler) ResponseData(ctx context.Context, in chan []byte, response *protocol.Response) error { select { + case <-ctx.Done(): + return nil case frame := <-in: switch response.Announce.Code { case messages.RespHealth: @@ -124,7 +137,7 @@ func (s *SignerClientHandler) ResponseData(in chan []byte, response *protocol.Re return nil } -func (s *SignerClientHandler) HandleResponse(response *protocol.Response) error { +func (s *SignerClientHandler) HandleResponse(ctx context.Context, response *protocol.Response) error { s.logger.WithField("response", response.Announce).Info("handled response") s.logger.WithField("response", response).Debug("full response") @@ -132,9 +145,9 @@ func (s *SignerClientHandler) HandleResponse(response *protocol.Response) error case *messages.ErrorResponse: s.logger.WithField("message", r.Message).Error("error from signer") case *messages.HealthResponse: - s.handleHealthResponse(r) + s.handleHealthResponse(ctx, r) case *messages.FetchCRLResponse: - s.handleFetchCRLResponse(r) + s.handleFetchCRLResponse(ctx, r) default: s.logger.WithField("response", response).Warnf("unhandled response of type %T", response.Response) } @@ -142,7 +155,7 @@ func (s *SignerClientHandler) HandleResponse(response *protocol.Response) error return nil } -func (s *SignerClientHandler) handleHealthResponse(r *messages.HealthResponse) { +func (s *SignerClientHandler) handleHealthResponse(ctx context.Context, r *messages.HealthResponse) { signerInfo := client.SignerInfo{} signerInfo.SignerHealth = r.Healthy @@ -198,11 +211,21 @@ func (s *SignerClientHandler) handleHealthResponse(r *messages.HealthResponse) { } } - s.clientCallback <- signerInfo + select { + case <-ctx.Done(): + return + case s.clientCallback <- signerInfo: + break + } } -func (s *SignerClientHandler) handleFetchCRLResponse(r *messages.FetchCRLResponse) { - s.clientCallback <- r +func (s *SignerClientHandler) handleFetchCRLResponse(ctx context.Context, r *messages.FetchCRLResponse) { + select { + case <-ctx.Done(): + return + case s.clientCallback <- r: + break + } } func New(