/* Copyright CAcert Inc. SPDX-License-Identifier: Apache-2.0 Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package client import ( "context" "crypto/x509" "encoding/pem" "fmt" "math/big" "os" "path" "sync" "time" "github.com/balacode/go-delta" "github.com/sirupsen/logrus" "github.com/tarm/serial" "git.cacert.org/cacert-gosigner/pkg/messages" "git.cacert.org/cacert-gosigner/pkg/protocol" "git.cacert.org/cacert-gosignerclient/internal/command" "git.cacert.org/cacert-gosignerclient/internal/config" ) const CallBackBufferSize = 50 const ( worldReadableDirPerm = 0o755 worldReadableFilePerm = 0o644 ) type Profile struct { Name string UseFor string } type CACertificateInfo struct { Name string FetchCert bool FetchCRL bool LastKnownCRL *big.Int Certificate *x509.Certificate Profiles map[string]*Profile } type SignerInfo struct { SignerHealth bool SignerVersion string CACertificates []string } func (i *SignerInfo) containsCA(caName string) bool { for _, name := range i.CACertificates { if name == caName { return true } } return false } type Client struct { port *serial.Port logger *logrus.Logger framer protocol.Framer config *config.ClientConfig signerInfo *SignerInfo knownCACertificates map[string]*CACertificateInfo commandSources []CommandSource responseSinks map[messages.ResponseCode]ResponseSink sync.Mutex } func (c *Client) Run( ctx context.Context, callback <-chan any, handler protocol.ClientHandler, commands chan *protocol.Command, ) error { const componentCount = 4 protocolErrors, framerErrors, sourceErrors := make(chan error), make(chan error), make(chan error) subCtx, cancel := context.WithCancel(ctx) wg := sync.WaitGroup{} wg.Add(componentCount) wg.Add(len(c.commandSources)) fromSigner := make(chan []byte) toSigner := make(chan []byte) defer func() { cancel() c.logger.Info("context canceled, waiting for shutdown of components") wg.Wait() c.logger.Info("shutdown complete") }() c.RunSources(subCtx, &wg, sourceErrors) go func(f protocol.Framer) { defer wg.Done() err := f.ReadFrames(subCtx, c.port, fromSigner) c.logger.Info("frame reading stopped") select { case framerErrors <- err: case <-subCtx.Done(): } }(c.framer) go func(f protocol.Framer) { defer wg.Done() err := f.WriteFrames(subCtx, c.port, toSigner) c.logger.Info("frame writing stopped") select { case framerErrors <- err: case <-subCtx.Done(): } }(c.framer) go func() { defer wg.Done() clientProtocol := protocol.NewClient(handler, commands, fromSigner, toSigner, c.logger) err := clientProtocol.Handle(subCtx) c.logger.Info("client protocol stopped") select { case protocolErrors <- err: case <-subCtx.Done(): } }() go func() { defer wg.Done() c.commandLoop(subCtx, commands, callback) c.logger.Info("client command loop stopped") }() for { select { case <-ctx.Done(): return nil case err := <-framerErrors: if err != nil { return fmt.Errorf("error from framer: %w", err) } return nil case err := <-protocolErrors: if err != nil { return fmt.Errorf("error from protocol: %w", err) } return nil case err := <-sourceErrors: if err != nil { return fmt.Errorf("error from command source: %w", err) } return nil } } } func (c *Client) setupConnection(serialConfig *serial.Config) error { s, err := serial.OpenPort(serialConfig) if err != nil { return fmt.Errorf("could not open serial port: %w", err) } c.port = s err = c.port.Flush() if err != nil { c.logger.WithError(err).Warn("could not flush buffers of port: %w", err) } return nil } func (c *Client) Close() error { if c.port != nil { err := c.port.Close() if err != nil { return fmt.Errorf("could not close serial port: %w", err) } } return nil } type commandGenerator func(context.Context, chan<- *protocol.Command) error func (c *Client) commandLoop(ctx context.Context, commands chan *protocol.Command, callback <-chan any) { healthTimer := time.NewTimer(c.config.HealthStart) fetchCRLTimer := time.NewTimer(c.config.FetchCRLStart) defer func() { close(commands) c.logger.Info("command loop stopped") }() for { select { case <-ctx.Done(): return case callbackData := <-callback: go func() { err := c.handleCallback(ctx, commands, callbackData) if err != nil { c.logger.WithError(err).Error("callback handling failed") } }() case <-fetchCRLTimer.C: go c.scheduleRequiredCRLFetches(ctx, commands) fetchCRLTimer.Reset(c.config.FetchCRLInterval) case <-healthTimer.C: go c.scheduleHealthCheck(ctx, commands) healthTimer.Reset(c.config.HealthInterval) } } } type ErrNoResponseSink struct { msg string } func (e ErrNoResponseSink) Error() string { return fmt.Sprintf("no response sink for %s response found", e.msg) } func (c *Client) handleCallback( ctx context.Context, newCommands chan<- *protocol.Command, data any, ) error { var ( handler commandGenerator err error ) switch d := data.(type) { case SignerInfo: handler = c.updateSignerInfo(d) case *protocol.Response: handler, err = c.handleResponse(d) if err != nil { return err } default: return fmt.Errorf("unknown callback data of type %T", d) } return handler(ctx, newCommands) } func (c *Client) updateSignerInfo( signerInfo SignerInfo, ) commandGenerator { return func(ctx context.Context, newCommands chan<- *protocol.Command) error { c.logger.Debug("update signer info") c.Lock() c.signerInfo = &signerInfo c.Unlock() c.learnNewCACertificates() c.forgetRemovedCACertificates() for _, caName := range c.requiredCertificateInfo() { select { case <-ctx.Done(): case newCommands <- command.CAInfo(caName): } } return nil } } func (c *Client) updateCAInformation( infoResponse *messages.CAInfoResponse, ) commandGenerator { return func(ctx context.Context, newCommands chan<- *protocol.Command) error { var ( caInfo *CACertificateInfo cert *x509.Certificate err error ) if caInfo, err = c.getCACertificate(infoResponse.Name); err != nil { return err } if cert, err = x509.ParseCertificate(infoResponse.Certificate); err != nil { return fmt.Errorf("could not parse CA certificate for %s: %w", infoResponse.Name, err) } if !cert.IsCA { return fmt.Errorf("certificate for %s is not a CA certificate", infoResponse.Name) } if err = c.writeCertificate(caInfo.Name, infoResponse.Certificate); err != nil { c.logger.WithError(err).WithField("certificate", infoResponse.Name).Warn( "could not write CA certificate files", ) } caInfo.Certificate = cert caInfo.FetchCert = false caInfo.Profiles = make(map[string]*Profile) for _, p := range infoResponse.Profiles { caInfo.Profiles[p.Name] = &Profile{ Name: p.Name, UseFor: p.UseFor.String(), } } if len(cert.CRLDistributionPoints) == 0 { caInfo.FetchCRL = false return nil } select { case <-ctx.Done(): case newCommands <- command.FetchCRL(caInfo.Name, c.lastKnownCRL(caInfo)): } return nil } } type CRLInfo struct { Name string LastKnown *big.Int } func (c *Client) scheduleRequiredCRLFetches(ctx context.Context, newCommands chan<- *protocol.Command) { infos := make([]CRLInfo, 0) c.Lock() for _, caInfo := range c.knownCACertificates { if caInfo.FetchCRL { infos = append(infos, CRLInfo{Name: caInfo.Name, LastKnown: c.lastKnownCRL(caInfo)}) } } c.Unlock() for _, crlInfo := range infos { select { case <-ctx.Done(): case newCommands <- command.FetchCRL(crlInfo.Name, crlInfo.LastKnown): } } } func (c *Client) scheduleHealthCheck(ctx context.Context, nextCommands chan<- *protocol.Command) { select { case <-ctx.Done(): case nextCommands <- command.Health(): } } func (c *Client) requiredCertificateInfo() []string { c.Lock() defer c.Unlock() infos := make([]string, 0) for _, caInfo := range c.knownCACertificates { if caInfo.FetchCert { infos = append(infos, caInfo.Name) } } return infos } func (c *Client) lastKnownCRL(caInfo *CACertificateInfo) *big.Int { caName := caInfo.Name crlFileName := c.buildCRLFileName(caName) _, err := os.Stat(crlFileName) if err != nil { c.logger.WithField("crl", crlFileName).Debug("CRL file does not exist") return nil } lastKnown := caInfo.LastKnownCRL if lastKnown == nil { derData, err := os.ReadFile(crlFileName) if err != nil { c.logger.WithError(err).WithField("crl", crlFileName).Error("could not read CRL data") return nil } crl, err := x509.ParseRevocationList(derData) if err != nil { c.logger.WithError(err).WithField("crl", crlFileName).Error("could not parse CRL data") return nil } lastKnown = crl.Number } return lastKnown } func (c *Client) updateCRL(fetchCRLResponse *messages.FetchCRLResponse) commandGenerator { return func(_ context.Context, _ chan<- *protocol.Command) error { var ( crlNumber *big.Int der []byte err error list *x509.RevocationList ) if _, err = c.getCACertificate(fetchCRLResponse.IssuerID); err != nil { return err } if fetchCRLResponse.UnChanged { c.logger.WithField("issuer", fetchCRLResponse.IssuerID).Debug("CRL did not change") return nil } if !fetchCRLResponse.IsDelta { der = fetchCRLResponse.CRLData list, err = x509.ParseRevocationList(der) if err != nil { return fmt.Errorf( "CRL for %s from signer could not be parsed: %w", fetchCRLResponse.IssuerID, err, ) } crlNumber = list.Number } else { crlFileName := c.buildCRLFileName(fetchCRLResponse.IssuerID) if der, err = c.patchCRL(crlFileName, fetchCRLResponse.CRLData); err != nil { return fmt.Errorf("CRL patching failed: %w", err) } if list, err = x509.ParseRevocationList(der); err != nil { return fmt.Errorf("could not parse patched CRL: %w", err) } crlNumber = list.Number } if err = c.writeCRL(fetchCRLResponse.IssuerID, der); err != nil { c.setLastKnownCRL(fetchCRLResponse.IssuerID, nil) return fmt.Errorf("could not store CRL for %s: %w", fetchCRLResponse.IssuerID, err) } c.setLastKnownCRL(fetchCRLResponse.IssuerID, crlNumber) return nil } } func (c *Client) buildCRLFileName(caName string) string { return path.Join(c.config.PublicCRLDirectory, fmt.Sprintf("%s.crl", caName)) } func (c *Client) buildCertificateFileName(caName string, certFormat string) string { return path.Join(c.config.PublicCRLDirectory, fmt.Sprintf("%s.%s", caName, certFormat)) } func (c *Client) writeCertificate(caName string, derBytes []byte) error { if err := os.MkdirAll(c.config.PublicCRLDirectory, worldReadableDirPerm); err != nil { return fmt.Errorf("could not create public CA data directory %s: %w", c.config.PublicCRLDirectory, err) } if err := os.WriteFile( c.buildCertificateFileName(caName, "crt"), derBytes, worldReadableFilePerm, ); err != nil { c.logger.WithError(err).Error("could not write DER encoded certificate file") } pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) if err := os.WriteFile( c.buildCertificateFileName(caName, "pem"), pemBytes, worldReadableFilePerm, ); err != nil { c.logger.WithError(err).Error("could not write PEM encoded certificate file") } return nil } func (c *Client) writeCRL(caName string, crlBytes []byte) error { if err := os.MkdirAll(c.config.PublicCRLDirectory, worldReadableDirPerm); err != nil { return fmt.Errorf("could not create public CA data directory %s: %w", c.config.PublicCRLDirectory, err) } if err := os.WriteFile(c.buildCRLFileName(caName), crlBytes, worldReadableFilePerm); err != nil { c.logger.WithError(err).Error("could not write CRL file") } return nil } func (c *Client) patchCRL(crlFileName string, diff []byte) ([]byte, error) { original, err := os.ReadFile(crlFileName) if err != nil { return nil, fmt.Errorf("could not read existing CRL %s: %w", crlFileName, err) } patch, err := delta.Load(diff) if err != nil { return nil, fmt.Errorf("could not parse CRL delta: %w", err) } der, err := patch.Apply(original) if err != nil { return nil, fmt.Errorf("could not apply CRL delta: %w", err) } return der, nil } func (c *Client) learnNewCACertificates() { c.Lock() defer c.Unlock() for _, caName := range c.signerInfo.CACertificates { if _, ok := c.knownCACertificates[caName]; ok { continue } c.knownCACertificates[caName] = &CACertificateInfo{ Name: caName, FetchCert: true, FetchCRL: true, } } } func (c *Client) forgetRemovedCACertificates() { c.Lock() defer c.Unlock() for knownCA := range c.knownCACertificates { if c.signerInfo.containsCA(knownCA) { continue } c.logger.WithField("certificate", knownCA).Warn("signer did not send status for certificate") delete(c.knownCACertificates, knownCA) } } func (c *Client) getCACertificate(name string) (*CACertificateInfo, error) { c.Lock() defer c.Unlock() caInfo, ok := c.knownCACertificates[name] if !ok { return nil, fmt.Errorf("no known CA certificate for %s", name) } return caInfo, nil } func (c *Client) setLastKnownCRL(caName string, number *big.Int) { c.Lock() defer c.Unlock() caInfo, ok := c.knownCACertificates[caName] if !ok { c.logger.WithField("certificate", caName).Warn( "tried to set last known CRL for unknown CA certificate", ) return } caInfo.LastKnownCRL = number } type CommandSource interface { Run(context.Context) error } type ResponseSink interface { SupportedResponses() []messages.ResponseCode HandleResponse(context.Context, *messages.ResponseAnnounce, any) error NotifyError(ctx context.Context, requestID, message string) error } func (c *Client) RegisterCommandSource(source CommandSource) { c.commandSources = append(c.commandSources, source) } func (c *Client) RegisterResponseSink(sink ResponseSink) { for _, code := range sink.SupportedResponses() { c.responseSinks[code] = sink } } func (c *Client) handleResponse(r *protocol.Response) (commandGenerator, error) { var handler commandGenerator switch payload := r.Response.(type) { case *messages.CAInfoResponse: handler = c.updateCAInformation(payload) case *messages.FetchCRLResponse: handler = c.updateCRL(payload) case *messages.ErrorResponse: handler = func(ctx context.Context, _ chan<- *protocol.Command) error { for _, sink := range c.responseSinks { if err := sink.NotifyError(ctx, r.Announce.ID, payload.Message); err != nil { return fmt.Errorf("error from response sink: %w", err) } } return nil } case *messages.SignCertificateResponse: sink, ok := c.responseSinks[messages.RespSignCertificate] if !ok { return nil, ErrNoResponseSink{"sign certificate"} } handler = func(ctx context.Context, _ chan<- *protocol.Command) error { if err := sink.HandleResponse(ctx, r.Announce, payload); err != nil { return fmt.Errorf("error from response sink: %w", err) } return nil } case *messages.SignOpenPGPResponse: sink, ok := c.responseSinks[messages.RespSignOpenPGP] if !ok { return nil, ErrNoResponseSink{"sign openpgp"} } handler = func(ctx context.Context, _ chan<- *protocol.Command) error { if err := sink.HandleResponse(ctx, r.Announce, payload); err != nil { return fmt.Errorf("error from response sink: %w", err) } return nil } default: return nil, fmt.Errorf("unhandled response %s", payload) } return handler, nil } func (c *Client) RunSources(ctx context.Context, wg *sync.WaitGroup, errorChan chan error) { for _, source := range c.commandSources { go func(s CommandSource) { defer wg.Done() err := s.Run(ctx) if err != nil { c.logger.WithError(err).Error("command source failed") errorChan <- err } c.logger.Info("command source stopped") }(source) } } func New( cfg *config.ClientConfig, logger *logrus.Logger, ) (*Client, error) { cobsFramer, err := protocol.NewCOBSFramer(logger) if err != nil { return nil, fmt.Errorf("could not create COBS framer: %w", err) } client := &Client{ logger: logger, framer: cobsFramer, config: cfg, knownCACertificates: make(map[string]*CACertificateInfo), responseSinks: make(map[messages.ResponseCode]ResponseSink), commandSources: make([]CommandSource, 0), } err = client.setupConnection(&serial.Config{ Name: cfg.Serial.Device, Baud: cfg.Serial.Baud, ReadTimeout: cfg.Serial.Timeout, }) if err != nil { return nil, err } return client, nil }