cacert-gosignerclient/internal/client/client.go
Jan Dittberner f3c0e1379f Improve robustness and concurrency handling
- Rename client.CertInfo to CACertificateInfo
- declare commands channel inside client.Run, there is no need to inject it
  from the outside
- let command generating code in client.commandLoop run in goroutines to
  allow parallel handling of queued commands and avoid blocking operations
- pass context to command generating functions to allow cancellation
- guard access to c.knownCACertificates by mutex.Lock and mutex.Unlock
- make command channel capacity configurable
- update to latest cacert-gosigner dependency for channel direction support
- improve handling of closed input channel
- reduce client initialization to serial connection setup, move callback and
  handler parameters to client.Run invocation
2022-12-04 14:20:34 +01:00

638 lines
14 KiB
Go

/*
Copyright 2022 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package client
import (
"context"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"os"
"path"
"sync"
"time"
"github.com/balacode/go-delta"
"github.com/sirupsen/logrus"
"github.com/tarm/serial"
"git.cacert.org/cacert-gosigner/pkg/messages"
"git.cacert.org/cacert-gosigner/pkg/protocol"
"git.cacert.org/cacert-gosignerclient/internal/command"
"git.cacert.org/cacert-gosignerclient/internal/config"
)
const CallBackBufferSize = 50
const (
worldReadableDirPerm = 0o755
worldReadableFilePerm = 0o644
)
type Profile struct {
Name string
UseFor string
}
type CACertificateInfo struct {
Name string
FetchCert bool
FetchCRL bool
LastKnownCRL *big.Int
Certificate *x509.Certificate
Profiles map[string]*Profile
}
type SignerInfo struct {
SignerHealth bool
SignerVersion string
CACertificates []string
}
func (i *SignerInfo) containsCA(caName string) bool {
for _, name := range i.CACertificates {
if name == caName {
return true
}
}
return false
}
type Client struct {
port *serial.Port
logger *logrus.Logger
framer protocol.Framer
config *config.ClientConfig
signerInfo *SignerInfo
knownCACertificates map[string]*CACertificateInfo
sync.Mutex
}
func (c *Client) Run(
ctx context.Context, callback <-chan interface{}, handler protocol.ClientHandler,
) error {
const componentCount = 4
protocolErrors, framerErrors := make(chan error), make(chan error)
subCtx, cancel := context.WithCancel(ctx)
wg := sync.WaitGroup{}
wg.Add(componentCount)
commands := make(chan *protocol.Command, c.config.CommandChannelCapacity)
fromSigner := make(chan []byte)
toSigner := make(chan []byte)
defer func() {
cancel()
c.logger.Info("context canceled, waiting for shutdown of components")
wg.Wait()
c.logger.Info("shutdown complete")
}()
go func(f protocol.Framer) {
defer wg.Done()
err := f.ReadFrames(subCtx, c.port, fromSigner)
c.logger.Info("frame reading stopped")
select {
case framerErrors <- err:
case <-subCtx.Done():
}
}(c.framer)
go func(f protocol.Framer) {
defer wg.Done()
err := f.WriteFrames(subCtx, c.port, toSigner)
c.logger.Info("frame writing stopped")
select {
case framerErrors <- err:
case <-subCtx.Done():
}
}(c.framer)
go func() {
defer wg.Done()
clientProtocol := protocol.NewClient(handler, commands, fromSigner, toSigner, c.logger)
err := clientProtocol.Handle(subCtx)
c.logger.Info("client protocol stopped")
select {
case protocolErrors <- err:
case <-subCtx.Done():
}
}()
go func() {
defer wg.Done()
c.commandLoop(subCtx, commands, callback)
c.logger.Info("client command loop stopped")
}()
for {
select {
case <-ctx.Done():
return nil
case err := <-framerErrors:
if err != nil {
return fmt.Errorf("error from framer: %w", err)
}
return nil
case err := <-protocolErrors:
if err != nil {
return fmt.Errorf("error from protocol: %w", err)
}
return nil
}
}
}
func (c *Client) setupConnection(serialConfig *serial.Config) error {
s, err := serial.OpenPort(serialConfig)
if err != nil {
return fmt.Errorf("could not open serial port: %w", err)
}
c.port = s
err = c.port.Flush()
if err != nil {
c.logger.WithError(err).Warn("could not flush buffers of port: %w", err)
}
return nil
}
func (c *Client) Close() error {
if c.port != nil {
err := c.port.Close()
if err != nil {
return fmt.Errorf("could not close serial port: %w", err)
}
}
return nil
}
type commandGenerator func(context.Context, chan<- *protocol.Command) error
func (c *Client) commandLoop(ctx context.Context, commands chan *protocol.Command, callback <-chan interface{}) {
healthTimer := time.NewTimer(c.config.HealthStart)
fetchCRLTimer := time.NewTimer(c.config.FetchCRLStart)
nextCommands := make(chan *protocol.Command)
for {
select {
case <-ctx.Done():
return
case callbackData := <-callback:
go func() {
err := c.handleCallback(ctx, nextCommands, callbackData)
if err != nil {
c.logger.WithError(err).Error("callback handling failed")
}
}()
case <-fetchCRLTimer.C:
go c.scheduleRequiredCRLFetches(ctx, nextCommands)
fetchCRLTimer.Reset(c.config.FetchCRLInterval)
case <-healthTimer.C:
go c.scheduleHealthCheck(ctx, nextCommands)
healthTimer.Reset(c.config.HealthInterval)
case nextCommand, ok := <-nextCommands:
if !ok {
return
}
commands <- nextCommand
c.logger.WithFields(map[string]interface{}{
"command": nextCommand.Announce,
"buffer length": len(commands),
}).Trace("sent command")
}
}
}
func (c *Client) handleCallback(
ctx context.Context,
newCommands chan<- *protocol.Command,
data interface{},
) error {
var handler commandGenerator
switch d := data.(type) {
case SignerInfo:
handler = c.updateSignerInfo(d)
case *messages.CAInfoResponse:
handler = c.updateCAInformation(d)
case *messages.FetchCRLResponse:
handler = c.updateCRL(d)
default:
return fmt.Errorf("unknown callback data of type %T", data)
}
if err := handler(ctx, newCommands); err != nil {
return err
}
return nil
}
func (c *Client) updateSignerInfo(
signerInfo SignerInfo,
) commandGenerator {
return func(ctx context.Context, newCommands chan<- *protocol.Command) error {
c.logger.Debug("update signer info")
c.Lock()
c.signerInfo = &signerInfo
c.Unlock()
c.learnNewCACertificates()
c.forgetRemovedCACertificates()
for _, caName := range c.requiredCertificateInfo() {
select {
case <-ctx.Done():
case newCommands <- command.CAInfo(caName):
}
}
return nil
}
}
func (c *Client) updateCAInformation(
infoResponse *messages.CAInfoResponse,
) commandGenerator {
return func(ctx context.Context, newCommands chan<- *protocol.Command) error {
var (
caInfo *CACertificateInfo
cert *x509.Certificate
err error
)
if caInfo, err = c.getCACertificate(infoResponse.Name); err != nil {
return err
}
if cert, err = x509.ParseCertificate(infoResponse.Certificate); err != nil {
return fmt.Errorf("could not parse CA certificate for %s: %w", infoResponse.Name, err)
}
if !cert.IsCA {
return fmt.Errorf("certificate for %s is not a CA certificate", infoResponse.Name)
}
if err = c.writeCertificate(caInfo.Name, infoResponse.Certificate); err != nil {
c.logger.WithError(err).WithField("certificate", infoResponse.Name).Warn(
"could not write CA certificate files",
)
}
caInfo.Certificate = cert
caInfo.FetchCert = false
caInfo.Profiles = make(map[string]*Profile)
for _, p := range infoResponse.Profiles {
caInfo.Profiles[p.Name] = &Profile{
Name: p.Name,
UseFor: p.UseFor.String(),
}
}
if len(cert.CRLDistributionPoints) == 0 {
caInfo.FetchCRL = false
return nil
}
select {
case <-ctx.Done():
case newCommands <- command.FetchCRL(caInfo.Name, c.lastKnownCRL(caInfo)):
}
return nil
}
}
type CRLInfo struct {
Name string
LastKnown *big.Int
}
func (c *Client) scheduleRequiredCRLFetches(ctx context.Context, newCommands chan<- *protocol.Command) {
infos := make([]CRLInfo, 0)
c.Lock()
for _, caInfo := range c.knownCACertificates {
if caInfo.FetchCRL {
infos = append(infos, CRLInfo{Name: caInfo.Name, LastKnown: c.lastKnownCRL(caInfo)})
}
}
c.Unlock()
for _, crlInfo := range infos {
select {
case <-ctx.Done():
case newCommands <- command.FetchCRL(crlInfo.Name, crlInfo.LastKnown):
}
}
}
func (c *Client) scheduleHealthCheck(ctx context.Context, nextCommands chan<- *protocol.Command) {
select {
case <-ctx.Done():
case nextCommands <- command.Health():
}
}
func (c *Client) requiredCertificateInfo() []string {
c.Lock()
defer c.Unlock()
infos := make([]string, 0)
for _, caInfo := range c.knownCACertificates {
if caInfo.FetchCert {
infos = append(infos, caInfo.Name)
}
}
return infos
}
func (c *Client) lastKnownCRL(caInfo *CACertificateInfo) *big.Int {
caName := caInfo.Name
crlFileName := c.buildCRLFileName(caName)
_, err := os.Stat(crlFileName)
if err != nil {
c.logger.WithField("crl", crlFileName).Debug("CRL file does not exist")
return nil
}
lastKnown := caInfo.LastKnownCRL
if lastKnown == nil {
derData, err := os.ReadFile(crlFileName)
if err != nil {
c.logger.WithError(err).WithField("crl", crlFileName).Error("could not read CRL data")
return nil
}
crl, err := x509.ParseRevocationList(derData)
if err != nil {
c.logger.WithError(err).WithField("crl", crlFileName).Error("could not parse CRL data")
return nil
}
lastKnown = crl.Number
}
return lastKnown
}
func (c *Client) updateCRL(fetchCRLResponse *messages.FetchCRLResponse) commandGenerator {
return func(_ context.Context, _ chan<- *protocol.Command) error {
var (
crlNumber *big.Int
der []byte
err error
list *x509.RevocationList
)
if _, err = c.getCACertificate(fetchCRLResponse.IssuerID); err != nil {
return err
}
if fetchCRLResponse.UnChanged {
c.logger.WithField("issuer", fetchCRLResponse.IssuerID).Debug("CRL did not change")
return nil
}
if !fetchCRLResponse.IsDelta {
der = fetchCRLResponse.CRLData
list, err = x509.ParseRevocationList(der)
if err != nil {
return fmt.Errorf(
"CRL for %s from signer could not be parsed: %w",
fetchCRLResponse.IssuerID,
err,
)
}
crlNumber = list.Number
} else {
crlFileName := c.buildCRLFileName(fetchCRLResponse.IssuerID)
if der, err = c.patchCRL(crlFileName, fetchCRLResponse.CRLData); err != nil {
return fmt.Errorf("CRL patching failed: %w", err)
}
if list, err = x509.ParseRevocationList(der); err != nil {
return fmt.Errorf("could not parse patched CRL: %w", err)
}
crlNumber = list.Number
}
if err = c.writeCRL(fetchCRLResponse.IssuerID, der); err != nil {
c.setLastKnownCRL(fetchCRLResponse.IssuerID, nil)
return fmt.Errorf("could not store CRL for %s: %w", fetchCRLResponse.IssuerID, err)
}
c.setLastKnownCRL(fetchCRLResponse.IssuerID, crlNumber)
return nil
}
}
func (c *Client) buildCRLFileName(caName string) string {
return path.Join(c.config.PublicCRLDirectory, fmt.Sprintf("%s.crl", caName))
}
func (c *Client) buildCertificateFileName(caName string, certFormat string) string {
return path.Join(c.config.PublicCRLDirectory, fmt.Sprintf("%s.%s", caName, certFormat))
}
func (c *Client) writeCertificate(caName string, derBytes []byte) error {
if err := os.MkdirAll(c.config.PublicCRLDirectory, worldReadableDirPerm); err != nil {
return fmt.Errorf("could not create public CA data directory %s: %w", c.config.PublicCRLDirectory, err)
}
if err := os.WriteFile(
c.buildCertificateFileName(caName, "crt"), derBytes, worldReadableFilePerm,
); err != nil {
c.logger.WithError(err).Error("could not write DER encoded certificate file")
}
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
if err := os.WriteFile(
c.buildCertificateFileName(caName, "pem"), pemBytes, worldReadableFilePerm,
); err != nil {
c.logger.WithError(err).Error("could not write PEM encoded certificate file")
}
return nil
}
func (c *Client) writeCRL(caName string, crlBytes []byte) error {
if err := os.MkdirAll(c.config.PublicCRLDirectory, worldReadableDirPerm); err != nil {
return fmt.Errorf("could not create public CA data directory %s: %w", c.config.PublicCRLDirectory, err)
}
if err := os.WriteFile(c.buildCRLFileName(caName), crlBytes, worldReadableFilePerm); err != nil {
c.logger.WithError(err).Error("could not write CRL file")
}
return nil
}
func (c *Client) patchCRL(crlFileName string, diff []byte) ([]byte, error) {
original, err := os.ReadFile(crlFileName)
if err != nil {
return nil, fmt.Errorf("could not read existing CRL %s: %w", crlFileName, err)
}
patch, err := delta.Load(diff)
if err != nil {
return nil, fmt.Errorf("could not parse CRL delta: %w", err)
}
der, err := patch.Apply(original)
if err != nil {
return nil, fmt.Errorf("could not apply CRL delta: %w", err)
}
return der, nil
}
func (c *Client) learnNewCACertificates() {
c.Lock()
defer c.Unlock()
for _, caName := range c.signerInfo.CACertificates {
if _, ok := c.knownCACertificates[caName]; ok {
continue
}
c.knownCACertificates[caName] = &CACertificateInfo{
Name: caName,
FetchCert: true,
FetchCRL: true,
}
}
}
func (c *Client) forgetRemovedCACertificates() {
c.Lock()
defer c.Unlock()
for knownCA := range c.knownCACertificates {
if c.signerInfo.containsCA(knownCA) {
continue
}
c.logger.WithField("certificate", knownCA).Warn("signer did not send status for certificate")
delete(c.knownCACertificates, knownCA)
}
}
func (c *Client) getCACertificate(name string) (*CACertificateInfo, error) {
c.Lock()
defer c.Unlock()
caInfo, ok := c.knownCACertificates[name]
if !ok {
return nil, fmt.Errorf("no known CA certificate for %s", name)
}
return caInfo, nil
}
func (c *Client) setLastKnownCRL(caName string, number *big.Int) {
c.Lock()
defer c.Unlock()
caInfo, ok := c.knownCACertificates[caName]
if !ok {
c.logger.WithField("certificate", caName).Warn(
"tried to set last known CRL for unknown CA certificate",
)
return
}
caInfo.LastKnownCRL = number
}
func New(
cfg *config.ClientConfig,
logger *logrus.Logger,
) (*Client, error) {
cobsFramer, err := protocol.NewCOBSFramer(logger)
if err != nil {
return nil, fmt.Errorf("could not create COBS framer: %w", err)
}
client := &Client{
logger: logger,
framer: cobsFramer,
config: cfg,
knownCACertificates: make(map[string]*CACertificateInfo),
}
err = client.setupConnection(&serial.Config{
Name: cfg.Serial.Device,
Baud: cfg.Serial.Baud,
ReadTimeout: cfg.Serial.Timeout,
})
if err != nil {
return nil, err
}
return client, nil
}