package handlers import ( "bytes" "crypto/sha256" "crypto/x509" "encoding/pem" "errors" "fmt" "io/ioutil" "os" "os/exec" "time" log "github.com/sirupsen/logrus" ) type SigningRequestRegistry struct { caCertificates []*x509.Certificate caChainMap map[string][]string requests map[string]chan *responseData } func NewSigningRequestRegistry(caCertificates []*x509.Certificate) *SigningRequestRegistry { return &SigningRequestRegistry{ caCertificates: caCertificates, caChainMap: make(map[string][]string), requests: make(map[string]chan *responseData), } } type SigningRequestAttributes struct { CommonName string CSRBytes []byte RequestToken string } func (registry *SigningRequestRegistry) AddSigningRequest(request *requestData) (string, error) { requestToken, csrBytes, err := validateCsr(request.Csr) if err != nil { return "", err } requestAttributes := &SigningRequestAttributes{ CommonName: request.CommonName, CSRBytes: csrBytes, RequestToken: requestToken, } go func() { responseChannel := make(chan *responseData, 1) registry.requests[requestToken] = responseChannel registry.signCertificate(responseChannel, requestAttributes) }() return requestToken, nil } func validateCsr(csr string) (string, []byte, error) { csrBlock, _ := pem.Decode([]byte(csr)) if csrBlock == nil { return "", nil, errors.New("request data did not contain valid PEM data") } if csrBlock.Type != "CERTIFICATE REQUEST" { return "", nil, fmt.Errorf("request is not valid, type in PEM data is %s", csrBlock.Type) } var err error var csrContent *x509.CertificateRequest csrContent, err = x509.ParseCertificateRequest(csrBlock.Bytes) if err != nil { return "", nil, err } if err = csrContent.CheckSignature(); err != nil { log.Errorf("invalid CSR signature %v", err) return "", nil, err } // generate request token as defined in CAB Baseline Requirements 1.7.3 Request Token definition requestToken := fmt.Sprintf( "%s%x", time.Now().UTC().Format("200601021504"), sha256.Sum256(csrContent.Raw), ) log.Debugf("generated request token %s", requestToken) return requestToken, csrContent.Raw, nil } func (registry *SigningRequestRegistry) signCertificate(channel chan *responseData, request *SigningRequestAttributes) { responseData, err := registry.sign(request) if err != nil { log.Error(err) close(channel) return } channel <- responseData } func (registry *SigningRequestRegistry) sign(request *SigningRequestAttributes) (*responseData, error) { log.Infof("handling signing request %s", request.RequestToken) subjectDN := fmt.Sprintf("/CN=%s", request.CommonName) var err error var csrFile *os.File if csrFile, err = ioutil.TempFile("", "*.csr.pem"); err != nil { log.Errorf("could not open temporary file: %s", err) return nil, err } if err = pem.Encode(csrFile, &pem.Block{ Type: "CERTIFICATE REQUEST", Bytes: request.CSRBytes, }); err != nil { log.Errorf("could not write CSR to file: %s", err) return nil, err } if err = csrFile.Close(); err != nil { log.Errorf("could not close CSR file: %s", err) return nil, err } defer func(file *os.File) { err = os.Remove(file.Name()) if err != nil { log.Errorf("could not remove temporary file: %s", err) } }(csrFile) // simulate a delay during certificate creation time.Sleep(5 * time.Second) opensslCommand := exec.Command( "openssl", "ca", "-config", "ca.cnf", "-policy", "policy_match", "-extensions", "client_ext", "-batch", "-subj", subjectDN, "-utf8", "-rand_serial", "-in", csrFile.Name()) var out, cmdErr bytes.Buffer opensslCommand.Stdout = &out opensslCommand.Stderr = &cmdErr err = opensslCommand.Run() if err != nil { log.Error(err) log.Error(cmdErr.String()) return nil, err } var block *pem.Block if block, _ = pem.Decode(out.Bytes()); block == nil { err = fmt.Errorf("could not decode pem") return nil, err } var certificate *x509.Certificate if certificate, err = x509.ParseCertificate(block.Bytes); err != nil { return nil, err } var caChain []string if caChain, err = registry.getCAChain(certificate); err != nil { return nil, err } response := &responseData{ Certificate: string(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: certificate.Raw, })), CAChain: caChain, } return response, nil } func (registry *SigningRequestRegistry) GetResponseChannel(requestUuid string) (chan *responseData, error) { if responseChannel, exists := registry.requests[requestUuid]; exists { delete(registry.requests, requestUuid) return responseChannel, nil } else { return nil, errors.New("no request found") } } func (registry *SigningRequestRegistry) getCAChain(certificate *x509.Certificate) ([]string, error) { issuerString := string(certificate.RawIssuer) if value, exists := registry.caChainMap[issuerString]; exists { return value, nil } result := make([]string, 0) appendCert := func(cert *x509.Certificate) { result = append( result, string(pem.EncodeToMemory(&pem.Block{Bytes: cert.Raw, Type: "CERTIFICATE"}))) log.Debugf("added %s to cachain", result[len(result)-1]) } var previous *x509.Certificate var count = 0 for { if len(registry.caCertificates) == 0 { return nil, errors.New("no CA certificates loaded") } if count > len(registry.caCertificates) { return nil, errors.New("could not construct certificate chain") } for _, caCert := range registry.caCertificates { if previous == nil { if bytes.Equal(caCert.RawSubject, certificate.RawIssuer) { previous = caCert appendCert(caCert) } } else if bytes.Equal(previous.RawSubject, previous.RawIssuer) { registry.caChainMap[issuerString] = result return result, nil } else if bytes.Equal(caCert.RawSubject, previous.RawIssuer) { previous = caCert appendCert(caCert) } else { log.Debugf("skipped certificate %s", caCert.Subject) } } count++ } }