Refactor HSM setup

- create new type hsm.Access to encapsulate HSM operations
- make setup options operate on hsm.Access instances
- adapt tests and cmd/signer to work with hsm.Access
main
Jan Dittberner 2 years ago
parent 7acec714e3
commit 0d69a9013d

@ -40,11 +40,13 @@ func main() {
var ( var (
showVersion, setupMode, verbose bool showVersion, setupMode, verbose bool
signerConfigFile string signerConfigFile string
infoLog, errorLog *log.Logger
) )
log.SetFlags(log.Ldate | log.Lmicroseconds | log.LUTC) infoLog = log.New(os.Stdout, "INFO ", log.Ldate|log.Lmicroseconds|log.LUTC)
errorLog = log.New(os.Stderr, "ERROR ", log.Ldate|log.Lmicroseconds|log.LUTC)
log.Printf("cacert-gosigner %s (%s) - built %s\n", version, commit, date) infoLog.Printf("cacert-gosigner %s (%s) - built %s\n", version, commit, date)
flag.StringVar(&signerConfigFile, "config", defaultSignerConfigFile, "signer configuration file") flag.StringVar(&signerConfigFile, "config", defaultSignerConfigFile, "signer configuration file")
flag.BoolVar(&showVersion, "version", false, "show version") flag.BoolVar(&showVersion, "version", false, "show version")
@ -59,20 +61,20 @@ func main() {
configFile, err := os.Open(signerConfigFile) configFile, err := os.Open(signerConfigFile)
if err != nil { if err != nil {
log.Fatalf("could not open signer configuration file %s: %v", signerConfigFile, err) errorLog.Fatalf("could not open signer configuration file %s: %v", signerConfigFile, err)
} }
opts := make([]hsm.ConfigOption, 0) opts := make([]hsm.ConfigOption, 0)
caConfig, err := config.LoadConfiguration(configFile) caConfig, err := config.LoadConfiguration(configFile)
if err != nil { if err != nil {
log.Fatalf("could not load CA hierarchy: %v", err) errorLog.Fatalf("could not load CA hierarchy: %v", err)
} }
opts = append(opts, hsm.CaConfigOption(caConfig)) opts = append(opts, hsm.CaConfigOption(caConfig))
if setupMode { if setupMode {
log.Print("running in setup mode") infoLog.Print("running in setup mode")
opts = append(opts, hsm.SetupModeOption()) opts = append(opts, hsm.SetupModeOption())
} }
@ -81,16 +83,19 @@ func main() {
opts = append(opts, hsm.VerboseLoggingOption()) opts = append(opts, hsm.VerboseLoggingOption())
} }
ctx := hsm.SetupContext(opts...) acc, err := hsm.NewAccess(infoLog, opts...)
if err != nil {
errorLog.Fatalf("could not setup HSM access: %v", err)
}
err = hsm.EnsureCAKeysAndCertificates(ctx) err = acc.EnsureCAKeysAndCertificates()
if err != nil { if err != nil {
log.Fatalf("could not ensure CA keys and certificates exist: %v", err) errorLog.Fatalf("could not ensure CA keys and certificates exist: %v", err)
} }
if setupMode { if setupMode {
return return
} }
log.Print("setup complete, starting signer operation") infoLog.Print("setup complete, starting signer operation")
} }

@ -239,7 +239,16 @@ func (c *SignerConfig) GetParentCA(label string) (*CaCertificateEntry, error) {
return nil, fmt.Errorf("CA %s is a root CA and has no parent", label) return nil, fmt.Errorf("CA %s is a root CA and has no parent", label)
} }
return c.caMap[entry.Parent], nil if entry.Parent == "" {
return nil, fmt.Errorf("parent for %s is empty", label)
}
parent, ok := c.caMap[entry.Parent]
if !ok {
return nil, fmt.Errorf("parent %s for %s not found in signer config", entry.Parent, label)
}
return parent, nil
} }
// RootCAs returns the labels of all configured root CAs // RootCAs returns the labels of all configured root CAs

@ -18,8 +18,6 @@ limitations under the License.
package hsm package hsm
import ( import (
"context"
"errors"
"fmt" "fmt"
"github.com/ThalesIgnite/crypto11" "github.com/ThalesIgnite/crypto11"
@ -27,111 +25,72 @@ import (
"git.cacert.org/cacert-gosigner/pkg/config" "git.cacert.org/cacert-gosigner/pkg/config"
) )
type ctxKey int type ConfigOption func(a *Access)
const ( func CADirectoryOption(path string) func(a *Access) {
ctxCADirectory ctxKey = iota return func(a *Access) {
ctxP11Contexts a.caDirectory = path
ctxSetupMode
ctxSignerConfig
ctxVerboseLogging
)
type ConfigOption func(ctx context.Context) context.Context
func CADirectoryOption(path string) func(ctx context.Context) context.Context {
return func(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxCADirectory, path)
} }
} }
func CaConfigOption(signerConfig *config.SignerConfig) func(context.Context) context.Context { func CaConfigOption(signerConfig *config.SignerConfig) func(a *Access) {
return func(ctx context.Context) context.Context { return func(a *Access) {
return context.WithValue(ctx, ctxSignerConfig, signerConfig) a.signerConfig = signerConfig
} }
} }
func SetupModeOption() func(context.Context) context.Context { func SetupModeOption() func(a *Access) {
return func(ctx context.Context) context.Context { return func(a *Access) {
return context.WithValue(ctx, ctxSetupMode, true) a.setupMode = true
} }
} }
func VerboseLoggingOption() func(ctx context.Context) context.Context { func VerboseLoggingOption() func(a *Access) {
return func(ctx context.Context) context.Context { return func(a *Access) {
return context.WithValue(ctx, ctxVerboseLogging, true) a.verbose = true
} }
} }
// SetupContext sets global context for HSM operations. // setupContext sets global context for HSM operations.
func SetupContext(options ...ConfigOption) context.Context { func (a *Access) setupContext(options ...ConfigOption) {
ctx := context.Background() a.p11Contexts = make(map[string]*crypto11.Context)
ctx = context.WithValue(ctx, ctxP11Contexts, make(map[string]*crypto11.Context))
for _, opt := range options { for _, opt := range options {
ctx = opt(ctx) opt(a)
} }
return ctx
} }
func GetSignerConfig(ctx context.Context) *config.SignerConfig { func (a *Access) GetSignerConfig() *config.SignerConfig {
signerConfig, ok := ctx.Value(ctxSignerConfig).(*config.SignerConfig) return a.signerConfig
if !ok {
return nil
}
return signerConfig
} }
func IsSetupMode(ctx context.Context) bool { func (a *Access) IsSetupMode() bool {
setupMode, ok := ctx.Value(ctxSetupMode).(bool) return a.setupMode
if !ok {
return false
}
return setupMode
} }
func IsVerbose(ctx context.Context) bool { func (a *Access) IsVerbose() bool {
verbose, ok := ctx.Value(ctxVerboseLogging).(bool) return a.verbose
if !ok {
return false
}
return verbose
} }
func GetP11Context(ctx context.Context, entry *config.CaCertificateEntry) (*crypto11.Context, error) { func (a *Access) GetP11Context(entry *config.CaCertificateEntry) (*crypto11.Context, error) {
contexts, ok := ctx.Value(ctxP11Contexts).(map[string]*crypto11.Context) if p11Context, ok := a.p11Contexts[entry.Storage]; ok {
if !ok {
return nil, errors.New("type assertion failed, use hsm.SetupContext first")
}
if p11Context, ok := contexts[entry.Storage]; ok {
return p11Context, nil return p11Context, nil
} }
p11Context, err := prepareCrypto11Context(ctx, entry.Storage) p11Context, err := a.prepareCrypto11Context(entry.Storage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
contexts[entry.Storage] = p11Context a.p11Contexts[entry.Storage] = p11Context
return p11Context, nil return p11Context, nil
} }
func CloseP11Contexts(ctx context.Context) error { func (a *Access) CloseP11Contexts() error {
contexts, ok := ctx.Value(ctxP11Contexts).(map[string]*crypto11.Context)
if !ok {
return errors.New("type assertion failed, use hsm.SetupContext first")
}
seen := make(map[*crypto11.Context]struct{}, 0) seen := make(map[*crypto11.Context]struct{}, 0)
for _, p11Context := range contexts { for _, p11Context := range a.p11Contexts {
if _, ok := seen[p11Context]; ok { if _, ok := seen[p11Context]; ok {
continue continue
} }

@ -18,8 +18,8 @@ limitations under the License.
package hsm_test package hsm_test
import ( import (
"context"
"fmt" "fmt"
"log"
"os" "os"
"os/exec" "os/exec"
"path" "path"
@ -36,67 +36,72 @@ import (
func TestCaConfigOption(t *testing.T) { func TestCaConfigOption(t *testing.T) {
testSignerConfig := config.SignerConfig{} testSignerConfig := config.SignerConfig{}
ctx := hsm.SetupContext(hsm.CaConfigOption(&testSignerConfig)) access, err := hsm.NewAccess(log.Default(), hsm.CaConfigOption(&testSignerConfig))
assert.NoError(t, err)
assert.Equal(t, &testSignerConfig, hsm.GetSignerConfig(ctx)) assert.Equal(t, &testSignerConfig, access.GetSignerConfig())
} }
func TestGetSignerConfig_empty(t *testing.T) { func TestGetSignerConfig_empty(t *testing.T) {
ctx := hsm.SetupContext() access, err := hsm.NewAccess(log.Default())
assert.NoError(t, err)
assert.Nil(t, hsm.GetSignerConfig(ctx)) assert.Nil(t, access.GetSignerConfig())
} }
func TestSetupModeOption(t *testing.T) { func TestSetupModeOption(t *testing.T) {
ctx := hsm.SetupContext(hsm.SetupModeOption()) access, err := hsm.NewAccess(log.Default(), hsm.SetupModeOption())
assert.NoError(t, err)
assert.True(t, hsm.IsSetupMode(ctx)) assert.True(t, access.IsSetupMode())
} }
func TestIsSetupMode_not_set(t *testing.T) { func TestIsSetupMode_not_set(t *testing.T) {
ctx := hsm.SetupContext() access, err := hsm.NewAccess(log.Default())
assert.NoError(t, err)
assert.False(t, hsm.IsSetupMode(ctx)) assert.False(t, access.IsSetupMode())
} }
func TestVerboseLoggingOption(t *testing.T) { func TestVerboseLoggingOption(t *testing.T) {
ctx := hsm.SetupContext(hsm.VerboseLoggingOption()) access, err := hsm.NewAccess(log.Default(), hsm.VerboseLoggingOption())
assert.NoError(t, err)
assert.True(t, hsm.IsVerbose(ctx)) assert.True(t, access.IsVerbose())
} }
func TestIsVerbose_not_set(t *testing.T) { func TestIsVerbose_not_set(t *testing.T) {
ctx := hsm.SetupContext() access, err := hsm.NewAccess(log.Default())
assert.NoError(t, err)
assert.False(t, hsm.IsVerbose(ctx)) assert.False(t, access.IsVerbose())
} }
func TestSetupContext(t *testing.T) { func TestSetupContext(t *testing.T) {
testConfig := setupSignerConfig(t) testConfig := setupSignerConfig(t)
ctx := hsm.SetupContext(hsm.SetupModeOption(), hsm.VerboseLoggingOption(), hsm.CaConfigOption(testConfig)) access, err := hsm.NewAccess(
log.Default(),
assert.True(t, hsm.IsSetupMode(ctx)) hsm.SetupModeOption(),
assert.True(t, hsm.IsVerbose(ctx)) hsm.VerboseLoggingOption(),
assert.Equal(t, hsm.GetSignerConfig(ctx), testConfig) hsm.CaConfigOption(testConfig),
} )
assert.NoError(t, err)
func TestGetP11Context_missing_SetupContext(t *testing.T) {
p11Context, err := hsm.GetP11Context(context.Background(), &config.CaCertificateEntry{Storage: "default"})
assert.Error(t, err) assert.True(t, access.IsSetupMode())
assert.ErrorContains(t, err, "type assertion failed, use hsm.SetupContext first") assert.True(t, access.IsVerbose())
assert.Nil(t, p11Context) assert.Equal(t, access.GetSignerConfig(), testConfig)
} }
func TestGetP11Context_unknown_storage(t *testing.T) { func TestGetP11Context_unknown_storage(t *testing.T) {
testConfig := setupSignerConfig(t) testConfig := setupSignerConfig(t)
ctx := hsm.SetupContext(hsm.SetupModeOption(), hsm.CaConfigOption(testConfig)) access, err := hsm.NewAccess(log.Default(), hsm.SetupModeOption(), hsm.CaConfigOption(testConfig))
assert.NoError(t, err)
definition := &config.CaCertificateEntry{Storage: "undefined"} definition := &config.CaCertificateEntry{Storage: "undefined"}
p11Context, err := hsm.GetP11Context(ctx, definition) p11Context, err := access.GetP11Context(definition)
assert.Error(t, err) assert.Error(t, err)
assert.ErrorContains(t, err, "key storage undefined not available") assert.ErrorContains(t, err, "key storage undefined not available")
@ -109,13 +114,14 @@ func TestGetP11Context_wrong_pin(t *testing.T) {
t.Setenv("TOKEN_PIN_ACME_TEST_HSM", "wrongpin") t.Setenv("TOKEN_PIN_ACME_TEST_HSM", "wrongpin")
ctx := hsm.SetupContext(hsm.CaConfigOption(testConfig)) access, err := hsm.NewAccess(log.Default(), hsm.CaConfigOption(testConfig))
assert.NoError(t, err)
definition, err := testConfig.GetCADefinition("root") definition, err := testConfig.GetCADefinition("root")
require.NoError(t, err) require.NoError(t, err)
_, err = hsm.GetP11Context(ctx, definition) _, err = access.GetP11Context(definition)
assert.ErrorContains(t, err, "could not configure PKCS#11 library") assert.ErrorContains(t, err, "could not configure PKCS#11 library")
} }
@ -124,13 +130,14 @@ func TestGetP11Context_no_pin(t *testing.T) {
testConfig := setupSignerConfig(t) testConfig := setupSignerConfig(t)
setupSoftHsm(t) setupSoftHsm(t)
ctx := hsm.SetupContext(hsm.CaConfigOption(testConfig)) access, err := hsm.NewAccess(log.Default(), hsm.CaConfigOption(testConfig))
assert.NoError(t, err)
definition, err := testConfig.GetCADefinition("root") definition, err := testConfig.GetCADefinition("root")
require.NoError(t, err) require.NoError(t, err)
_, err = hsm.GetP11Context(ctx, definition) _, err = access.GetP11Context(definition)
assert.ErrorContains(t, err, "stdin is not a terminal") assert.ErrorContains(t, err, "stdin is not a terminal")
} }
@ -141,25 +148,26 @@ func TestGetP11Context(t *testing.T) {
t.Setenv("TOKEN_PIN_ACME_TEST_HSM", "123456") t.Setenv("TOKEN_PIN_ACME_TEST_HSM", "123456")
ctx := hsm.SetupContext(hsm.CaConfigOption(testConfig)) access, err := hsm.NewAccess(log.Default(), hsm.CaConfigOption(testConfig))
assert.NoError(t, err)
definition, err := testConfig.GetCADefinition("root") definition, err := testConfig.GetCADefinition("root")
require.NoError(t, err) require.NoError(t, err)
p11Context1, err := hsm.GetP11Context(ctx, definition) p11Context1, err := access.GetP11Context(definition)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, p11Context1) assert.NotNil(t, p11Context1)
p11Context2, err := hsm.GetP11Context(ctx, definition) p11Context2, err := access.GetP11Context(definition)
assert.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
err := hsm.CloseP11Contexts(ctx) err := access.CloseP11Contexts()
assert.NoError(t, err) assert.NoError(t, err)
}) })
assert.NoError(t, err)
assert.NotNil(t, p11Context1) assert.NotNil(t, p11Context1)
assert.Equal(t, p11Context1, p11Context2) assert.Equal(t, p11Context1, p11Context2)
} }
@ -212,10 +220,10 @@ func setupSignerConfig(t *testing.T) *config.SignerConfig {
func setupSoftHsm(t *testing.T) { func setupSoftHsm(t *testing.T) {
t.Helper() t.Helper()
tempdir := t.TempDir() tempDir := t.TempDir()
tokenDir := path.Join(tempdir, "tokens") tokenDir := path.Join(tempDir, "tokens")
softhsmConfig := path.Join(tempdir, "softhsm2.conf") softhsmConfig := path.Join(tempDir, "softhsm2.conf")
err := os.Mkdir(tokenDir, 0o700) err := os.Mkdir(tokenDir, 0o700)
@ -241,10 +249,3 @@ func setupSoftHsm(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestCloseP11Contexts_without_setup(t *testing.T) {
ctx := context.Background()
err := hsm.CloseP11Contexts(ctx)
assert.ErrorContains(t, err, "type assertion failed, use hsm.SetupContext first")
}

@ -19,7 +19,6 @@ limitations under the License.
package hsm package hsm
import ( import (
"context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand" "crypto/rand"
@ -54,25 +53,34 @@ type caFile struct {
label string label string
} }
func (c *caFile) buildCertificatePath(ctx context.Context) (string, error) { type Access struct {
fileName := c.sc.CertificateFileName(c.label) infoLog *log.Logger
caDirectory string
signerConfig *config.SignerConfig
p11Contexts map[string]*crypto11.Context
setupMode bool
verbose bool
}
caDir := ctx.Value(ctxCADirectory) func NewAccess(infoLog *log.Logger, options ...ConfigOption) (*Access, error) {
access := &Access{infoLog: infoLog}
access.setupContext(options...)
if caDir != nil { return access, nil
caPath, ok := caDir.(string) }
if !ok {
return "", errors.New("context object CA directory is not a string")
}
return path.Join(caPath, fileName), nil func (c *caFile) buildCertificatePath(caDirectory string) (string, error) {
fileName := c.sc.CertificateFileName(c.label)
if caDirectory == "" {
return "", errors.New("CA directory is not set")
} }
return fileName, nil return path.Join(caDirectory, fileName), nil
} }
func (c *caFile) loadCertificate(ctx context.Context) (*x509.Certificate, error) { func (c *caFile) loadCertificate(caDirectory string) (*x509.Certificate, error) {
certFile, err := c.buildCertificatePath(ctx) certFile, err := c.buildCertificatePath(caDirectory)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,8 +120,8 @@ func (c *caFile) loadCertificate(ctx context.Context) (*x509.Certificate, error)
return certificate, nil return certificate, nil
} }
func (c *caFile) storeCertificate(ctx context.Context, certificate []byte) error { func (c *caFile) storeCertificate(caDirectory string, certificate []byte) error {
certFile, err := c.buildCertificatePath(ctx) certFile, err := c.buildCertificatePath(caDirectory)
if err != nil { if err != nil {
return err return err
} }
@ -126,13 +134,13 @@ func (c *caFile) storeCertificate(ctx context.Context, certificate []byte) error
return nil return nil
} }
func GetRootCACertificate(ctx context.Context, label string) (*x509.Certificate, error) { func (a *Access) GetRootCACertificate(label string) (*x509.Certificate, error) {
var ( var (
certificate *x509.Certificate certificate *x509.Certificate
keyPair crypto.Signer keyPair crypto.Signer
) )
sc := GetSignerConfig(ctx) sc := a.GetSignerConfig()
caCert, err := sc.GetCADefinition(label) caCert, err := sc.GetCADefinition(label)
if err != nil { if err != nil {
@ -145,37 +153,39 @@ func GetRootCACertificate(ctx context.Context, label string) (*x509.Certificate,
caFile := &caFile{sc: sc, label: label} caFile := &caFile{sc: sc, label: label}
certificate, err = caFile.loadCertificate(ctx) certificate, err = caFile.loadCertificate(a.caDirectory)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if certificate != nil && !IsSetupMode(ctx) { if certificate != nil && !a.IsSetupMode() {
caCert.Certificate = certificate caCert.Certificate = certificate
return certificate, nil return certificate, nil
} }
keyPair, err = getKeyPair(ctx, label, caCert.KeyInfo) keyPair, err = a.getKeyPair(label, caCert.KeyInfo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if certificate != nil && certificateMatches(certificate, keyPair) { if certificate != nil {
caCert.Certificate, caCert.KeyPair = certificate, keyPair err := certificateMatches(certificate, keyPair)
if err != nil {
return nil, err
}
return certificate, nil return certificate, nil
} }
if !IsSetupMode(ctx) { if !a.IsSetupMode() {
return nil, errCertificateGenerationRefused return nil, errCertificateGenerationRefused
} }
notBefore, notAfter := sc.CalculateValidity(caCert, time.Now()) notBefore, notAfter := sc.CalculateValidity(caCert, time.Now())
subject := sc.CalculateSubject(caCert) subject := sc.CalculateSubject(caCert)
certificate, err = generateRootCACertificate( certificate, err = a.generateRootCACertificate(
ctx,
caFile, caFile,
keyPair, keyPair,
&x509.Certificate{ &x509.Certificate{
@ -194,7 +204,7 @@ func GetRootCACertificate(ctx context.Context, label string) (*x509.Certificate,
return nil, err return nil, err
} }
p11Context, err := GetP11Context(ctx, caCert) p11Context, err := a.GetP11Context(caCert)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -209,13 +219,13 @@ func GetRootCACertificate(ctx context.Context, label string) (*x509.Certificate,
return certificate, nil return certificate, nil
} }
func GetIntermediaryCACertificate(ctx context.Context, certLabel string) (*x509.Certificate, error) { func (a *Access) GetIntermediaryCACertificate(certLabel string) (*x509.Certificate, error) {
var ( var (
certificate *x509.Certificate certificate *x509.Certificate
keyPair crypto.Signer keyPair crypto.Signer
) )
sc := GetSignerConfig(ctx) sc := a.GetSignerConfig()
caCert, err := sc.GetCADefinition(certLabel) caCert, err := sc.GetCADefinition(certLabel)
if err != nil { if err != nil {
@ -229,33 +239,35 @@ func GetIntermediaryCACertificate(ctx context.Context, certLabel string) (*x509.
) )
} }
keyPair, err = getKeyPair(ctx, certLabel, caCert.KeyInfo) keyPair, err = a.getKeyPair(certLabel, caCert.KeyInfo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
certFile := &caFile{sc: sc, label: certLabel} certFile := &caFile{sc: sc, label: certLabel}
certificate, err = certFile.loadCertificate(ctx) certificate, err = certFile.loadCertificate(a.caDirectory)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if certificate != nil && certificateMatches(certificate, keyPair) { if certificate != nil {
caCert.Certificate, caCert.KeyPair = certificate, keyPair err := certificateMatches(certificate, keyPair)
if err != nil {
return nil, err
}
return certificate, nil return certificate, nil
} }
if !IsSetupMode(ctx) { if !a.IsSetupMode() {
return nil, errCertificateGenerationRefused return nil, errCertificateGenerationRefused
} }
notBefore, notAfter := sc.CalculateValidity(caCert, time.Now()) notBefore, notAfter := sc.CalculateValidity(caCert, time.Now())
subject := sc.CalculateSubject(caCert) subject := sc.CalculateSubject(caCert)
certificate, err = generateIntermediaryCACertificate( certificate, err = a.generateIntermediaryCACertificate(
ctx,
certFile, certFile,
sc, sc,
certLabel, certLabel,
@ -284,7 +296,7 @@ func GetIntermediaryCACertificate(ctx context.Context, certLabel string) (*x509.
return nil, err return nil, err
} }
p11Context, err := GetP11Context(ctx, caCert) p11Context, err := a.GetP11Context(caCert)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -299,8 +311,7 @@ func GetIntermediaryCACertificate(ctx context.Context, certLabel string) (*x509.
return certificate, nil return certificate, nil
} }
func generateIntermediaryCACertificate( func (a *Access) generateIntermediaryCACertificate(
ctx context.Context,
certFile *caFile, certFile *caFile,
config *config.SignerConfig, config *config.SignerConfig,
certLabel string, certLabel string,
@ -340,7 +351,7 @@ func generateIntermediaryCACertificate(
Bytes: certBytes, Bytes: certBytes,
} }
err = certFile.storeCertificate(ctx, pem.EncodeToMemory(certBlock)) err = certFile.storeCertificate(a.caDirectory, pem.EncodeToMemory(certBlock))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -367,8 +378,8 @@ func addCertificate(p11Context *crypto11.Context, label string, certificate *x50
return nil return nil
} }
func getKeyPair(ctx context.Context, label string, keyInfo *config.PrivateKeyInfo) (crypto.Signer, error) { func (a *Access) getKeyPair(label string, keyInfo *config.PrivateKeyInfo) (crypto.Signer, error) {
sc := GetSignerConfig(ctx) sc := a.GetSignerConfig()
cert, err := sc.GetCADefinition(label) cert, err := sc.GetCADefinition(label)
if err != nil { if err != nil {
@ -379,7 +390,7 @@ func getKeyPair(ctx context.Context, label string, keyInfo *config.PrivateKeyInf
return cert.KeyPair, nil return cert.KeyPair, nil
} }
p11Context, err := GetP11Context(ctx, cert) p11Context, err := a.GetP11Context(cert)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -393,7 +404,7 @@ func getKeyPair(ctx context.Context, label string, keyInfo *config.PrivateKeyInf
return keyPair, nil return keyPair, nil
} }
if !IsSetupMode(ctx) { if !a.IsSetupMode() {
return nil, errKeyGenerationRefused return nil, errKeyGenerationRefused
} }
@ -480,8 +491,7 @@ func randomObjectID() ([]byte, error) {
return result, nil return result, nil
} }
func generateRootCACertificate( func (a *Access) generateRootCACertificate(
ctx context.Context,
certFile *caFile, certFile *caFile,
keyPair crypto.Signer, keyPair crypto.Signer,
template *x509.Certificate, template *x509.Certificate,
@ -514,7 +524,7 @@ func generateRootCACertificate(
Bytes: certBytes, Bytes: certBytes,
} }
if err = certFile.storeCertificate(ctx, pem.EncodeToMemory(certBlock)); err != nil { if err = certFile.storeCertificate(a.caDirectory, pem.EncodeToMemory(certBlock)); err != nil {
return nil, err return nil, err
} }
@ -538,31 +548,29 @@ func determineSignatureAlgorithm(keyPair crypto.Signer) (x509.SignatureAlgorithm
} }
} }
func certificateMatches(certificate *x509.Certificate, key crypto.Signer) bool { func certificateMatches(certificate *x509.Certificate, key crypto.Signer) error {
switch v := certificate.PublicKey.(type) { switch v := certificate.PublicKey.(type) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
if pub, ok := key.Public().(*ecdsa.PublicKey); ok { if pub, ok := key.Public().(*ecdsa.PublicKey); ok {
if v.Equal(pub) { if v.Equal(pub) {
return true return nil
} }
} }
case *rsa.PublicKey: case *rsa.PublicKey:
if pub, ok := key.Public().(*rsa.PublicKey); ok { if pub, ok := key.Public().(*rsa.PublicKey); ok {
if v.Equal(pub) { if v.Equal(pub) {
return true return nil
} }
} }
default: default:
log.Printf("unsupported public key %v", v) return fmt.Errorf("unsupported public key %v", v)
} }
log.Printf( return fmt.Errorf(
"public key from certificate does not match private key: %s != %s", "public key from certificate does not match private key: %s != %s",
certificate.PublicKey, certificate.PublicKey,
key.Public(), key.Public(),
) )
return false
} }
func randomSerialNumber() (*big.Int, error) { func randomSerialNumber() (*big.Int, error) {

@ -18,15 +18,15 @@ limitations under the License.
package hsm_test package hsm_test
import ( import (
"context"
"crypto/x509" "crypto/x509"
"log"
"strings" "strings"
"testing" "testing"
"git.cacert.org/cacert-gosigner/pkg/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"git.cacert.org/cacert-gosigner/pkg/config"
"git.cacert.org/cacert-gosigner/pkg/hsm" "git.cacert.org/cacert-gosigner/pkg/hsm"
) )
@ -36,21 +36,22 @@ func TestEnsureCAKeysAndCertificates_not_in_setup_mode(t *testing.T) {
t.Setenv("TOKEN_PIN_ACME_TEST_HSM", "123456") t.Setenv("TOKEN_PIN_ACME_TEST_HSM", "123456")
ctx := hsm.SetupContext( acc, err := hsm.NewAccess(log.Default(),
hsm.CaConfigOption(testConfig), hsm.CaConfigOption(testConfig),
hsm.CADirectoryOption(t.TempDir())) hsm.CADirectoryOption(t.TempDir()))
assert.NoError(t, err)
err := hsm.EnsureCAKeysAndCertificates(ctx) err = acc.EnsureCAKeysAndCertificates()
t.Cleanup(func() { t.Cleanup(func() {
err := hsm.CloseP11Contexts(ctx) err := acc.CloseP11Contexts()
assert.NoError(t, err) assert.NoError(t, err)
}) })
assert.ErrorContains(t, err, "not in setup mode") assert.ErrorContains(t, err, "not in setup mode")
} }
func prepareSoftHSM(t *testing.T) context.Context { func prepareSoftHSM(t *testing.T) *hsm.Access {
t.Helper() t.Helper()
testConfig := setupSignerConfig(t) testConfig := setupSignerConfig(t)
@ -58,25 +59,26 @@ func prepareSoftHSM(t *testing.T) context.Context {
t.Setenv("TOKEN_PIN_ACME_TEST_HSM", "123456") t.Setenv("TOKEN_PIN_ACME_TEST_HSM", "123456")
ctx := hsm.SetupContext( acc, err := hsm.NewAccess(log.Default(),
hsm.CaConfigOption(testConfig), hsm.CaConfigOption(testConfig),
hsm.SetupModeOption(), hsm.SetupModeOption(),
hsm.CADirectoryOption(t.TempDir())) hsm.CADirectoryOption(t.TempDir()))
assert.NoError(t, err)
err := hsm.EnsureCAKeysAndCertificates(ctx) err = acc.EnsureCAKeysAndCertificates()
t.Cleanup(func() { t.Cleanup(func() {
err := hsm.CloseP11Contexts(ctx) err := acc.CloseP11Contexts()
assert.NoError(t, err) assert.NoError(t, err)
}) })
require.NoError(t, err) require.NoError(t, err)
return ctx return acc
} }
func TestGetRootCACertificate(t *testing.T) { func TestGetRootCACertificate(t *testing.T) {
ctx := prepareSoftHSM(t) acc := prepareSoftHSM(t)
testData := map[string]struct { testData := map[string]struct {
label, errMsg string label, errMsg string
@ -96,7 +98,7 @@ func TestGetRootCACertificate(t *testing.T) {
for name, item := range testData { for name, item := range testData {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
root, err := hsm.GetRootCACertificate(ctx, item.label) root, err := acc.GetRootCACertificate(item.label)
if item.errMsg != "" { if item.errMsg != "" {
assert.ErrorContains(t, err, item.errMsg) assert.ErrorContains(t, err, item.errMsg)
@ -110,7 +112,7 @@ func TestGetRootCACertificate(t *testing.T) {
} }
func TestGetIntermediaryCACertificate(t *testing.T) { func TestGetIntermediaryCACertificate(t *testing.T) {
ctx := prepareSoftHSM(t) acc := prepareSoftHSM(t)
testData := map[string]struct { testData := map[string]struct {
label, errMsg string label, errMsg string
@ -130,7 +132,7 @@ func TestGetIntermediaryCACertificate(t *testing.T) {
for name, item := range testData { for name, item := range testData {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
root, err := hsm.GetIntermediaryCACertificate(ctx, item.label) root, err := acc.GetIntermediaryCACertificate(item.label)
if item.errMsg != "" { if item.errMsg != "" {
assert.ErrorContains(t, err, item.errMsg) assert.ErrorContains(t, err, item.errMsg)
@ -166,6 +168,7 @@ KeyStorage:
type: softhsm type: softhsm
label: acme-test-hsm label: acme-test-hsm
` `
testConfig, err := config.LoadConfiguration(strings.NewReader(testRSASignerConfig)) testConfig, err := config.LoadConfiguration(strings.NewReader(testRSASignerConfig))
require.NoError(t, err) require.NoError(t, err)
@ -174,21 +177,22 @@ KeyStorage:
t.Setenv("TOKEN_PIN_ACME_TEST_HSM", "123456") t.Setenv("TOKEN_PIN_ACME_TEST_HSM", "123456")
ctx := hsm.SetupContext( acc, err := hsm.NewAccess(
log.Default(),
hsm.CaConfigOption(testConfig), hsm.CaConfigOption(testConfig),
hsm.SetupModeOption(), hsm.SetupModeOption(),
hsm.CADirectoryOption(t.TempDir())) hsm.CADirectoryOption(t.TempDir()))
assert.NoError(t, err)
err = hsm.EnsureCAKeysAndCertificates(ctx) err = acc.EnsureCAKeysAndCertificates()
require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
err := hsm.CloseP11Contexts(ctx) err := acc.CloseP11Contexts()
assert.NoError(t, err) assert.NoError(t, err)
}) })
require.NoError(t, err) root, err := acc.GetRootCACertificate("root")
root, err := hsm.GetRootCACertificate(ctx, "root")
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, root) assert.NotNil(t, root)

@ -17,24 +17,19 @@ limitations under the License.
package hsm package hsm
import ( func (a *Access) EnsureCAKeysAndCertificates() error {
"context"
"log"
)
func EnsureCAKeysAndCertificates(ctx context.Context) error {
var label string var label string
conf := GetSignerConfig(ctx) conf := a.GetSignerConfig()
for _, label = range conf.RootCAs() { for _, label = range conf.RootCAs() {
crt, err := GetRootCACertificate(ctx, label) crt, err := a.GetRootCACertificate(label)
if err != nil { if err != nil {
return err return err
} }
if IsVerbose(ctx) { if a.IsVerbose() {
log.Printf( a.infoLog.Printf(
"found root CA certificate %s:\n Subject %s\n Issuer %s\n Valid from %s until %s\n Serial %s", "found root CA certificate %s:\n Subject %s\n Issuer %s\n Valid from %s until %s\n Serial %s",
label, label,
crt.Subject, crt.Subject,
@ -43,18 +38,18 @@ func EnsureCAKeysAndCertificates(ctx context.Context) error {
crt.NotAfter, crt.NotAfter,
crt.SerialNumber) crt.SerialNumber)
} else { } else {
log.Printf("found root CA certificate %s: %s", label, crt.Subject.CommonName) a.infoLog.Printf("found root CA certificate %s: %s", label, crt.Subject.CommonName)
} }
} }
for _, label = range conf.IntermediaryCAs() { for _, label = range conf.IntermediaryCAs() {
crt, err := GetIntermediaryCACertificate(ctx, label) crt, err := a.GetIntermediaryCACertificate(label)
if err != nil { if err != nil {
return err return err
} }
if IsVerbose(ctx) { if a.IsVerbose() {
log.Printf( a.infoLog.Printf(
"found intermediary CA certificate %s:\n Subject %s\n Issuer %s\n Valid from %s until %s\n Serial %s", "found intermediary CA certificate %s:\n Subject %s\n Issuer %s\n Valid from %s until %s\n Serial %s",
label, label,
crt.Subject, crt.Subject,
@ -63,7 +58,7 @@ func EnsureCAKeysAndCertificates(ctx context.Context) error {
crt.NotAfter, crt.NotAfter,
crt.SerialNumber) crt.SerialNumber)
} else { } else {
log.Printf("found intermediary CA certificate %s: %s", label, crt.Subject.CommonName) a.infoLog.Printf("found intermediary CA certificate %s: %s", label, crt.Subject.CommonName)
} }
} }

@ -35,17 +35,17 @@ func TestEnsureCAKeysAndCertificates(t *testing.T) {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
log.SetOutput(buf) testLogger := log.New(buf, "TEST ", log.LstdFlags)
ctx := hsm.SetupContext( acc, err := hsm.NewAccess(testLogger, hsm.CaConfigOption(testConfig),
hsm.CaConfigOption(testConfig),
hsm.SetupModeOption(), hsm.SetupModeOption(),
hsm.CADirectoryOption(t.TempDir())) hsm.CADirectoryOption(t.TempDir()))
assert.NoError(t, err)
err := hsm.EnsureCAKeysAndCertificates(ctx) err = acc.EnsureCAKeysAndCertificates()
t.Cleanup(func() { t.Cleanup(func() {
err := hsm.CloseP11Contexts(ctx) err := acc.CloseP11Contexts()
assert.NoError(t, err) assert.NoError(t, err)
}) })
@ -66,18 +66,18 @@ func TestEnsureCAKeysAndCertificates_verbose(t *testing.T) {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
log.SetOutput(buf) testLogger := log.New(buf, "TEST ", log.LstdFlags)
ctx := hsm.SetupContext( acc, err := hsm.NewAccess(testLogger, hsm.CaConfigOption(testConfig),
hsm.CaConfigOption(testConfig),
hsm.SetupModeOption(), hsm.SetupModeOption(),
hsm.VerboseLoggingOption(), hsm.VerboseLoggingOption(),
hsm.CADirectoryOption(t.TempDir())) hsm.CADirectoryOption(t.TempDir()))
assert.NoError(t, err)
err := hsm.EnsureCAKeysAndCertificates(ctx) err = acc.EnsureCAKeysAndCertificates()
t.Cleanup(func() { t.Cleanup(func() {
err := hsm.CloseP11Contexts(ctx) err := acc.CloseP11Contexts()
assert.NoError(t, err) assert.NoError(t, err)
}) })

@ -18,7 +18,6 @@ limitations under the License.
package hsm package hsm
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"log" "log"
@ -30,13 +29,13 @@ import (
"golang.org/x/term" "golang.org/x/term"
) )
func prepareCrypto11Context(ctx context.Context, label string) (*crypto11.Context, error) { func (a *Access) prepareCrypto11Context(label string) (*crypto11.Context, error) {
var ( var (
err error err error
p11Context *crypto11.Context p11Context *crypto11.Context
) )
storage, err := GetSignerConfig(ctx).GetKeyStorage(label) storage, err := a.GetSignerConfig().GetKeyStorage(label)
if err != nil { if err != nil {
return nil, fmt.Errorf("key storage %s not available: %w", label, err) return nil, fmt.Errorf("key storage %s not available: %w", label, err)
} }
@ -46,10 +45,10 @@ func prepareCrypto11Context(ctx context.Context, label string) (*crypto11.Contex
TokenLabel: storage.Label, TokenLabel: storage.Label,
} }
log.Printf("using PKCS#11 module %s", p11Config.Path) a.infoLog.Printf("using PKCS#11 module %s", p11Config.Path)
log.Printf("looking for token with label %s", p11Config.TokenLabel) a.infoLog.Printf("looking for token with label %s", p11Config.TokenLabel)
p11Config.Pin, err = getPin(p11Config) p11Config.Pin, err = getPin(p11Config, a.infoLog)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -62,17 +61,21 @@ func prepareCrypto11Context(ctx context.Context, label string) (*crypto11.Contex
return p11Context, nil return p11Context, nil
} }
func getPin(p11Config *crypto11.Config) (string, error) { func getPin(p11Config *crypto11.Config, infoLog *log.Logger) (string, error) {
var err error var err error
tokenPinEnv := strings.ReplaceAll(p11Config.TokenLabel, "-", "_") tokenPinEnv := strings.NewReplacer(
tokenPinEnv = strings.ReplaceAll(tokenPinEnv, " ", "_") "-", "_",
" ", "_",
"(", "_",
")", "_",
).Replace(p11Config.TokenLabel)
tokenPinEnv = strings.ToUpper(tokenPinEnv) tokenPinEnv = strings.ToUpper(tokenPinEnv)
tokenPinEnv = fmt.Sprintf("TOKEN_PIN_%s", tokenPinEnv) tokenPinEnv = fmt.Sprintf("TOKEN_PIN_%s", tokenPinEnv)
pin, found := os.LookupEnv(tokenPinEnv) pin, found := os.LookupEnv(tokenPinEnv)
if !found { if !found {
log.Printf("environment variable %s has not been set", tokenPinEnv) infoLog.Printf("environment variable %s has not been set", tokenPinEnv)
if !term.IsTerminal(syscall.Stdin) { if !term.IsTerminal(syscall.Stdin) {
return "", errors.New("stdin is not a terminal") return "", errors.New("stdin is not a terminal")

@ -50,9 +50,7 @@ func (t *testRepo) NextCRLNumber() (*big.Int, error) {
func (t *testRepo) RevokedCertificates() ([]pkix.RevokedCertificate, error) { func (t *testRepo) RevokedCertificates() ([]pkix.RevokedCertificate, error) {
result := make([]pkix.RevokedCertificate, len(t.revoked)) result := make([]pkix.RevokedCertificate, len(t.revoked))
for i, revoked := range t.revoked { copy(result, t.revoked)
result[i] = revoked
}
return result, nil return result, nil
} }

Loading…
Cancel
Save