Improve robustness and concurrency handling

- Rename client.CertInfo to CACertificateInfo
- declare commands channel inside client.Run, there is no need to inject it
  from the outside
- let command generating code in client.commandLoop run in goroutines to
  allow parallel handling of queued commands and avoid blocking operations
- pass context to command generating functions to allow cancellation
- guard access to c.knownCACertificates by mutex.Lock and mutex.Unlock
- make command channel capacity configurable
- update to latest cacert-gosigner dependency for channel direction support
- improve handling of closed input channel
- reduce client initialization to serial connection setup, move callback and
  handler parameters to client.Run invocation
This commit is contained in:
Jan Dittberner 2022-12-04 14:20:34 +01:00
parent ef1ac1950b
commit f3c0e1379f
6 changed files with 261 additions and 218 deletions

View file

@ -31,7 +31,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"git.cacert.org/cacert-gosigner/pkg/protocol"
"git.cacert.org/cacert-gosignerclient/internal/client" "git.cacert.org/cacert-gosignerclient/internal/client"
"git.cacert.org/cacert-gosignerclient/internal/config" "git.cacert.org/cacert-gosignerclient/internal/config"
"git.cacert.org/cacert-gosignerclient/internal/handler" "git.cacert.org/cacert-gosignerclient/internal/handler"
@ -93,7 +92,7 @@ func main() {
logger.SetLevel(parsedLevel) logger.SetLevel(parsedLevel)
if err := startClient(configFile, logger); err != nil { if err := startClient(configFile, logger); err != nil {
logger.WithError(err).Fatal("client failure") logger.WithError(err).Error("client failure")
os.Exit(1) os.Exit(1)
} }
@ -135,15 +134,7 @@ func startClient(configFile string, logger *logrus.Logger) error {
return fmt.Errorf("could not configure client: %w", err) return fmt.Errorf("could not configure client: %w", err)
} }
commands := make(chan *protocol.Command) signerClient, err := client.New(clientConfig, logger)
callbacks := make(chan interface{}, client.CallBackBufferSize)
clientHandler, err := handler.New(clientConfig, logger, commands, callbacks)
if err != nil {
return fmt.Errorf("could not setup client handler: %w", err)
}
signerClient, err := client.New(clientConfig, logger, clientHandler, commands, callbacks)
if err != nil { if err != nil {
return fmt.Errorf("could not setup client: %w", err) return fmt.Errorf("could not setup client: %w", err)
} }
@ -166,7 +157,14 @@ func startClient(configFile string, logger *logrus.Logger) error {
cancel() cancel()
}() }()
if err = signerClient.Run(ctx); err != nil { callbacks := make(chan interface{}, client.CallBackBufferSize)
clientHandler, err := handler.New(clientConfig, logger, callbacks)
if err != nil {
return fmt.Errorf("could not setup client handler: %w", err)
}
if err = signerClient.Run(ctx, callbacks, clientHandler); err != nil {
return fmt.Errorf("error in client: %w", err) return fmt.Errorf("error in client: %w", err)
} }

2
go.mod
View file

@ -3,7 +3,7 @@ module git.cacert.org/cacert-gosignerclient
go 1.19 go 1.19
require ( require (
git.cacert.org/cacert-gosigner v0.0.0-20221203123337-46407b368528 git.cacert.org/cacert-gosigner v0.0.0-20221204124751-7852c4d3df8c
github.com/balacode/go-delta v0.1.0 github.com/balacode/go-delta v0.1.0
github.com/shamaton/msgpackgen v0.3.0 github.com/shamaton/msgpackgen v0.3.0
github.com/sirupsen/logrus v1.9.0 github.com/sirupsen/logrus v1.9.0

12
go.sum
View file

@ -1,13 +1,5 @@
git.cacert.org/cacert-gosigner v0.0.0-20221202080952-37d3b1e02146 h1:vbm3fIRxNKD4jahqVnIvvU7jc57JfHz5KijalJFlHJ4= git.cacert.org/cacert-gosigner v0.0.0-20221204124751-7852c4d3df8c h1:Awd3z2rKzBHiSuY/hpUBTr2Vnm08V8XxNibUkqaa0Ow=
git.cacert.org/cacert-gosigner v0.0.0-20221202080952-37d3b1e02146/go.mod h1:OGIB5wLUhJiBhTzSXReOhGxuy7sT5VvyOyT8Ux8EGyw= git.cacert.org/cacert-gosigner v0.0.0-20221204124751-7852c4d3df8c/go.mod h1:OGIB5wLUhJiBhTzSXReOhGxuy7sT5VvyOyT8Ux8EGyw=
git.cacert.org/cacert-gosigner v0.0.0-20221202122810-6f8ac9818cd1 h1:HRtgcV6tRM+jN8NxPx7DkuwdH2prZOTdydMBCFg/CWM=
git.cacert.org/cacert-gosigner v0.0.0-20221202122810-6f8ac9818cd1/go.mod h1:OGIB5wLUhJiBhTzSXReOhGxuy7sT5VvyOyT8Ux8EGyw=
git.cacert.org/cacert-gosigner v0.0.0-20221202173159-afe7d23c9b6f h1:VcIwyogvdmYDpDwE7U0+S2P+xU5zwquppAVp2q4eI9k=
git.cacert.org/cacert-gosigner v0.0.0-20221202173159-afe7d23c9b6f/go.mod h1:OGIB5wLUhJiBhTzSXReOhGxuy7sT5VvyOyT8Ux8EGyw=
git.cacert.org/cacert-gosigner v0.0.0-20221203104439-bc81ab84cb4a h1:yX3lhEoBQkUKu23xggAzAeYWuziCkRYktSjsAOfNGHY=
git.cacert.org/cacert-gosigner v0.0.0-20221203104439-bc81ab84cb4a/go.mod h1:OGIB5wLUhJiBhTzSXReOhGxuy7sT5VvyOyT8Ux8EGyw=
git.cacert.org/cacert-gosigner v0.0.0-20221203123337-46407b368528 h1:W1K/YiNp8ganr2GuOpoNC+ZcaVsG15tczpNuvAjHn+U=
git.cacert.org/cacert-gosigner v0.0.0-20221203123337-46407b368528/go.mod h1:OGIB5wLUhJiBhTzSXReOhGxuy7sT5VvyOyT8Ux8EGyw=
github.com/balacode/go-delta v0.1.0 h1:pwz4CMn06P2bIaIfAx3GSabMPwJp/Ww4if+7SgPYa3I= github.com/balacode/go-delta v0.1.0 h1:pwz4CMn06P2bIaIfAx3GSabMPwJp/Ww4if+7SgPYa3I=
github.com/balacode/go-delta v0.1.0/go.mod h1:wLNrwTI3lHbPBvnLzqbHmA7HVVlm1u22XLvhbeA6t3o= github.com/balacode/go-delta v0.1.0/go.mod h1:wLNrwTI3lHbPBvnLzqbHmA7HVVlm1u22XLvhbeA6t3o=
github.com/balacode/zr v1.0.0/go.mod h1:pLeSAL3DhZ9L0JuiRkUtIX3mLOCtzBLnDhfmykbSmkE= github.com/balacode/zr v1.0.0/go.mod h1:pLeSAL3DhZ9L0JuiRkUtIX3mLOCtzBLnDhfmykbSmkE=

View file

@ -50,7 +50,7 @@ type Profile struct {
UseFor string UseFor string
} }
type CertInfo struct { type CACertificateInfo struct {
Name string Name string
FetchCert bool FetchCert bool
FetchCRL bool FetchCRL bool
@ -76,21 +76,18 @@ func (i *SignerInfo) containsCA(caName string) bool {
} }
type Client struct { type Client struct {
port *serial.Port port *serial.Port
logger *logrus.Logger logger *logrus.Logger
framer protocol.Framer framer protocol.Framer
in chan []byte config *config.ClientConfig
out chan []byte signerInfo *SignerInfo
commands chan *protocol.Command knownCACertificates map[string]*CACertificateInfo
handler protocol.ClientHandler
config *config.ClientConfig
signerInfo *SignerInfo
knownCertificates map[string]*CertInfo
callback chan interface{}
sync.Mutex sync.Mutex
} }
func (c *Client) Run(ctx context.Context) error { func (c *Client) Run(
ctx context.Context, callback <-chan interface{}, handler protocol.ClientHandler,
) error {
const componentCount = 4 const componentCount = 4
protocolErrors, framerErrors := make(chan error), make(chan error) protocolErrors, framerErrors := make(chan error), make(chan error)
@ -99,6 +96,10 @@ func (c *Client) Run(ctx context.Context) error {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(componentCount) wg.Add(componentCount)
commands := make(chan *protocol.Command, c.config.CommandChannelCapacity)
fromSigner := make(chan []byte)
toSigner := make(chan []byte)
defer func() { defer func() {
cancel() cancel()
c.logger.Info("context canceled, waiting for shutdown of components") c.logger.Info("context canceled, waiting for shutdown of components")
@ -109,7 +110,7 @@ func (c *Client) Run(ctx context.Context) error {
go func(f protocol.Framer) { go func(f protocol.Framer) {
defer wg.Done() defer wg.Done()
err := f.ReadFrames(subCtx, c.port, c.in) err := f.ReadFrames(subCtx, c.port, fromSigner)
c.logger.Info("frame reading stopped") c.logger.Info("frame reading stopped")
@ -122,7 +123,7 @@ func (c *Client) Run(ctx context.Context) error {
go func(f protocol.Framer) { go func(f protocol.Framer) {
defer wg.Done() defer wg.Done()
err := f.WriteFrames(subCtx, c.port, c.out) err := f.WriteFrames(subCtx, c.port, toSigner)
c.logger.Info("frame writing stopped") c.logger.Info("frame writing stopped")
@ -135,7 +136,7 @@ func (c *Client) Run(ctx context.Context) error {
go func() { go func() {
defer wg.Done() defer wg.Done()
clientProtocol := protocol.NewClient(c.handler, c.commands, c.in, c.out, c.logger) clientProtocol := protocol.NewClient(handler, commands, fromSigner, toSigner, c.logger)
err := clientProtocol.Handle(subCtx) err := clientProtocol.Handle(subCtx)
@ -150,7 +151,7 @@ func (c *Client) Run(ctx context.Context) error {
go func() { go func() {
defer wg.Done() defer wg.Done()
c.commandLoop(subCtx) c.commandLoop(subCtx, commands, callback)
c.logger.Info("client command loop stopped") c.logger.Info("client command loop stopped")
}() }()
@ -192,9 +193,6 @@ func (c *Client) setupConnection(serialConfig *serial.Config) error {
} }
func (c *Client) Close() error { func (c *Client) Close() error {
close(c.in)
close(c.out)
if c.port != nil { if c.port != nil {
err := c.port.Close() err := c.port.Close()
if err != nil { if err != nil {
@ -205,123 +203,150 @@ func (c *Client) Close() error {
return nil return nil
} }
func (c *Client) commandLoop(ctx context.Context) { type commandGenerator func(context.Context, chan<- *protocol.Command) error
func (c *Client) commandLoop(ctx context.Context, commands chan *protocol.Command, callback <-chan interface{}) {
healthTimer := time.NewTimer(c.config.HealthStart) healthTimer := time.NewTimer(c.config.HealthStart)
fetchCRLTimer := time.NewTimer(c.config.FetchCRLStart) fetchCRLTimer := time.NewTimer(c.config.FetchCRLStart)
nextCommands := make(chan *protocol.Command)
for { for {
newCommands := make([]*protocol.Command, 0)
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case callbackData := <-c.callback: case callbackData := <-callback:
addCommands, err := c.handleCallback(callbackData) go func() {
if err != nil { err := c.handleCallback(ctx, nextCommands, callbackData)
c.logger.WithError(err).Error("callback handling failed") if err != nil {
} c.logger.WithError(err).Error("callback handling failed")
}
newCommands = append(newCommands, addCommands...) }()
case <-fetchCRLTimer.C: case <-fetchCRLTimer.C:
for _, crlInfo := range c.requiredCRLs() { go c.scheduleRequiredCRLFetches(ctx, nextCommands)
newCommands = append(newCommands, command.FetchCRL(crlInfo.Name, crlInfo.LastKnown))
}
fetchCRLTimer.Reset(c.config.FetchCRLInterval) fetchCRLTimer.Reset(c.config.FetchCRLInterval)
case <-healthTimer.C: case <-healthTimer.C:
newCommands = append(newCommands, command.Health()) go c.scheduleHealthCheck(ctx, nextCommands)
healthTimer.Reset(c.config.HealthInterval) healthTimer.Reset(c.config.HealthInterval)
} case nextCommand, ok := <-nextCommands:
if !ok {
for _, nextCommand := range newCommands {
select {
case <-ctx.Done():
return return
case c.commands <- nextCommand:
c.logger.WithField("command", nextCommand.Announce).Trace("sent command")
} }
commands <- nextCommand
c.logger.WithFields(map[string]interface{}{
"command": nextCommand.Announce,
"buffer length": len(commands),
}).Trace("sent command")
} }
} }
} }
func (c *Client) handleCallback(data interface{}) ([]*protocol.Command, error) { func (c *Client) handleCallback(
ctx context.Context,
newCommands chan<- *protocol.Command,
data interface{},
) error {
var handler commandGenerator
switch d := data.(type) { switch d := data.(type) {
case SignerInfo: case SignerInfo:
return c.updateSignerInfo(d) handler = c.updateSignerInfo(d)
case *messages.CAInfoResponse: case *messages.CAInfoResponse:
return c.updateCAInformation(d) handler = c.updateCAInformation(d)
case *messages.FetchCRLResponse: case *messages.FetchCRLResponse:
return c.updateCRL(d) handler = c.updateCRL(d)
default: default:
return nil, fmt.Errorf("unknown callback data of type %T", data) return fmt.Errorf("unknown callback data of type %T", data)
} }
if err := handler(ctx, newCommands); err != nil {
return err
}
return nil
} }
func (c *Client) updateSignerInfo(signerInfo SignerInfo) ([]*protocol.Command, error) { func (c *Client) updateSignerInfo(
c.logger.Debug("update signer info") signerInfo SignerInfo,
) commandGenerator {
return func(ctx context.Context, newCommands chan<- *protocol.Command) error {
c.logger.Debug("update signer info")
c.Lock() c.Lock()
c.signerInfo = &signerInfo c.signerInfo = &signerInfo
c.Unlock() c.Unlock()
c.learnNewCACertificates() c.learnNewCACertificates()
c.forgetRemovedCACertificates() c.forgetRemovedCACertificates()
newCommands := make([]*protocol.Command, 0) for _, caName := range c.requiredCertificateInfo() {
select {
for _, caName := range c.requiredCertificateInfo() { case <-ctx.Done():
newCommands = append(newCommands, command.CAInfo(caName)) case newCommands <- command.CAInfo(caName):
} }
return newCommands, nil
}
func (c *Client) updateCAInformation(d *messages.CAInfoResponse) ([]*protocol.Command, error) {
c.Lock()
defer c.Unlock()
caInfo, ok := c.knownCertificates[d.Name]
if !ok {
c.logger.WithField("certificate", d.Name).Warn("unknown CA certificate")
return nil, nil
}
cert, err := x509.ParseCertificate(d.Certificate)
if err != nil {
return nil, fmt.Errorf("could not parse CA certificate for %s: %w", d.Name, err)
}
if !cert.IsCA {
return nil, fmt.Errorf("certificate for %s is not a CA certificate", d.Name)
}
err = c.writeCertificate(caInfo.Name, d.Certificate)
if err != nil {
c.logger.WithError(err).WithField("certificate", d.Name).Warn("could not write CA certificate files")
}
caInfo.Certificate = cert
caInfo.FetchCert = false
caInfo.Profiles = make(map[string]*Profile)
for _, p := range d.Profiles {
caInfo.Profiles[p.Name] = &Profile{
Name: p.Name,
UseFor: p.UseFor.String(),
} }
return nil
} }
}
if len(cert.CRLDistributionPoints) == 0 { func (c *Client) updateCAInformation(
caInfo.FetchCRL = false infoResponse *messages.CAInfoResponse,
) commandGenerator {
return func(ctx context.Context, newCommands chan<- *protocol.Command) error {
var (
caInfo *CACertificateInfo
cert *x509.Certificate
err error
)
return nil, nil if caInfo, err = c.getCACertificate(infoResponse.Name); err != nil {
return err
}
if cert, err = x509.ParseCertificate(infoResponse.Certificate); err != nil {
return fmt.Errorf("could not parse CA certificate for %s: %w", infoResponse.Name, err)
}
if !cert.IsCA {
return fmt.Errorf("certificate for %s is not a CA certificate", infoResponse.Name)
}
if err = c.writeCertificate(caInfo.Name, infoResponse.Certificate); err != nil {
c.logger.WithError(err).WithField("certificate", infoResponse.Name).Warn(
"could not write CA certificate files",
)
}
caInfo.Certificate = cert
caInfo.FetchCert = false
caInfo.Profiles = make(map[string]*Profile)
for _, p := range infoResponse.Profiles {
caInfo.Profiles[p.Name] = &Profile{
Name: p.Name,
UseFor: p.UseFor.String(),
}
}
if len(cert.CRLDistributionPoints) == 0 {
caInfo.FetchCRL = false
return nil
}
select {
case <-ctx.Done():
case newCommands <- command.FetchCRL(caInfo.Name, c.lastKnownCRL(caInfo)):
}
return nil
} }
return []*protocol.Command{command.FetchCRL(caInfo.Name, c.lastKnownCRL(caInfo))}, nil
} }
type CRLInfo struct { type CRLInfo struct {
@ -329,40 +354,39 @@ type CRLInfo struct {
LastKnown *big.Int LastKnown *big.Int
} }
func (c *Client) requiredCRLs() []CRLInfo { func (c *Client) scheduleRequiredCRLFetches(ctx context.Context, newCommands chan<- *protocol.Command) {
c.Lock()
defer c.Unlock()
if c.knownCertificates == nil {
c.logger.Warn("no certificates known")
return nil
}
infos := make([]CRLInfo, 0) infos := make([]CRLInfo, 0)
for _, caInfo := range c.knownCertificates { c.Lock()
for _, caInfo := range c.knownCACertificates {
if caInfo.FetchCRL { if caInfo.FetchCRL {
infos = append(infos, CRLInfo{Name: caInfo.Name, LastKnown: c.lastKnownCRL(caInfo)}) infos = append(infos, CRLInfo{Name: caInfo.Name, LastKnown: c.lastKnownCRL(caInfo)})
} }
} }
c.Unlock()
return infos for _, crlInfo := range infos {
select {
case <-ctx.Done():
case newCommands <- command.FetchCRL(crlInfo.Name, crlInfo.LastKnown):
}
}
}
func (c *Client) scheduleHealthCheck(ctx context.Context, nextCommands chan<- *protocol.Command) {
select {
case <-ctx.Done():
case nextCommands <- command.Health():
}
} }
func (c *Client) requiredCertificateInfo() []string { func (c *Client) requiredCertificateInfo() []string {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
if c.knownCertificates == nil {
c.logger.Warn("no certificates known")
return nil
}
infos := make([]string, 0) infos := make([]string, 0)
for _, caInfo := range c.knownCertificates { for _, caInfo := range c.knownCACertificates {
if caInfo.FetchCert { if caInfo.FetchCert {
infos = append(infos, caInfo.Name) infos = append(infos, caInfo.Name)
} }
@ -371,7 +395,7 @@ func (c *Client) requiredCertificateInfo() []string {
return infos return infos
} }
func (c *Client) lastKnownCRL(caInfo *CertInfo) *big.Int { func (c *Client) lastKnownCRL(caInfo *CACertificateInfo) *big.Int {
caName := caInfo.Name caName := caInfo.Name
crlFileName := c.buildCRLFileName(caName) crlFileName := c.buildCRLFileName(caName)
@ -406,66 +430,62 @@ func (c *Client) lastKnownCRL(caInfo *CertInfo) *big.Int {
return lastKnown return lastKnown
} }
func (c *Client) updateCRL(d *messages.FetchCRLResponse) ([]*protocol.Command, error) { func (c *Client) updateCRL(fetchCRLResponse *messages.FetchCRLResponse) commandGenerator {
var ( return func(_ context.Context, _ chan<- *protocol.Command) error {
crlNumber *big.Int var (
der []byte crlNumber *big.Int
err error der []byte
) err error
list *x509.RevocationList
)
caInfo, ok := c.knownCertificates[d.IssuerID] if _, err = c.getCACertificate(fetchCRLResponse.IssuerID); err != nil {
if !ok { return err
c.logger.WithField("certificate", d.IssuerID).Warn("unknown CA certificate")
}
if d.UnChanged {
c.logger.WithField("issuer", d.IssuerID).Debug("CRL did not change")
return nil, nil
}
if !d.IsDelta {
der = d.CRLData
list, err := x509.ParseRevocationList(der)
if err != nil {
c.logger.WithError(err).Error("CRL from signer could not be parsed")
return nil, nil
} }
crlNumber = list.Number if fetchCRLResponse.UnChanged {
} else { c.logger.WithField("issuer", fetchCRLResponse.IssuerID).Debug("CRL did not change")
crlFileName := c.buildCRLFileName(d.IssuerID)
der, err = c.patchCRL(crlFileName, d.CRLData) return nil
if err != nil {
c.logger.WithError(err).Error("CRL patching failed")
return nil, nil
} }
list, err := x509.ParseRevocationList(der) if !fetchCRLResponse.IsDelta {
if err != nil { der = fetchCRLResponse.CRLData
c.logger.WithError(err).Error("could not parse patched CRL")
return nil, nil list, err = x509.ParseRevocationList(der)
if err != nil {
return fmt.Errorf(
"CRL for %s from signer could not be parsed: %w",
fetchCRLResponse.IssuerID,
err,
)
}
crlNumber = list.Number
} else {
crlFileName := c.buildCRLFileName(fetchCRLResponse.IssuerID)
if der, err = c.patchCRL(crlFileName, fetchCRLResponse.CRLData); err != nil {
return fmt.Errorf("CRL patching failed: %w", err)
}
if list, err = x509.ParseRevocationList(der); err != nil {
return fmt.Errorf("could not parse patched CRL: %w", err)
}
crlNumber = list.Number
} }
crlNumber = list.Number if err = c.writeCRL(fetchCRLResponse.IssuerID, der); err != nil {
c.setLastKnownCRL(fetchCRLResponse.IssuerID, nil)
return fmt.Errorf("could not store CRL for %s: %w", fetchCRLResponse.IssuerID, err)
}
c.setLastKnownCRL(fetchCRLResponse.IssuerID, crlNumber)
return nil
} }
if err := c.writeCRL(d.IssuerID, der); err != nil {
c.logger.WithError(err).Error("could not store CRL")
caInfo.LastKnownCRL = nil
return nil, nil
}
caInfo.LastKnownCRL = crlNumber
return nil, nil
} }
func (c *Client) buildCRLFileName(caName string) string { func (c *Client) buildCRLFileName(caName string) string {
@ -534,11 +554,11 @@ func (c *Client) learnNewCACertificates() {
defer c.Unlock() defer c.Unlock()
for _, caName := range c.signerInfo.CACertificates { for _, caName := range c.signerInfo.CACertificates {
if _, ok := c.knownCertificates[caName]; ok { if _, ok := c.knownCACertificates[caName]; ok {
continue continue
} }
c.knownCertificates[caName] = &CertInfo{ c.knownCACertificates[caName] = &CACertificateInfo{
Name: caName, Name: caName,
FetchCert: true, FetchCert: true,
FetchCRL: true, FetchCRL: true,
@ -550,23 +570,48 @@ func (c *Client) forgetRemovedCACertificates() {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
for knownCA := range c.knownCertificates { for knownCA := range c.knownCACertificates {
if c.signerInfo.containsCA(knownCA) { if c.signerInfo.containsCA(knownCA) {
continue continue
} }
c.logger.WithField("certificate", knownCA).Warn("signer did not send status for certificate") c.logger.WithField("certificate", knownCA).Warn("signer did not send status for certificate")
delete(c.knownCertificates, knownCA) delete(c.knownCACertificates, knownCA)
} }
} }
func (c *Client) getCACertificate(name string) (*CACertificateInfo, error) {
c.Lock()
defer c.Unlock()
caInfo, ok := c.knownCACertificates[name]
if !ok {
return nil, fmt.Errorf("no known CA certificate for %s", name)
}
return caInfo, nil
}
func (c *Client) setLastKnownCRL(caName string, number *big.Int) {
c.Lock()
defer c.Unlock()
caInfo, ok := c.knownCACertificates[caName]
if !ok {
c.logger.WithField("certificate", caName).Warn(
"tried to set last known CRL for unknown CA certificate",
)
return
}
caInfo.LastKnownCRL = number
}
func New( func New(
cfg *config.ClientConfig, cfg *config.ClientConfig,
logger *logrus.Logger, logger *logrus.Logger,
handler protocol.ClientHandler,
commands chan *protocol.Command,
callback chan interface{},
) (*Client, error) { ) (*Client, error) {
cobsFramer, err := protocol.NewCOBSFramer(logger) cobsFramer, err := protocol.NewCOBSFramer(logger)
if err != nil { if err != nil {
@ -574,15 +619,10 @@ func New(
} }
client := &Client{ client := &Client{
logger: logger, logger: logger,
framer: cobsFramer, framer: cobsFramer,
in: make(chan []byte), config: cfg,
out: make(chan []byte), knownCACertificates: make(map[string]*CACertificateInfo),
commands: commands,
handler: handler,
config: cfg,
callback: callback,
knownCertificates: make(map[string]*CertInfo),
} }
err = client.setupConnection(&serial.Config{ err = client.setupConnection(&serial.Config{

View file

@ -35,6 +35,7 @@ const (
defaultResponseAnnounceTimeout = 30 * time.Second defaultResponseAnnounceTimeout = 30 * time.Second
defaultResponseDataTimeout = 2 * time.Second defaultResponseDataTimeout = 2 * time.Second
defaultFilesDirectory = "public" defaultFilesDirectory = "public"
defaultCommandChannelCapacity = 100
) )
type SettingsError struct { type SettingsError struct {
@ -61,6 +62,7 @@ type ClientConfig struct {
ResponseDataTimeout time.Duration `yaml:"response-data-timeout"` ResponseDataTimeout time.Duration `yaml:"response-data-timeout"`
PublicCRLDirectory string `yaml:"public-crl-directory"` PublicCRLDirectory string `yaml:"public-crl-directory"`
PublicCertificateDirectory string `yaml:"public-certificate-directory"` PublicCertificateDirectory string `yaml:"public-certificate-directory"`
CommandChannelCapacity int `yaml:"command-channel-capacity"`
} }
func (c *ClientConfig) UnmarshalYAML(n *yaml.Node) error { func (c *ClientConfig) UnmarshalYAML(n *yaml.Node) error {
@ -74,6 +76,7 @@ func (c *ClientConfig) UnmarshalYAML(n *yaml.Node) error {
ResponseDataTimeout time.Duration `yaml:"response-data-timeout"` ResponseDataTimeout time.Duration `yaml:"response-data-timeout"`
PublicCRLDirectory string `yaml:"public-crl-directory"` PublicCRLDirectory string `yaml:"public-crl-directory"`
PublicCertificateDirectory string `yaml:"public-certificate-directory"` PublicCertificateDirectory string `yaml:"public-certificate-directory"`
CommandChannelCapacity int `yaml:"command-channel-capacity"`
}{} }{}
err := n.Decode(&data) err := n.Decode(&data)
@ -141,6 +144,12 @@ func (c *ClientConfig) UnmarshalYAML(n *yaml.Node) error {
data.PublicCertificateDirectory = defaultFilesDirectory data.PublicCertificateDirectory = defaultFilesDirectory
} }
if data.CommandChannelCapacity == 0 {
data.CommandChannelCapacity = defaultCommandChannelCapacity
}
c.CommandChannelCapacity = data.CommandChannelCapacity
c.PublicCertificateDirectory = data.PublicCRLDirectory c.PublicCertificateDirectory = data.PublicCRLDirectory
return nil return nil

View file

@ -19,6 +19,7 @@ package handler
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
@ -35,12 +36,13 @@ import (
type SignerClientHandler struct { type SignerClientHandler struct {
logger *logrus.Logger logger *logrus.Logger
commands chan *protocol.Command
config *config.ClientConfig config *config.ClientConfig
clientCallback chan interface{} clientCallback chan<- interface{}
} }
func (s *SignerClientHandler) Send(ctx context.Context, command *protocol.Command, out chan []byte) error { var errInputClosed = errors.New("input channel has been closed")
func (s *SignerClientHandler) Send(ctx context.Context, command *protocol.Command, out chan<- []byte) error {
var ( var (
frame []byte frame []byte
err error err error
@ -77,7 +79,7 @@ func (s *SignerClientHandler) Send(ctx context.Context, command *protocol.Comman
} }
} }
func (s *SignerClientHandler) ResponseAnnounce(ctx context.Context, in chan []byte) (*protocol.Response, error) { func (s *SignerClientHandler) ResponseAnnounce(ctx context.Context, in <-chan []byte) (*protocol.Response, error) {
response := &protocol.Response{} response := &protocol.Response{}
var announce messages.ResponseAnnounce var announce messages.ResponseAnnounce
@ -85,7 +87,11 @@ func (s *SignerClientHandler) ResponseAnnounce(ctx context.Context, in chan []by
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, nil return nil, nil
case frame := <-in: case frame, ok := <-in:
if !ok {
return nil, errInputClosed
}
if err := msgpack.Unmarshal(frame, &announce); err != nil { if err := msgpack.Unmarshal(frame, &announce); err != nil {
return nil, fmt.Errorf("could not unmarshal response announcement: %w", err) return nil, fmt.Errorf("could not unmarshal response announcement: %w", err)
} }
@ -100,7 +106,7 @@ func (s *SignerClientHandler) ResponseAnnounce(ctx context.Context, in chan []by
} }
} }
func (s *SignerClientHandler) ResponseData(ctx context.Context, in chan []byte, response *protocol.Response) error { func (s *SignerClientHandler) ResponseData(ctx context.Context, in <-chan []byte, response *protocol.Response) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
@ -230,13 +236,11 @@ func (s *SignerClientHandler) handleFetchCRLResponse(ctx context.Context, r *mes
func New( func New(
config *config.ClientConfig, config *config.ClientConfig,
logger *logrus.Logger, logger *logrus.Logger,
commands chan *protocol.Command,
clientCallback chan interface{}, clientCallback chan interface{},
) (protocol.ClientHandler, error) { ) (protocol.ClientHandler, error) {
return &SignerClientHandler{ return &SignerClientHandler{
logger: logger, logger: logger,
config: config, config: config,
commands: commands,
clientCallback: clientCallback, clientCallback: clientCallback,
}, nil }, nil
} }