Fix linter warnings, modernize code

This commit is contained in:
Jan Dittberner 2023-05-13 13:27:19 +02:00
parent e828b30b21
commit 2c82ccb324
12 changed files with 763 additions and 504 deletions

View file

@ -1,6 +1,6 @@
/* /*
Copyright 2020, 2021 Jan Dittberner Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -23,7 +23,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -31,24 +30,33 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"git.cacert.org/oidc_idp/ui"
"github.com/go-openapi/runtime/client" "github.com/go-openapi/runtime/client"
"github.com/gorilla/csrf" "github.com/gorilla/csrf"
"github.com/knadh/koanf" "github.com/knadh/koanf"
hydra "github.com/ory/hydra-client-go/client" hydra "github.com/ory/hydra-client-go/client"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"git.cacert.org/oidc_idp/ui"
"git.cacert.org/oidc_idp/handlers" "git.cacert.org/oidc_idp/handlers"
"git.cacert.org/oidc_idp/services" "git.cacert.org/oidc_idp/services"
) )
const (
TimeoutThirty = 30 * time.Second
TimeoutTwenty = 20 * time.Second
DefaultCSRFMaxAge = 600
DefaultServerPort = 3000
)
func main() { func main() {
logger := log.New() logger := log.New()
config, err := services.ConfigureApplication( config, err := services.ConfigureApplication(
logger, logger,
"IDP", "IDP",
map[string]interface{}{ map[string]interface{}{
"server.port": 3000, "server.port": DefaultServerPort,
"server.name": "login.cacert.localhost", "server.name": "login.cacert.localhost",
"server.key": "certs/idp.cacert.localhost.key", "server.key": "certs/idp.cacert.localhost.key",
"server.certificate": "certs/idp.cacert.localhost.crt.pem", "server.certificate": "certs/idp.cacert.localhost.crt.pem",
@ -61,23 +69,28 @@ func main() {
} }
logger.Infoln("Server is starting") logger.Infoln("Server is starting")
ctx := context.Background() bundle, catalog := services.InitI18n(logger, config.Strings("i18n.languages"))
ctx = services.InitI18n(ctx, logger, config.Strings("i18n.languages")) if err = services.AddMessages(catalog); err != nil {
services.AddMessages(ctx) logger.Fatalf("could not add messages for i18n: %v", err)
}
adminURL, err := url.Parse(config.MustString("admin.url")) adminURL, err := url.Parse(config.MustString("admin.url"))
if err != nil { if err != nil {
logger.Fatalf("error parsing admin URL: %v", err) logger.Fatalf("error parsing admin URL: %v", err)
} }
tlsClientConfig := &tls.Config{MinVersion: tls.VersionTLS12} tlsClientConfig := &tls.Config{MinVersion: tls.VersionTLS12}
if config.Exists("api-client.rootCAs") { if config.Exists("api-client.rootCAs") {
rootCAFile := config.MustString("api-client.rootCAs") rootCAFile := config.MustString("api-client.rootCAs")
caCertPool := x509.NewCertPool() caCertPool := x509.NewCertPool()
pemBytes, err := ioutil.ReadFile(rootCAFile)
pemBytes, err := os.ReadFile(rootCAFile)
if err != nil { if err != nil {
log.Fatalf("could not read CA certificate file: %v", err) log.Fatalf("could not read CA certificate file: %v", err)
} }
caCertPool.AppendCertsFromPEM(pemBytes) caCertPool.AppendCertsFromPEM(pemBytes)
tlsClientConfig.RootCAs = caCertPool tlsClientConfig.RootCAs = caCertPool
} }
@ -92,16 +105,10 @@ func main() {
) )
adminClient := hydra.New(clientTransport, nil) adminClient := hydra.New(clientTransport, nil)
handlerContext := context.WithValue(ctx, handlers.CtxAdminClient, adminClient.Admin) loginHandler := handlers.NewLoginHandler(logger, bundle, catalog, adminClient.Admin)
loginHandler, err := handlers.NewLoginHandler(handlerContext, logger) consentHandler := handlers.NewConsentHandler(logger, bundle, catalog, adminClient.Admin)
if err != nil { logoutHandler := handlers.NewLogoutHandler(logger, adminClient.Admin)
logger.Fatalf("error initializing login handler: %v", err)
}
consentHandler, err := handlers.NewConsentHandler(handlerContext, logger)
if err != nil {
logger.Fatalf("error initializing consent handler: %v", err)
}
logoutHandler := handlers.NewLogoutHandler(handlerContext, logger)
logoutSuccessHandler := handlers.NewLogoutSuccessHandler() logoutSuccessHandler := handlers.NewLogoutSuccessHandler()
errorHandler := handlers.NewErrorHandler() errorHandler := handlers.NewErrorHandler()
staticFiles := http.FileServer(http.FS(ui.Static)) staticFiles := http.FileServer(http.FS(ui.Static))
@ -126,20 +133,21 @@ func main() {
logger.Fatalf("could not parse CSRF key bytes: %v", err) logger.Fatalf("could not parse CSRF key bytes: %v", err)
} }
nextRequestId := func() string { nextRequestID := func() string {
return fmt.Sprintf("%d", time.Now().UnixNano()) return fmt.Sprintf("%d", time.Now().UnixNano())
} }
tracing := handlers.Tracing(nextRequestId) tracing := handlers.Tracing(nextRequestID)
logging := handlers.Logging(logger) logging := handlers.Logging(logger)
hsts := handlers.EnableHSTS() hsts := handlers.EnableHSTS()
csrfProtect := csrf.Protect( csrfProtect := csrf.Protect(
csrfKey, csrfKey,
csrf.Secure(true), csrf.Secure(true),
csrf.SameSite(csrf.SameSiteStrictMode), csrf.SameSite(csrf.SameSiteStrictMode),
csrf.MaxAge(600)) csrf.MaxAge(DefaultCSRFMaxAge))
errorMiddleware, err := handlers.ErrorHandling( errorMiddleware, err := handlers.ErrorHandling(
ctx, context.Background(),
logger, logger,
ui.Templates, ui.Templates,
) )
@ -149,7 +157,7 @@ func main() {
handlerChain := tracing(logging(hsts(errorMiddleware(csrfProtect(router))))) handlerChain := tracing(logging(hsts(errorMiddleware(csrfProtect(router)))))
startServer(ctx, handlerChain, logger, config) startServer(context.Background(), handlerChain, logger, config)
} }
func startServer(ctx context.Context, handlerChain http.Handler, logger *log.Logger, config *koanf.Koanf) { func startServer(ctx context.Context, handlerChain http.Handler, logger *log.Logger, config *koanf.Koanf) {
@ -158,10 +166,12 @@ func startServer(ctx context.Context, handlerChain http.Handler, logger *log.Log
serverPort := config.Int("server.port") serverPort := config.Int("server.port")
clientCertPool := x509.NewCertPool() clientCertPool := x509.NewCertPool()
pemBytes, err := ioutil.ReadFile(clientCertificateCAFile)
pemBytes, err := os.ReadFile(clientCertificateCAFile)
if err != nil { if err != nil {
logger.Fatalf("could not load client CA certificates: %v", err) logger.Fatalf("could not load client CA certificates: %v", err)
} }
clientCertPool.AppendCertsFromPEM(pemBytes) clientCertPool.AppendCertsFromPEM(pemBytes)
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
@ -173,9 +183,9 @@ func startServer(ctx context.Context, handlerChain http.Handler, logger *log.Log
server := &http.Server{ server := &http.Server{
Addr: fmt.Sprintf("%s:%d", serverName, serverPort), Addr: fmt.Sprintf("%s:%d", serverName, serverPort),
Handler: handlerChain, Handler: handlerChain,
ReadTimeout: 20 * time.Second, ReadTimeout: TimeoutTwenty,
WriteTimeout: 20 * time.Second, WriteTimeout: TimeoutTwenty,
IdleTimeout: 30 * time.Second, IdleTimeout: TimeoutThirty,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
} }
@ -188,18 +198,21 @@ func startServer(ctx context.Context, handlerChain http.Handler, logger *log.Log
logger.Infoln("Server is shutting down...") logger.Infoln("Server is shutting down...")
atomic.StoreInt32(&handlers.Healthy, 0) atomic.StoreInt32(&handlers.Healthy, 0)
ctx, cancel := context.WithTimeout(ctx, 30*time.Second) ctx, cancel := context.WithTimeout(ctx, TimeoutThirty)
defer cancel() defer cancel()
server.SetKeepAlivesEnabled(false) server.SetKeepAlivesEnabled(false)
if err := server.Shutdown(ctx); err != nil { if err := server.Shutdown(ctx); err != nil {
logger.Fatalf("Could not gracefully shutdown the server: %v\n", err) logger.Fatalf("Could not gracefully shutdown the server: %v\n", err)
} }
close(done) close(done)
}() }()
logger.Infof("Server is ready to handle requests at https://%s/", server.Addr) logger.Infof("Server is ready to handle requests at https://%s/", server.Addr)
atomic.StoreInt32(&handlers.Healthy, 1) atomic.StoreInt32(&handlers.Healthy, 1)
if err := server.ListenAndServeTLS( if err := server.ListenAndServeTLS(
config.String("server.certificate"), config.String("server.key"), config.String("server.certificate"), config.String("server.key"),
); err != nil && err != http.ErrServerClosed { ); err != nil && err != http.ErrServerClosed {

View file

@ -1,24 +0,0 @@
/*
Copyright 2020, 2021 Jan Dittberner
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package handlers
type handlerContextKey int
const (
CtxAdminClient handlerContextKey = iota
)

View file

@ -1,6 +1,6 @@
/* /*
Copyright 2020, 2021 Jan Dittberner Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -18,16 +18,15 @@
package handlers package handlers
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt"
"html/template" "html/template"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
commonModels "git.cacert.org/oidc_idp/models"
"git.cacert.org/oidc_idp/ui"
"github.com/go-playground/form/v4" "github.com/go-playground/form/v4"
"github.com/gorilla/csrf" "github.com/gorilla/csrf"
"github.com/lestrrat-go/jwx/jwt/openid" "github.com/lestrrat-go/jwx/jwt/openid"
@ -36,14 +35,16 @@ import (
"github.com/ory/hydra-client-go/models" "github.com/ory/hydra-client-go/models"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
commonModels "git.cacert.org/oidc_idp/models"
"git.cacert.org/oidc_idp/ui"
"git.cacert.org/oidc_idp/services" "git.cacert.org/oidc_idp/services"
) )
type consentHandler struct { type ConsentHandler struct {
adminClient *admin.Client adminClient admin.ClientService
bundle *i18n.Bundle bundle *i18n.Bundle
consentTemplate *template.Template consentTemplate *template.Template
context context.Context
logger *log.Logger logger *log.Logger
messageCatalog *services.MessageCatalog messageCatalog *services.MessageCatalog
} }
@ -70,6 +71,8 @@ const (
ScopeEmail = "email" ScopeEmail = "email"
) )
const OneDayInSeconds = 86400
func init() { func init() {
supportedScopes = make(map[string]*i18n.Message) supportedScopes = make(map[string]*i18n.Message)
supportedScopes[ScopeOpenID] = &i18n.Message{ supportedScopes[ScopeOpenID] = &i18n.Message{
@ -107,9 +110,11 @@ func (i *UserInfo) GetFullName() string {
return i.CommonName return i.CommonName
} }
func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ConsentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
challenge := r.URL.Query().Get("consent_challenge") challenge := r.URL.Query().Get("consent_challenge")
h.logger.Debugf("received consent challenge %s", challenge) h.logger.Debugf("received consent challenge %s", challenge)
accept := r.Header.Get("Accept-Language") accept := r.Header.Get("Accept-Language")
localizer := i18n.NewLocalizer(h.bundle, accept) localizer := i18n.NewLocalizer(h.bundle, accept)
@ -122,8 +127,12 @@ func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
h.renderConsentForm(w, r, consentData, requestedClaims, err, localizer) if err := h.renderConsentForm(w, r, consentData, requestedClaims, localizer); err != nil {
break h.logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
case http.MethodPost: case http.MethodPost:
var consentInfo ConsentInformation var consentInfo ConsentInformation
@ -131,11 +140,8 @@ func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
decoder := form.NewDecoder() decoder := form.NewDecoder()
if err := decoder.Decode(&consentInfo, r.Form); err != nil { if err := decoder.Decode(&consentInfo, r.Form); err != nil {
h.logger.Error(err) h.logger.Error(err)
http.Error( http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
w,
http.StatusText(http.StatusInternalServerError),
http.StatusInternalServerError,
)
return return
} }
@ -144,6 +150,7 @@ func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
h.logger.Errorf("could not get session data: %v", err) h.logger.Errorf("could not get session data: %v", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
@ -154,34 +161,38 @@ func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
GrantScope: consentInfo.GrantedScopes, GrantScope: consentInfo.GrantedScopes,
HandledAt: models.NullTime(time.Now()), HandledAt: models.NullTime(time.Now()),
Remember: true, Remember: true,
RememberFor: 86400, RememberFor: OneDayInSeconds,
Session: sessionData, Session: sessionData,
}).WithTimeout(time.Second * 10)) }).WithTimeout(TimeoutTen))
if err != nil { if err != nil {
h.logger.Error(err) h.logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
w.Header().Add("Location", *consentRequest.GetPayload().RedirectTo) w.Header().Add("Location", *consentRequest.GetPayload().RedirectTo)
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
return
} else { return
}
consentRequest, err := h.adminClient.RejectConsentRequest( consentRequest, err := h.adminClient.RejectConsentRequest(
admin.NewRejectConsentRequestParams().WithConsentChallenge(challenge).WithBody( admin.NewRejectConsentRequestParams().WithConsentChallenge(challenge).WithBody(
&models.RejectRequest{})) &models.RejectRequest{}))
if err != nil { if err != nil {
h.logger.Error(err) h.logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
w.Header().Add("Location", *consentRequest.GetPayload().RedirectTo) w.Header().Add("Location", *consentRequest.GetPayload().RedirectTo)
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
} }
} }
}
func (h *consentHandler) getRequestedConsentInformation(challenge string, r *http.Request) ( func (h *ConsentHandler) getRequestedConsentInformation(challenge string, r *http.Request) (
*admin.GetConsentRequestOK, *admin.GetConsentRequestOK,
*commonModels.OIDCClaimsRequest, *commonModels.OIDCClaimsRequest,
error, error,
@ -190,20 +201,26 @@ func (h *consentHandler) getRequestedConsentInformation(challenge string, r *htt
admin.NewGetConsentRequestParams().WithConsentChallenge(challenge)) admin.NewGetConsentRequestParams().WithConsentChallenge(challenge))
if err != nil { if err != nil {
h.logger.Errorf("error getting consent information: %v", err) h.logger.Errorf("error getting consent information: %v", err)
var errorDetails *ErrorDetails
errorDetails = &ErrorDetails{ if errorBucket := GetErrorBucket(r); errorBucket != nil {
errorDetails := &ErrorDetails{
ErrorMessage: "could not get consent details", ErrorMessage: "could not get consent details",
ErrorDetails: []string{http.StatusText(http.StatusInternalServerError)}, ErrorDetails: []string{http.StatusText(http.StatusInternalServerError)},
} }
GetErrorBucket(r).AddError(errorDetails)
return nil, nil, err errorBucket.AddError(errorDetails)
} }
return nil, nil, fmt.Errorf("error getting consent information: %w", err)
}
var requestedClaims commonModels.OIDCClaimsRequest var requestedClaims commonModels.OIDCClaimsRequest
requestUrl, err := url.Parse(consentData.Payload.RequestURL)
requestURL, err := url.Parse(consentData.Payload.RequestURL)
if err != nil { if err != nil {
h.logger.Warnf("could not parse original request URL %s: %v", consentData.Payload.RequestURL, err) h.logger.Warnf("could not parse original request URL %s: %v", consentData.Payload.RequestURL, err)
} else { } else {
claimsParameter := requestUrl.Query().Get("claims") claimsParameter := requestURL.Query().Get("claims")
if claimsParameter != "" { if claimsParameter != "" {
decoder := json.NewDecoder(strings.NewReader(claimsParameter)) decoder := json.NewDecoder(strings.NewReader(claimsParameter))
err := decoder.Decode(&requestedClaims) err := decoder.Decode(&requestedClaims)
@ -216,27 +233,28 @@ func (h *consentHandler) getRequestedConsentInformation(challenge string, r *htt
} }
} }
} }
return consentData, &requestedClaims, nil return consentData, &requestedClaims, nil
} }
func (h *consentHandler) renderConsentForm( func (h *ConsentHandler) renderConsentForm(
w http.ResponseWriter, w http.ResponseWriter,
r *http.Request, r *http.Request,
consentData *admin.GetConsentRequestOK, consentData *admin.GetConsentRequestOK,
claims *commonModels.OIDCClaimsRequest, claims *commonModels.OIDCClaimsRequest,
err error,
localizer *i18n.Localizer, localizer *i18n.Localizer,
) { ) error {
trans := func(id string, values ...map[string]interface{}) string { trans := func(id string, values ...map[string]interface{}) string {
if len(values) > 0 { if len(values) > 0 {
return h.messageCatalog.LookupMessage(id, values[0], localizer) return h.messageCatalog.LookupMessage(id, values[0], localizer)
} }
return h.messageCatalog.LookupMessage(id, nil, localizer) return h.messageCatalog.LookupMessage(id, nil, localizer)
} }
// render consent form // render consent form
client := consentData.GetPayload().Client client := consentData.GetPayload().Client
err = h.consentTemplate.Lookup("base").Execute(w, map[string]interface{}{ err := h.consentTemplate.Lookup("base").Execute(w, map[string]interface{}{
"Title": trans("TitleRequestConsent"), "Title": trans("TitleRequestConsent"),
csrf.TemplateTag: csrf.TemplateField(r), csrf.TemplateTag: csrf.TemplateField(r),
"errors": map[string]string{}, "errors": map[string]string{},
@ -245,15 +263,24 @@ func (h *consentHandler) renderConsentForm(
"requestedClaims": h.mapRequestedClaims(claims, localizer), "requestedClaims": h.mapRequestedClaims(claims, localizer),
"LabelSubmit": trans("LabelSubmit"), "LabelSubmit": trans("LabelSubmit"),
"LabelConsent": trans("LabelConsent"), "LabelConsent": trans("LabelConsent"),
"IntroMoreInformation": template.HTML(trans("IntroConsentMoreInformation", map[string]interface{}{ "IntroMoreInformation": template.HTML( //nolint:gosec
trans("IntroConsentMoreInformation", map[string]interface{}{
"client": client.ClientName, "client": client.ClientName,
"clientLink": client.ClientURI, "clientLink": client.ClientURI,
})), })),
"ClaimsInformation": template.HTML(trans("ClaimsInformation", nil)), "ClaimsInformation": template.HTML( //nolint:gosec
"IntroConsentRequested": template.HTML(trans("IntroConsentRequested", map[string]interface{}{ trans("ClaimsInformation", nil)),
"IntroConsentRequested": template.HTML( //nolint:gosec
trans("IntroConsentRequested", map[string]interface{}{
"client": client.ClientName, "client": client.ClientName,
})), })),
}) })
if err != nil {
return fmt.Errorf("rendering failed: %w", err)
}
return nil
} }
type scopeWithLabel struct { type scopeWithLabel struct {
@ -261,13 +288,19 @@ type scopeWithLabel struct {
Label string Label string
} }
func (h *consentHandler) mapRequestedScope(scope models.StringSlicePipeDelimiter, localizer *i18n.Localizer) []*scopeWithLabel { func (h *ConsentHandler) mapRequestedScope(
scope models.StringSlicePipeDelimiter,
localizer *i18n.Localizer,
) []*scopeWithLabel {
result := make([]*scopeWithLabel, 0) result := make([]*scopeWithLabel, 0)
for _, scopeName := range scope { for _, scopeName := range scope {
if _, ok := supportedScopes[scopeName]; !ok { if _, ok := supportedScopes[scopeName]; !ok {
h.logger.Warnf("unsupported scope %s ignored", scopeName) h.logger.Warnf("unsupported scope %s ignored", scopeName)
continue continue
} }
label, err := localizer.Localize(&i18n.LocalizeConfig{ label, err := localizer.Localize(&i18n.LocalizeConfig{
DefaultMessage: supportedScopes[scopeName], DefaultMessage: supportedScopes[scopeName],
}) })
@ -275,8 +308,10 @@ func (h *consentHandler) mapRequestedScope(scope models.StringSlicePipeDelimiter
h.logger.Warnf("could not localize label for scope %s: %v", scopeName, err) h.logger.Warnf("could not localize label for scope %s: %v", scopeName, err)
label = scopeName label = scopeName
} }
result = append(result, &scopeWithLabel{Name: scopeName, Label: label}) result = append(result, &scopeWithLabel{Name: scopeName, Label: label})
} }
return result return result
} }
@ -286,7 +321,10 @@ type claimWithLabel struct {
Essential bool Essential bool
} }
func (h *consentHandler) mapRequestedClaims(claims *commonModels.OIDCClaimsRequest, localizer *i18n.Localizer) []*claimWithLabel { func (h *ConsentHandler) mapRequestedClaims(
claims *commonModels.OIDCClaimsRequest,
localizer *i18n.Localizer,
) []*claimWithLabel {
result := make([]*claimWithLabel, 0) result := make([]*claimWithLabel, 0)
known := make(map[string]bool) known := make(map[string]bool)
@ -295,8 +333,10 @@ func (h *consentHandler) mapRequestedClaims(claims *commonModels.OIDCClaimsReque
for k, v := range *claimElement { for k, v := range *claimElement {
if _, ok := supportedClaims[k]; !ok { if _, ok := supportedClaims[k]; !ok {
h.logger.Warnf("unsupported claim %s ignored", k) h.logger.Warnf("unsupported claim %s ignored", k)
continue continue
} }
label, err := localizer.Localize(&i18n.LocalizeConfig{ label, err := localizer.Localize(&i18n.LocalizeConfig{
DefaultMessage: supportedClaims[k], DefaultMessage: supportedClaims[k],
}) })
@ -304,6 +344,7 @@ func (h *consentHandler) mapRequestedClaims(claims *commonModels.OIDCClaimsReque
h.logger.Warnf("could not localize label for claim %s: %v", k, err) h.logger.Warnf("could not localize label for claim %s: %v", k, err)
label = k label = k
} }
if !known[k] { if !known[k] {
result = append(result, &claimWithLabel{ result = append(result, &claimWithLabel{
Name: k, Name: k,
@ -315,10 +356,11 @@ func (h *consentHandler) mapRequestedClaims(claims *commonModels.OIDCClaimsReque
} }
} }
} }
return result return result
} }
func (h *consentHandler) getSessionData( func (h *ConsentHandler) getSessionData(
r *http.Request, r *http.Request,
info ConsentInformation, info ConsentInformation,
claims *commonModels.OIDCClaimsRequest, claims *commonModels.OIDCClaimsRequest,
@ -329,32 +371,42 @@ func (h *consentHandler) getSessionData(
userInfo := h.GetUserInfoFromClientCertificate(r, payload.Subject) userInfo := h.GetUserInfoFromClientCertificate(r, payload.Subject)
h.fillTokenData(accessTokenData, payload.RequestedScope, claims, info, userInfo) if err := h.fillTokenData(accessTokenData, payload.RequestedScope, claims, info, userInfo); err != nil {
h.fillTokenData(idTokenData, payload.RequestedScope, claims, info, userInfo) return nil, err
}
if err := h.fillTokenData(idTokenData, payload.RequestedScope, claims, info, userInfo); err != nil {
return nil, err
}
return &models.ConsentRequestSession{ return &models.ConsentRequestSession{
AccessToken: accessTokenData, AccessToken: accessTokenData,
IDToken: idTokenData, IDToken: idTokenData,
}, nil }, nil
} }
func (h *consentHandler) fillTokenData( func (h *ConsentHandler) fillTokenData(
m map[string]interface{}, m map[string]interface{},
requestedScope models.StringSlicePipeDelimiter, requestedScope models.StringSlicePipeDelimiter,
claimsRequest *commonModels.OIDCClaimsRequest, claimsRequest *commonModels.OIDCClaimsRequest,
consentInformation ConsentInformation, consentInformation ConsentInformation,
userInfo *UserInfo, userInfo *UserInfo,
) { ) error {
for _, scope := range requestedScope { for _, scope := range requestedScope {
granted := false granted := false
for _, k := range consentInformation.GrantedScopes { for _, k := range consentInformation.GrantedScopes {
if k == scope { if k == scope {
granted = true granted = true
break break
} }
} }
if !granted { if !granted {
continue continue
} }
switch scope { switch scope {
case ScopeEmail: case ScopeEmail:
// email // email
@ -362,7 +414,6 @@ func (h *consentHandler) fillTokenData(
// email_verified Claims. // email_verified Claims.
m[openid.EmailKey] = userInfo.Email m[openid.EmailKey] = userInfo.Email
m[openid.EmailVerifiedKey] = userInfo.EmailVerified m[openid.EmailVerifiedKey] = userInfo.EmailVerified
break
case ScopeProfile: case ScopeProfile:
// profile // profile
// OPTIONAL. This scope value requests access to the // OPTIONAL. This scope value requests access to the
@ -371,25 +422,52 @@ func (h *consentHandler) fillTokenData(
// preferred_username, profile, picture, website, gender, // preferred_username, profile, picture, website, gender,
// birthdate, zoneinfo, locale, and updated_at. // birthdate, zoneinfo, locale, and updated_at.
m[openid.NameKey] = userInfo.GetFullName() m[openid.NameKey] = userInfo.GetFullName()
break
} }
} }
if userInfoClaims := claimsRequest.GetUserInfo(); userInfoClaims != nil { if userInfoClaims := claimsRequest.GetUserInfo(); userInfoClaims != nil {
err := h.parseUserInfoClaims(m, userInfoClaims, consentInformation)
if err != nil {
return err
}
}
return nil
}
func (h *ConsentHandler) parseUserInfoClaims(
m map[string]interface{},
userInfoClaims *commonModels.ClaimElement,
consentInformation ConsentInformation,
) error {
for claimName, claim := range *userInfoClaims { for claimName, claim := range *userInfoClaims {
granted := false granted := false
for _, k := range consentInformation.SelectedClaims { for _, k := range consentInformation.SelectedClaims {
if k == claimName { if k == claimName {
granted = true granted = true
break break
} }
} }
if !granted { if !granted {
continue continue
} }
if claim.WantedValue() != nil {
m[claimName] = *claim.WantedValue() wantedValue, err := claim.WantedValue()
if err != nil {
if !errors.Is(err, commonModels.ErrNoValue) {
return fmt.Errorf("error handling claim: %w", err)
}
}
if wantedValue != "" {
m[claimName] = wantedValue
continue continue
} }
if claim.IsEssential() { if claim.IsEssential() {
h.logger.Warnf( h.logger.Warnf(
"handling for essential claim name %s not implemented", "handling for essential claim name %s not implemented",
@ -402,22 +480,30 @@ func (h *consentHandler) fillTokenData(
) )
} }
} }
}
return nil
} }
func (h *consentHandler) GetUserInfoFromClientCertificate(r *http.Request, subject string) *UserInfo { func (h *ConsentHandler) GetUserInfoFromClientCertificate(r *http.Request, subject string) *UserInfo {
if r.TLS != nil && r.TLS.PeerCertificates != nil && len(r.TLS.PeerCertificates) > 0 { if r.TLS != nil && r.TLS.PeerCertificates != nil && len(r.TLS.PeerCertificates) > 0 {
firstCert := r.TLS.PeerCertificates[0] firstCert := r.TLS.PeerCertificates[0]
var verified bool var verified bool
for _, email := range firstCert.EmailAddresses { for _, email := range firstCert.EmailAddresses {
h.logger.Infof("authenticated with a client certificate for email address %s", email) h.logger.Infof("authenticated with a client certificate for email address %s", email)
if subject == email { if subject == email {
verified = true verified = true
} }
} }
if !verified { if !verified {
h.logger.Warnf("authentication attempt with a wrong certificate that did not contain the requested address %s", subject) h.logger.Warnf(
"authentication attempt with a wrong certificate that did not contain the requested address %s",
subject,
)
return nil return nil
} }
@ -427,10 +513,16 @@ func (h *consentHandler) GetUserInfoFromClientCertificate(r *http.Request, subje
CommonName: firstCert.Subject.CommonName, CommonName: firstCert.Subject.CommonName,
} }
} }
return nil return nil
} }
func NewConsentHandler(ctx context.Context, logger *log.Logger) (*consentHandler, error) { func NewConsentHandler(
logger *log.Logger,
bundle *i18n.Bundle,
messageCatalog *services.MessageCatalog,
adminClient admin.ClientService,
) *ConsentHandler {
consentTemplate := template.Must( consentTemplate := template.Must(
template.ParseFS( template.ParseFS(
ui.Templates, ui.Templates,
@ -438,12 +530,11 @@ func NewConsentHandler(ctx context.Context, logger *log.Logger) (*consentHandler
"templates/consent.gohtml", "templates/consent.gohtml",
)) ))
return &consentHandler{ return &ConsentHandler{
adminClient: ctx.Value(CtxAdminClient).(*admin.Client), adminClient: adminClient,
bundle: services.GetI18nBundle(ctx), bundle: bundle,
consentTemplate: consentTemplate, consentTemplate: consentTemplate,
context: ctx,
logger: logger, logger: logger,
messageCatalog: services.GetMessageCatalog(ctx), messageCatalog: messageCatalog,
}, nil }
} }

19
handlers/doc.go Normal file
View file

@ -0,0 +1,19 @@
/*
Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package handlers provides request handlers.
package handlers

View file

@ -1,6 +1,6 @@
/* /*
Copyright 2020, 2021 Jan Dittberner Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -24,9 +24,10 @@ import (
"io/fs" "io/fs"
"net/http" "net/http"
"git.cacert.org/oidc_idp/services"
"github.com/nicksnyder/go-i18n/v2/i18n" "github.com/nicksnyder/go-i18n/v2/i18n"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"git.cacert.org/oidc_idp/services"
) )
type errorKey int type errorKey int
@ -62,6 +63,7 @@ func (b *ErrorBucket) serveHTTP(w http.ResponseWriter, r *http.Request) {
), ),
"details": b.errorDetails, "details": b.errorDetails,
}) })
if err != nil { if err != nil {
log.Errorf("error rendering error template: %v", err) log.Errorf("error rendering error template: %v", err)
http.Error( http.Error(
@ -74,92 +76,110 @@ func (b *ErrorBucket) serveHTTP(w http.ResponseWriter, r *http.Request) {
} }
func GetErrorBucket(r *http.Request) *ErrorBucket { func GetErrorBucket(r *http.Request) *ErrorBucket {
return r.Context().Value(errorBucketKey).(*ErrorBucket) if bucket, ok := r.Context().Value(errorBucketKey).(*ErrorBucket); ok {
return bucket
} }
// call this from your application's handler return nil
}
// AddError can be called to add error details from your application's handler.
func (b *ErrorBucket) AddError(details *ErrorDetails) { func (b *ErrorBucket) AddError(details *ErrorDetails) {
b.errorDetails = details b.errorDetails = details
} }
type errorResponseWriter struct { type errorResponseWriter struct {
http.ResponseWriter http.ResponseWriter
ctx context.Context errorBucket *ErrorBucket
statusCode int statusCode int
} }
func (w *errorResponseWriter) WriteHeader(code int) { func (w *errorResponseWriter) WriteHeader(code int) {
w.statusCode = code w.statusCode = code
if code >= 400 {
if code >= http.StatusBadRequest {
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
errorBucket := w.ctx.Value(errorBucketKey).(*ErrorBucket)
if errorBucket != nil && errorBucket.errorDetails == nil { w.errorBucket.AddError(&ErrorDetails{
errorBucket.AddError(&ErrorDetails{
ErrorMessage: http.StatusText(code), ErrorMessage: http.StatusText(code),
}) })
} }
}
w.ResponseWriter.WriteHeader(code) w.ResponseWriter.WriteHeader(code)
} }
func (w *errorResponseWriter) Write(content []byte) (int, error) { func (w *errorResponseWriter) Write(content []byte) (int, error) {
if w.statusCode > 400 { if w.statusCode >= http.StatusBadRequest {
errorBucket := w.ctx.Value(errorBucketKey).(*ErrorBucket) if w.errorBucket.errorDetails.ErrorDetails == nil {
if errorBucket != nil { w.errorBucket.errorDetails.ErrorDetails = make([]string, 0)
if errorBucket.errorDetails.ErrorDetails == nil {
errorBucket.errorDetails.ErrorDetails = make([]string, 0)
}
errorBucket.errorDetails.ErrorDetails = append(
errorBucket.errorDetails.ErrorDetails, string(content),
)
return len(content), nil
}
}
return w.ResponseWriter.Write(content)
} }
func ErrorHandling(handlerContext context.Context, logger *log.Logger, templateFS fs.FS) (func(http.Handler) http.Handler, error) { w.errorBucket.errorDetails.ErrorDetails = append(
w.errorBucket.errorDetails.ErrorDetails, string(content),
)
return len(content), nil
}
code, err := w.ResponseWriter.Write(content)
if err != nil {
return code, fmt.Errorf("error writing response: %w", err)
}
return code, nil
}
func ErrorHandling(
handlerContext context.Context,
logger *log.Logger,
templateFS fs.FS,
) (func(http.Handler) http.Handler, error) {
errorTemplates, err := template.ParseFS( errorTemplates, err := template.ParseFS(
templateFS, templateFS,
"templates/base.gohtml", "templates/base.gohtml",
"templates/errors.gohtml", "templates/errors.gohtml",
) )
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("could not parse templates: %w", err)
} }
bundle, err := services.GetI18nBundle(handlerContext)
if err != nil {
return nil, fmt.Errorf("could not get i18n bundle: %w", err)
}
messageCatalog, err := services.GetMessageCatalog(handlerContext)
if err != nil {
return nil, fmt.Errorf("could not get message catalog: %w", err)
}
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
errorBucket := &ErrorBucket{ errorBucket := &ErrorBucket{
templates: errorTemplates, templates: errorTemplates,
logger: logger, logger: logger,
bundle: services.GetI18nBundle(handlerContext), bundle: bundle,
messageCatalog: services.GetMessageCatalog(handlerContext), messageCatalog: messageCatalog,
}
ctx := context.WithValue(r.Context(), errorBucketKey, errorBucket)
interCeptingResponseWriter := &errorResponseWriter{
w,
ctx,
http.StatusOK,
} }
next.ServeHTTP( next.ServeHTTP(
interCeptingResponseWriter, &errorResponseWriter{w, errorBucket, http.StatusOK},
r.WithContext(ctx), r.WithContext(context.WithValue(r.Context(), errorBucketKey, errorBucket)),
) )
errorBucket.serveHTTP(w, r) errorBucket.serveHTTP(w, r)
}) })
}, nil }, nil
} }
type errorHandler struct { type ErrorHandler struct {
} }
func (e *errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (e *ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
_, _ = fmt.Fprintf(w, ` _, _ = fmt.Fprintf(w, `
didumm %#v didumm %#v
`, r.URL.Query()) `, r.URL.Query())
} }
func NewErrorHandler() *errorHandler { func NewErrorHandler() *ErrorHandler {
return &errorHandler{} return &ErrorHandler{}
} }

View file

@ -1,6 +1,6 @@
/* /*
Copyright 2020, 2021 Jan Dittberner Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -19,7 +19,7 @@ package handlers
import ( import (
"bytes" "bytes"
"context" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"html/template" "html/template"
@ -27,13 +27,14 @@ import (
"strconv" "strconv"
"time" "time"
"git.cacert.org/oidc_idp/ui"
"github.com/gorilla/csrf" "github.com/gorilla/csrf"
"github.com/nicksnyder/go-i18n/v2/i18n" "github.com/nicksnyder/go-i18n/v2/i18n"
"github.com/ory/hydra-client-go/client/admin" "github.com/ory/hydra-client-go/client/admin"
"github.com/ory/hydra-client-go/models" "github.com/ory/hydra-client-go/models"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"git.cacert.org/oidc_idp/ui"
"git.cacert.org/oidc_idp/services" "git.cacert.org/oidc_idp/services"
) )
@ -49,22 +50,30 @@ type templateName string
const ( const (
CertificateLogin templateName = "cert" CertificateLogin templateName = "cert"
NoEmailsInClientCertificate = "no_emails" NoEmailsInClientCertificate templateName = "no_emails"
) )
type loginHandler struct { const TimeoutTen = 10 * time.Second
adminClient *admin.Client
type LoginHandler struct {
adminClient admin.ClientService
bundle *i18n.Bundle bundle *i18n.Bundle
context context.Context
logger *log.Logger logger *log.Logger
templates map[templateName]*template.Template templates map[templateName]*template.Template
messageCatalog *services.MessageCatalog messageCatalog *services.MessageCatalog
} }
func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *LoginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error if r.Method != http.MethodGet && r.Method != http.MethodPost {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
challenge := r.URL.Query().Get("login_challenge") challenge := r.URL.Query().Get("login_challenge")
h.logger.Debugf("received login challenge %s\n", challenge) h.logger.Debugf("received login challenge %s\n", challenge)
accept := r.Header.Get("Accept-Language") accept := r.Header.Get("Accept-Language")
localizer := i18n.NewLocalizer(h.bundle, accept) localizer := i18n.NewLocalizer(h.bundle, accept)
@ -72,11 +81,24 @@ func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if certEmails == nil { if certEmails == nil {
h.renderNoEmailsInClientCertificate(w, localizer) h.renderNoEmailsInClientCertificate(w, localizer)
return return
} }
switch r.Method { if r.Method == http.MethodGet {
case http.MethodGet: h.handleGet(w, r, challenge, certEmails, localizer)
} else {
h.handlePost(w, r, challenge, certEmails, localizer)
}
}
func (h *LoginHandler) handleGet(
w http.ResponseWriter,
r *http.Request,
challenge string,
certEmails []string,
localizer *i18n.Localizer,
) {
loginRequest, err := h.adminClient.GetLoginRequest(admin.NewGetLoginRequestParams().WithLoginChallenge(challenge)) loginRequest, err := h.adminClient.GetLoginRequest(admin.NewGetLoginRequestParams().WithLoginChallenge(challenge))
if err != nil { if err != nil {
h.logger.Warnf("could not get login request for challenge %s: %v", challenge, err) h.logger.Warnf("could not get login request for challenge %s: %v", challenge, err)
@ -85,25 +107,38 @@ func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if errors.As(err, &e) { if errors.As(err, &e) {
w.Header().Set("Location", *e.GetPayload().RedirectTo) w.Header().Set("Location", *e.GetPayload().RedirectTo)
w.WriteHeader(http.StatusGone) w.WriteHeader(http.StatusGone)
return
} return
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }
return
} http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
h.renderRequestForClientCert(w, r, certEmails, localizer, loginRequest)
break return
case http.MethodPost: }
if r.FormValue("use-identity") != "accept" {
h.rejectLogin(w, challenge, localizer) h.renderRequestForClientCert(w, r, certEmails, localizer, loginRequest)
}
func (h *LoginHandler) handlePost(
w http.ResponseWriter,
r *http.Request,
challenge string,
certEmails []string,
localizer *i18n.Localizer,
) {
if r.FormValue("use-identity") != "accept" {
h.rejectLogin(w, challenge, localizer)
return return
} }
var userId string
// perform certificate auth // perform certificate auth
h.logger.Infof("would perform certificate authentication with: %+v", certEmails) h.logger.Infof("would perform certificate authentication with: %+v", certEmails)
userId, err = h.performCertificateLogin(certEmails, r)
userID, err := h.performCertificateLogin(certEmails, r)
if err != nil { if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
@ -114,68 +149,105 @@ func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Acr: string(ClientCertificate), Acr: string(ClientCertificate),
Remember: true, Remember: true,
RememberFor: 0, RememberFor: 0,
Subject: &userId, Subject: &userID,
}).WithTimeout(time.Second * 10)) }).WithTimeout(TimeoutTen))
if err != nil { if err != nil {
h.logger.Errorf("error getting login request: %#v", err) h.logger.Errorf("error getting login request: %#v", err)
var errorDetails *ErrorDetails
switch v := err.(type) { h.fillAcceptLoginRequestErrorBucket(r, err)
case *admin.AcceptLoginRequestNotFound:
payload := v.GetPayload() return
}
w.Header().Add("Location", *loginRequest.GetPayload().RedirectTo)
w.WriteHeader(http.StatusFound)
}
func (h *LoginHandler) fillAcceptLoginRequestErrorBucket(r *http.Request, err error) {
if errorBucket := GetErrorBucket(r); errorBucket != nil {
var (
errorDetails *ErrorDetails
acceptLoginRequestNotFound *admin.AcceptLoginRequestNotFound
)
if errors.As(err, &acceptLoginRequestNotFound) {
payload := acceptLoginRequestNotFound.GetPayload()
errorDetails = &ErrorDetails{ errorDetails = &ErrorDetails{
ErrorMessage: payload.Error, ErrorMessage: payload.Error,
ErrorDetails: []string{payload.ErrorDescription}, ErrorDetails: []string{payload.ErrorDescription},
} }
if v.Payload.StatusCode != 0 {
if acceptLoginRequestNotFound.Payload.StatusCode != 0 {
errorDetails.ErrorCode = strconv.Itoa(int(payload.StatusCode)) errorDetails.ErrorCode = strconv.Itoa(int(payload.StatusCode))
} }
break } else {
default:
errorDetails = &ErrorDetails{ errorDetails = &ErrorDetails{
ErrorMessage: "could not accept login", ErrorMessage: "could not accept login",
ErrorDetails: []string{err.Error()}, ErrorDetails: []string{err.Error()},
} }
} }
GetErrorBucket(r).AddError(errorDetails)
return errorBucket.AddError(errorDetails)
}
w.Header().Add("Location", *loginRequest.GetPayload().RedirectTo)
w.WriteHeader(http.StatusFound)
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
} }
} }
func (h *loginHandler) rejectLogin(w http.ResponseWriter, challenge string, localizer *i18n.Localizer) { func (h *LoginHandler) rejectLogin(w http.ResponseWriter, challenge string, localizer *i18n.Localizer) {
rejectLoginRequest, err := h.adminClient.RejectLoginRequest(admin.NewRejectLoginRequestParams().WithLoginChallenge(challenge).WithBody( const Ten = 10 * time.Second
rejectLoginRequest, err := h.adminClient.RejectLoginRequest(
admin.NewRejectLoginRequestParams().WithLoginChallenge(challenge).WithBody(
&models.RejectRequest{ &models.RejectRequest{
ErrorDescription: h.messageCatalog.LookupMessage("LoginDeniedByUser", nil, localizer), ErrorDescription: h.messageCatalog.LookupMessage("LoginDeniedByUser", nil, localizer),
ErrorHint: h.messageCatalog.LookupMessage("HintChooseAnIdentityForAuthentication", nil, localizer), ErrorHint: h.messageCatalog.LookupMessage("HintChooseAnIdentityForAuthentication", nil, localizer),
StatusCode: http.StatusForbidden, StatusCode: http.StatusForbidden,
}, },
).WithTimeout(time.Second * 10)) ).WithTimeout(Ten))
if err != nil { if err != nil {
h.logger.Errorf("error getting reject login request: %#v", err) h.logger.Errorf("error getting reject login request: %#v", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
w.Header().Set("Location", *rejectLoginRequest.GetPayload().RedirectTo) w.Header().Set("Location", *rejectLoginRequest.GetPayload().RedirectTo)
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
} }
func (h *loginHandler) getEmailAddressesFromClientCertificate(r *http.Request) []string { func (h *LoginHandler) getEmailAddressesFromClientCertificate(r *http.Request) []string {
if r.TLS != nil && r.TLS.PeerCertificates != nil && len(r.TLS.PeerCertificates) > 0 { if r.TLS != nil && r.TLS.PeerCertificates != nil && len(r.TLS.PeerCertificates) > 0 {
firstCert := r.TLS.PeerCertificates[0] firstCert := r.TLS.PeerCertificates[0]
for _, email := range firstCert.EmailAddresses {
h.logger.Infof("authenticated with a client certificate for email address %s", email) if !isClientCertificate(firstCert) {
}
return firstCert.EmailAddresses
}
return nil return nil
} }
func (h *loginHandler) renderRequestForClientCert(w http.ResponseWriter, r *http.Request, emails []string, localizer *i18n.Localizer, loginRequest *admin.GetLoginRequestOK) { for _, email := range firstCert.EmailAddresses {
h.logger.Infof("authenticated with a client certificate for email address %s", email)
}
return firstCert.EmailAddresses
}
return nil
}
func isClientCertificate(cert *x509.Certificate) bool {
for _, ext := range cert.ExtKeyUsage {
if ext == x509.ExtKeyUsageClientAuth {
return true
}
}
return false
}
func (h *LoginHandler) renderRequestForClientCert(
w http.ResponseWriter,
r *http.Request,
emails []string,
localizer *i18n.Localizer,
loginRequest *admin.GetLoginRequestOK,
) {
trans := func(label string) string { trans := func(label string) string {
return h.messageCatalog.LookupMessage(label, nil, localizer) return h.messageCatalog.LookupMessage(label, nil, localizer)
} }
@ -184,7 +256,7 @@ func (h *loginHandler) renderRequestForClientCert(w http.ResponseWriter, r *http
err := h.templates[CertificateLogin].Lookup("base").Execute(rendered, map[string]interface{}{ err := h.templates[CertificateLogin].Lookup("base").Execute(rendered, map[string]interface{}{
"Title": trans("LoginTitle"), "Title": trans("LoginTitle"),
csrf.TemplateTag: csrf.TemplateField(r), csrf.TemplateTag: csrf.TemplateField(r),
"IntroText": template.HTML(h.messageCatalog.LookupMessage( "IntroText": template.HTML(h.messageCatalog.LookupMessage( //nolint:gosec
"CertLoginIntroText", "CertLoginIntroText",
map[string]interface{}{"ClientName": loginRequest.GetPayload().Client.ClientName}, map[string]interface{}{"ClientName": loginRequest.GetPayload().Client.ClientName},
localizer, localizer,
@ -195,27 +267,31 @@ func (h *loginHandler) renderRequestForClientCert(w http.ResponseWriter, r *http
"AcceptLabel": trans("LabelAcceptCertLogin"), "AcceptLabel": trans("LabelAcceptCertLogin"),
"RejectLabel": trans("LabelRejectCertLogin"), "RejectLabel": trans("LabelRejectCertLogin"),
}) })
if err != nil { if err != nil {
h.logger.Error(err) h.logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
w.Header().Add("Pragma", "no-cache") w.Header().Add("Pragma", "no-cache")
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
_, _ = w.Write(rendered.Bytes()) _, _ = w.Write(rendered.Bytes())
} }
func (h *loginHandler) performCertificateLogin(emails []string, r *http.Request) (string, error) { func (h *LoginHandler) performCertificateLogin(emails []string, r *http.Request) (string, error) {
requestedEmail := r.PostFormValue("email") requestedEmail := r.PostFormValue("email")
for _, email := range emails { for _, email := range emails {
if email == requestedEmail { if email == requestedEmail {
return email, nil return email, nil
} }
} }
return "", fmt.Errorf("no user found") return "", fmt.Errorf("no user found")
} }
func (h *loginHandler) renderNoEmailsInClientCertificate(w http.ResponseWriter, localizer *i18n.Localizer) { func (h *LoginHandler) renderNoEmailsInClientCertificate(w http.ResponseWriter, localizer *i18n.Localizer) {
trans := func(label string) string { trans := func(label string) string {
return h.messageCatalog.LookupMessage(label, nil, localizer) return h.messageCatalog.LookupMessage(label, nil, localizer)
} }
@ -227,15 +303,20 @@ func (h *loginHandler) renderNoEmailsInClientCertificate(w http.ResponseWriter,
if err != nil { if err != nil {
h.logger.Error(err) h.logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
} }
func NewLoginHandler(ctx context.Context, logger *log.Logger) (*loginHandler, error) { func NewLoginHandler(
return &loginHandler{ logger *log.Logger,
adminClient: ctx.Value(CtxAdminClient).(*admin.Client), bundle *i18n.Bundle,
bundle: services.GetI18nBundle(ctx), messageCatalog *services.MessageCatalog,
context: ctx, adminClient admin.ClientService,
) *LoginHandler {
return &LoginHandler{
adminClient: adminClient,
bundle: bundle,
logger: logger, logger: logger,
templates: map[templateName]*template.Template{ templates: map[templateName]*template.Template{
CertificateLogin: template.Must(template.ParseFS( CertificateLogin: template.Must(template.ParseFS(
@ -249,6 +330,6 @@ func NewLoginHandler(ctx context.Context, logger *log.Logger) (*loginHandler, er
"templates/no_email_in_client_certificate.gohtml", "templates/no_email_in_client_certificate.gohtml",
)), )),
}, },
messageCatalog: services.GetMessageCatalog(ctx), messageCatalog: messageCatalog,
}, nil }
} }

View file

@ -1,6 +1,6 @@
/* /*
Copyright 2020, 2021 Jan Dittberner Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -18,7 +18,6 @@
package handlers package handlers
import ( import (
"context"
"net/http" "net/http"
"time" "time"
@ -26,20 +25,23 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type logoutHandler struct { type LogoutHandler struct {
adminClient *admin.Client adminClient admin.ClientService
logger *log.Logger logger *log.Logger
} }
func (h *logoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
const Ten = 10 * time.Second
challenge := r.URL.Query().Get("logout_challenge") challenge := r.URL.Query().Get("logout_challenge")
h.logger.Debugf("received challenge %s\n", challenge) h.logger.Debugf("received challenge %s\n", challenge)
logoutRequest, err := h.adminClient.GetLogoutRequest( logoutRequest, err := h.adminClient.GetLogoutRequest(
admin.NewGetLogoutRequestParams().WithLogoutChallenge(challenge).WithTimeout(time.Second * 10)) admin.NewGetLogoutRequestParams().WithLogoutChallenge(challenge).WithTimeout(Ten))
if err != nil { if err != nil {
h.logger.Errorf("error getting logout requests: %v", err) h.logger.Errorf("error getting logout requests: %v", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
@ -56,20 +58,20 @@ func (h *logoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
} }
func NewLogoutHandler(ctx context.Context, logger *log.Logger) *logoutHandler { func NewLogoutHandler(logger *log.Logger, adminClient admin.ClientService) *LogoutHandler {
return &logoutHandler{ return &LogoutHandler{
logger: logger, logger: logger,
adminClient: ctx.Value(CtxAdminClient).(*admin.Client), adminClient: adminClient,
} }
} }
type logoutSuccessHandler struct { type LogoutSuccessHandler struct {
} }
func (l *logoutSuccessHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (l *LogoutSuccessHandler) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
panic("implement me") panic("implement me")
} }
func NewLogoutSuccessHandler() *logoutSuccessHandler { func NewLogoutSuccessHandler() *LogoutSuccessHandler {
return &logoutSuccessHandler{} return &LogoutSuccessHandler{}
} }

View file

@ -1,6 +1,6 @@
/* /*
Copyright 2020, 2021 Jan Dittberner Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ package handlers
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"sync/atomic" "sync/atomic"
@ -28,7 +29,7 @@ import (
type key int type key int
const ( const (
requestIdKey key = iota requestIDKey key = iota
) )
type statusCodeInterceptor struct { type statusCodeInterceptor struct {
@ -45,7 +46,12 @@ func (sci *statusCodeInterceptor) WriteHeader(code int) {
func (sci *statusCodeInterceptor) Write(content []byte) (int, error) { func (sci *statusCodeInterceptor) Write(content []byte) (int, error) {
count, err := sci.ResponseWriter.Write(content) count, err := sci.ResponseWriter.Write(content)
sci.count += count sci.count += count
return count, err
if err != nil {
return count, fmt.Errorf("could not write response: %w", err)
}
return count, nil
} }
func Logging(logger *log.Logger) func(http.Handler) http.Handler { func Logging(logger *log.Logger) func(http.Handler) http.Handler {
@ -53,13 +59,13 @@ func Logging(logger *log.Logger) func(http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
interceptor := &statusCodeInterceptor{w, http.StatusOK, 0} interceptor := &statusCodeInterceptor{w, http.StatusOK, 0}
defer func() { defer func() {
requestId, ok := r.Context().Value(requestIdKey).(string) requestID, ok := r.Context().Value(requestIDKey).(string)
if !ok { if !ok {
requestId = "unknown" requestID = "unknown"
} }
logger.Infof( logger.Infof(
"%s %s \"%s %s\" %d %d \"%s\"", "%s %s \"%s %s\" %d %d \"%s\"",
requestId, requestID,
r.RemoteAddr, r.RemoteAddr,
r.Method, r.Method,
r.URL.Path, r.URL.Path,
@ -73,15 +79,15 @@ func Logging(logger *log.Logger) func(http.Handler) http.Handler {
} }
} }
func Tracing(nextRequestId func() string) func(http.Handler) http.Handler { func Tracing(nextRequestID func() string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestId := r.Header.Get("X-Request-Id") requestID := r.Header.Get("X-Request-Id")
if requestId == "" { if requestID == "" {
requestId = nextRequestId() requestID = nextRequestID()
} }
ctx := context.WithValue(r.Context(), requestIdKey, requestId) ctx := context.WithValue(r.Context(), requestIDKey, requestID)
w.Header().Set("X-Request-Id", requestId) w.Header().Set("X-Request-Id", requestID)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
@ -93,8 +99,10 @@ func NewHealthHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if atomic.LoadInt32(&Healthy) == 1 { if atomic.LoadInt32(&Healthy) == 1 {
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
return return
} }
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
}) })
} }

View file

@ -1,6 +1,6 @@
/* /*
Copyright 2020, 2021 Jan Dittberner Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -26,7 +26,8 @@ import (
func EnableHSTS() func(http.Handler) http.Handler { func EnableHSTS() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Strict-Transport-Security", fmt.Sprintf("max-age=%d", int((time.Hour*24*180).Seconds()))) const Days180 = 180
w.Header().Set("Strict-Transport-Security", fmt.Sprintf("max-age=%d", int((time.Hour*24*Days180).Seconds())))
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }

View file

@ -1,6 +1,6 @@
/* /*
Copyright 2020, 2021 Jan Dittberner Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -15,31 +15,33 @@
limitations under the License. limitations under the License.
*/ */
/* // Package models contains data models
This package contains data models.
*/
package models package models
// An individual claim request. import "errors"
var ErrNoValue = errors.New("value not found")
// IndividualClaimsRequest represents an individual claim request.
// //
// Specification // # Specification
// //
// https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests // https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests
type IndividualClaimRequest map[string]interface{} type IndividualClaimsRequest map[string]interface{}
// ClaimElement represents a claim element // ClaimElement represents a claim element
type ClaimElement map[string]*IndividualClaimRequest type ClaimElement map[string]*IndividualClaimsRequest
// OIDCClaimsRequest the claims request parameter sent with the authorization request. // OIDCClaimsRequest the claims request parameter sent with the authorization request.
// //
// Specification // # Specification
// //
// https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter // https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter
type OIDCClaimsRequest map[string]ClaimElement type OIDCClaimsRequest map[string]ClaimElement
// GetUserInfo extracts the userinfo claim element from the request. // GetUserInfo extracts the userinfo claim element from the request.
// //
// Specification // # Specification
// //
// https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims // https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims
// //
@ -56,12 +58,13 @@ func (r OIDCClaimsRequest) GetUserInfo() *ClaimElement {
if userInfo, ok := r["userinfo"]; ok { if userInfo, ok := r["userinfo"]; ok {
return &userInfo return &userInfo
} }
return nil return nil
} }
// GetIDToken extracts the id_token claim element from the request. // GetIDToken extracts the id_token claim element from the request.
// //
// Specification // # Specification
// //
// https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims // https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims
// //
@ -75,12 +78,13 @@ func (r OIDCClaimsRequest) GetIDToken() *ClaimElement {
if idToken, ok := r["id_token"]; ok { if idToken, ok := r["id_token"]; ok {
return &idToken return &idToken
} }
return nil return nil
} }
// Checks whether the individual claim is an essential claim. // IsEssential checks whether the individual claim is an essential claim.
// //
// Specification // # Specification
// //
// https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests // https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests
// //
@ -99,20 +103,23 @@ func (r OIDCClaimsRequest) GetIDToken() *ClaimElement {
// specific task requested by the End-User. // specific task requested by the End-User.
// //
// Note that even if the Claims are not available because the End-User did not // Note that even if the Claims are not available because the End-User did not
// authorize their release or they are not present, the Authorization Server // authorize their release, or they are not present, the Authorization Server
// MUST NOT generate an error when Claims are not returned, whether they are // MUST NOT generate an error when Claims are not returned, whether they are
// Essential or Voluntary, unless otherwise specified in the description of // Essential or Voluntary, unless otherwise specified in the description of
// the specific claim. // the specific claim.
func (i IndividualClaimRequest) IsEssential() bool { func (i IndividualClaimsRequest) IsEssential() bool {
if essential, ok := i["essential"]; ok { if essential, ok := i["essential"]; ok {
return essential.(bool) if e, ok := essential.(bool); ok {
return e
} }
}
return false return false
} }
// Returns the wanted value for an individual claim request. // WantedValue returns the wanted value for an individual claim request.
// //
// Specification // # Specification
// //
// https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests // https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests
// //
@ -126,18 +133,20 @@ func (i IndividualClaimRequest) IsEssential() bool {
// value for the Claim being requested. Definitions of individual Claims can // value for the Claim being requested. Definitions of individual Claims can
// include requirements on how and whether the value qualifier is to be used // include requirements on how and whether the value qualifier is to be used
// when requesting that Claim. // when requesting that Claim.
func (i IndividualClaimRequest) WantedValue() *string { func (i IndividualClaimsRequest) WantedValue() (string, error) {
if value, ok := i["value"]; ok { if value, ok := i["value"]; ok {
valueString := value.(string) if valueString, ok := value.(string); ok {
return &valueString return valueString, nil
} }
return nil
} }
// Get the allowed values for an individual claim request that specifies return "", ErrNoValue
}
// AllowedValues gets the allowed values for an individual claim request that specifies
// a values field. // a values field.
// //
// Specification // # Specification
// //
// https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests // https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests
// //
@ -154,17 +163,20 @@ func (i IndividualClaimRequest) WantedValue() *string {
// being requested. Definitions of individual Claims can include requirements // being requested. Definitions of individual Claims can include requirements
// on how and whether the values qualifier is to be used when requesting that // on how and whether the values qualifier is to be used when requesting that
// Claim. // Claim.
func (i IndividualClaimRequest) AllowedValues() []string { func (i IndividualClaimsRequest) AllowedValues() []string {
if values, ok := i["values"]; ok { if values, ok := i["values"]; ok {
return values.([]string) if v, ok := values.([]string); ok {
return v
} }
}
return nil return nil
} }
// OpenIDConfiguration contains the parts of the OpenID discovery information // OpenIDConfiguration contains the parts of the OpenID discovery information
// that are relevant for us. // that are relevant for us.
// //
// Specifications // # Specifications
// //
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
// //
@ -174,7 +186,7 @@ type OpenIDConfiguration struct {
AuthorizationEndpoint string `json:"authorization_endpoint"` AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"` TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"` UserInfoEndpoint string `json:"userinfo_endpoint"`
JwksUri string `json:"jwks_uri"` JwksURI string `json:"jwks_uri"`
RegistrationEndpoint string `json:"registration_endpoint"` RegistrationEndpoint string `json:"registration_endpoint"`
ScopesSupported []string `json:"scopes_supported"` ScopesSupported []string `json:"scopes_supported"`
EndSessionEndpoint string `json:"end_session_endpoint"` EndSessionEndpoint string `json:"end_session_endpoint"`

View file

@ -1,6 +1,6 @@
/* /*
Copyright 2020, 2021 Jan Dittberner Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -39,44 +39,48 @@ func ConfigureApplication(
) (*koanf.Koanf, error) { ) (*koanf.Koanf, error) {
f := pflag.NewFlagSet("config", pflag.ContinueOnError) f := pflag.NewFlagSet("config", pflag.ContinueOnError)
f.Usage = func() { f.Usage = func() {
fmt.Println(f.FlagUsages()) fmt.Println(f.FlagUsages()) //nolint:forbidigo
os.Exit(0) os.Exit(0)
} }
f.StringSlice( f.StringSlice(
"conf", "conf",
[]string{fmt.Sprintf("%s.toml", strings.ToLower(appName))}, []string{fmt.Sprintf("%s.toml", strings.ToLower(appName))},
"path to one or more .toml files", "path to one or more .toml files",
) )
var err error
if err = f.Parse(os.Args[1:]); err != nil { if err := f.Parse(os.Args[1:]); err != nil {
logger.Fatal(err) logger.Fatal(err)
} }
config := koanf.New(".") config := koanf.New(".")
_ = config.Load(confmap.Provider(defaultConfig, "."), nil) _ = config.Load(confmap.Provider(defaultConfig, "."), nil)
cFiles, _ := f.GetStringSlice("conf") cFiles, _ := f.GetStringSlice("conf")
for _, c := range cFiles { for _, c := range cFiles {
if err := config.Load(file.Provider(c), toml.Parser()); err != nil { if err := config.Load(file.Provider(c), toml.Parser()); err != nil {
logger.Fatalf("error loading config file: %s", err) logger.Fatalf("error loading config file: %s", err)
} }
} }
if err := config.Load(posflag.Provider(f, ".", config), nil); err != nil { if err := config.Load(posflag.Provider(f, ".", config), nil); err != nil {
logger.Fatalf("error loading configuration: %s", err) logger.Fatalf("error loading configuration: %s", err)
} }
if err := config.Load( if err := config.Load(
file.Provider("resource_app.toml"), file.Provider("resource_app.toml"),
toml.Parser(), toml.Parser(),
); err != nil && !os.IsNotExist(err) { ); err != nil && !os.IsNotExist(err) {
logrus.Fatalf("error loading config: %v", err) logger.Fatalf("error loading config: %v", err)
} }
prefix := fmt.Sprintf("%s_", strings.ToUpper(appName)) prefix := fmt.Sprintf("%s_", strings.ToUpper(appName))
if err := config.Load(env.Provider(prefix, ".", func(s string) string { if err := config.Load(env.Provider(prefix, ".", func(s string) string {
return strings.Replace(strings.ToLower( return strings.ReplaceAll(strings.ToLower(strings.TrimPrefix(s, prefix)), "_", ".")
strings.TrimPrefix(s, prefix)), "_", ".", -1)
}), nil); err != nil { }), nil); err != nil {
logrus.Fatalf("error loading config: %v", err) logger.Fatalf("error loading config: %v", err)
} }
return config, err
return config, nil
} }

View file

@ -1,6 +1,6 @@
/* /*
Copyright 2020, 2021 Jan Dittberner Copyright 2020-2023 CAcert Inc.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ package services
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -28,7 +29,7 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
) )
func AddMessages(ctx context.Context) { func AddMessages(catalog *MessageCatalog) error {
messages := make(map[string]*i18n.Message) messages := make(map[string]*i18n.Message)
messages["unknown"] = &i18n.Message{ messages["unknown"] = &i18n.Message{
ID: "ErrorUnknown", ID: "ErrorUnknown",
@ -48,11 +49,13 @@ func AddMessages(ctx context.Context) {
} }
messages["IntroConsentRequested"] = &i18n.Message{ messages["IntroConsentRequested"] = &i18n.Message{
ID: "IntroConsentRequested", ID: "IntroConsentRequested",
Other: "The <strong>{{ .client }}</strong> application requested your consent for the following set of permissions:", Other: "The <strong>{{ .client }}</strong> application requested your consent for the following set of " +
"permissions:",
} }
messages["IntroConsentMoreInformation"] = &i18n.Message{ messages["IntroConsentMoreInformation"] = &i18n.Message{
ID: "IntroConsentMoreInformation", ID: "IntroConsentMoreInformation",
Other: "You can find more information about <strong>{{ .client }}</strong> at <a href=\"{{ .clientLink }}\">its description page</a>.", Other: "You can find more information about <strong>{{ .client }}</strong> at " +
"<a href=\"{{ .clientLink }}\">its description page</a>.",
} }
messages["ClaimsInformation"] = &i18n.Message{ messages["ClaimsInformation"] = &i18n.Message{
ID: "ClaimsInformation", ID: "ClaimsInformation",
@ -69,7 +72,8 @@ func AddMessages(ctx context.Context) {
messages["EmailChoiceText"] = &i18n.Message{ messages["EmailChoiceText"] = &i18n.Message{
ID: "EmailChoiceText", ID: "EmailChoiceText",
One: "You have presented a valid client certificate for the following email address:", One: "You have presented a valid client certificate for the following email address:",
Other: "You have presented a valid client certificate for multiple email addresses. Please choose which one you want to present to the application:", Other: "You have presented a valid client certificate for multiple email addresses. " +
"Please choose which one you want to present to the application:",
} }
messages["LoginTitle"] = &i18n.Message{ messages["LoginTitle"] = &i18n.Message{
ID: "LoginTitle", ID: "LoginTitle",
@ -97,7 +101,10 @@ func AddMessages(ctx context.Context) {
ID: "HintChooseAnIdentityForAuthentication", ID: "HintChooseAnIdentityForAuthentication",
Other: "Choose an identity for authentication.", Other: "Choose an identity for authentication.",
} }
GetMessageCatalog(ctx).AddMessages(messages)
catalog.AddMessages(messages)
return nil
} }
type contextKey int type contextKey int
@ -118,17 +125,21 @@ func (m *MessageCatalog) AddMessages(messages map[string]*i18n.Message) {
} }
} }
func (m *MessageCatalog) LookupErrorMessage(tag string, field string, value interface{}, localizer *i18n.Localizer) string { func (m *MessageCatalog) LookupErrorMessage(tag, field string, value interface{}, localizer *i18n.Localizer) string {
var message *i18n.Message var message *i18n.Message
message, ok := m.messages[fmt.Sprintf("%s-%s", field, tag)] message, ok := m.messages[fmt.Sprintf("%s-%s", field, tag)]
if !ok { if !ok {
m.logger.Infof("no specific error message %s-%s", field, tag) m.logger.Infof("no specific error message %s-%s", field, tag)
message, ok = m.messages[tag] message, ok = m.messages[tag]
if !ok { if !ok {
m.logger.Infof("no specific error message %s", tag) m.logger.Infof("no specific error message %s", tag)
message, ok = m.messages["unknown"] message, ok = m.messages["unknown"]
if !ok { if !ok {
m.logger.Warnf("no default translation found") m.logger.Warnf("no default translation found")
return tag return tag
} }
} }
@ -142,38 +153,41 @@ func (m *MessageCatalog) LookupErrorMessage(tag string, field string, value inte
}) })
if err != nil { if err != nil {
m.logger.Error(err) m.logger.Error(err)
return tag return tag
} }
return translation return translation
} }
func (m *MessageCatalog) LookupMessage(id string, templateData map[string]interface{}, localizer *i18n.Localizer) string { func (m *MessageCatalog) LookupMessage(
id string,
templateData map[string]interface{},
localizer *i18n.Localizer,
) string {
if message, ok := m.messages[id]; ok { if message, ok := m.messages[id]; ok {
translation, err := localizer.Localize(&i18n.LocalizeConfig{ translation, err := localizer.Localize(&i18n.LocalizeConfig{
DefaultMessage: message, DefaultMessage: message,
TemplateData: templateData, TemplateData: templateData,
}) })
if err != nil { if err != nil {
switch err.(type) { return m.handleLocalizeError(id, translation, err)
case *i18n.MessageNotFoundErr:
m.logger.Warnf("message %s not found: %v", id, err)
if translation != "" {
return translation
}
break
default:
m.logger.Error(err)
}
return id
}
return translation
} else {
m.logger.Warnf("no translation found for %s", id)
return id
}
} }
func (m *MessageCatalog) LookupMessagePlural(id string, templateData map[string]interface{}, localizer *i18n.Localizer, count int) string { return translation
}
m.logger.Warnf("no translation found for %s", id)
return id
}
func (m *MessageCatalog) LookupMessagePlural(
id string,
templateData map[string]interface{},
localizer *i18n.Localizer,
count int,
) string {
if message, ok := m.messages[id]; ok { if message, ok := m.messages[id]; ok {
translation, err := localizer.Localize(&i18n.LocalizeConfig{ translation, err := localizer.Localize(&i18n.LocalizeConfig{
DefaultMessage: message, DefaultMessage: message,
@ -181,38 +195,47 @@ func (m *MessageCatalog) LookupMessagePlural(id string, templateData map[string]
PluralCount: count, PluralCount: count,
}) })
if err != nil { if err != nil {
switch err.(type) { return m.handleLocalizeError(id, translation, err)
case *i18n.MessageNotFoundErr: }
return translation
}
m.logger.Warnf("no translation found for %s", id)
return id
}
func (m *MessageCatalog) handleLocalizeError(id string, translation string, err error) string {
var messageNotFound *i18n.MessageNotFoundErr
if errors.As(err, &messageNotFound) {
m.logger.Warnf("message %s not found: %v", id, err) m.logger.Warnf("message %s not found: %v", id, err)
if translation != "" { if translation != "" {
return translation return translation
} }
break } else {
default:
m.logger.Error(err) m.logger.Error(err)
} }
return id return id
} }
return translation
} else {
m.logger.Warnf("no translation found for %s", id)
return id
}
}
func InitI18n(ctx context.Context, logger *log.Logger, languages []string) context.Context { func InitI18n(logger *log.Logger, languages []string) (*i18n.Bundle, *MessageCatalog) {
bundle := i18n.NewBundle(language.English) bundle := i18n.NewBundle(language.English)
bundle.RegisterUnmarshalFunc("toml", toml.Unmarshal) bundle.RegisterUnmarshalFunc("toml", toml.Unmarshal)
for _, lang := range languages { for _, lang := range languages {
_, err := bundle.LoadMessageFile(fmt.Sprintf("active.%s.toml", lang)) _, err := bundle.LoadMessageFile(fmt.Sprintf("active.%s.toml", lang))
if err != nil { if err != nil {
logger.Warnln("message bundle de.toml not found") logger.Warnf("message bundle %s.toml not found", lang)
} }
} }
catalog := initMessageCatalog(logger) catalog := initMessageCatalog(logger)
ctx = context.WithValue(ctx, ctxI18nBundle, bundle)
ctx = context.WithValue(ctx, ctxI18nCatalog, catalog) return bundle, catalog
return ctx
} }
func initMessageCatalog(logger *log.Logger) *MessageCatalog { func initMessageCatalog(logger *log.Logger) *MessageCatalog {
@ -221,13 +244,22 @@ func initMessageCatalog(logger *log.Logger) *MessageCatalog {
ID: "ErrorTitle", ID: "ErrorTitle",
Other: "An error has occurred", Other: "An error has occurred",
} }
return &MessageCatalog{messages: messages, logger: logger} return &MessageCatalog{messages: messages, logger: logger}
} }
func GetI18nBundle(ctx context.Context) *i18n.Bundle { func GetI18nBundle(ctx context.Context) (*i18n.Bundle, error) {
return ctx.Value(ctxI18nBundle).(*i18n.Bundle) if b, ok := ctx.Value(ctxI18nBundle).(*i18n.Bundle); ok {
return b, nil
} }
func GetMessageCatalog(ctx context.Context) *MessageCatalog { return nil, errors.New("context value is not a Bundle")
return ctx.Value(ctxI18nCatalog).(*MessageCatalog) }
func GetMessageCatalog(ctx context.Context) (*MessageCatalog, error) {
if c, ok := ctx.Value(ctxI18nCatalog).(*MessageCatalog); ok {
return c, nil
}
return nil, errors.New("context value is not a MessageCatalog")
} }