Switch from logrus to log/slog

Use log/slog from the Go standard library to reduce dependencies.
This commit is contained in:
Jan Dittberner 2024-05-12 12:02:27 +02:00
parent 291c1857c6
commit c4724723b6
17 changed files with 194 additions and 172 deletions

View file

@ -23,16 +23,15 @@ import (
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"os" "os"
"time" "time"
"code.cacert.org/cacert/oidc-demo-app/ui"
"github.com/knadh/koanf" "github.com/knadh/koanf"
"github.com/knadh/koanf/parsers/toml" "github.com/knadh/koanf/parsers/toml"
"github.com/knadh/koanf/providers/confmap" "github.com/knadh/koanf/providers/confmap"
log "github.com/sirupsen/logrus"
"code.cacert.org/cacert/oidc-demo-app/ui"
"code.cacert.org/cacert/oidc-demo-app/internal/handlers" "code.cacert.org/cacert/oidc-demo-app/internal/handlers"
"code.cacert.org/cacert/oidc-demo-app/internal/services" "code.cacert.org/cacert/oidc-demo-app/internal/services"
@ -85,44 +84,58 @@ func (f *StaticFileInfoWrapper) ModTime() time.Time {
} }
func main() { func main() {
logger := log.New() var (
logLevel = new(slog.LevelVar)
config, err := services.ConfigureApplication( logHandler slog.Handler
logger, logger *slog.Logger
"RESOURCE_APP",
services.DefaultConfiguration,
) )
logHandler = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: logLevel})
logger = slog.New(logHandler)
slog.SetDefault(logger)
config, err := services.ConfigureApplication("RESOURCE_APP", services.DefaultConfiguration)
if err != nil { if err != nil {
log.Fatalf("error loading configuration: %v", err) logger.Error("error loading configuration", "error", err)
os.Exit(1)
} }
oidcServer := config.MustString("oidc.server") oidcServer := config.MustString("oidc.server")
oidcClientID := config.MustString("oidc.client-id") oidcClientID := config.MustString("oidc.client-id")
oidcClientSecret := config.MustString("oidc.client-secret") oidcClientSecret := config.MustString("oidc.client-secret")
if level := config.String("log.level"); level != "" { if level := config.Bytes("log.level"); level != nil {
logLevel, err := log.ParseLevel(level) if err := logLevel.UnmarshalText(level); err != nil {
if err != nil { logger.Error("could not parse log level", "error", err)
logger.WithError(err).Fatal("could not parse log level") os.Exit(1)
} }
logger.SetLevel(logLevel) slog.SetLogLoggerLevel(logLevel.Level())
} }
if config.Bool("log.json") { if config.Bool("log.json") {
logger.SetFormatter(&log.JSONFormatter{}) logHandler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: logLevel})
logger = slog.New(logHandler)
slog.SetDefault(logger)
} }
logger.WithFields(log.Fields{ logLogger := slog.NewLogLogger(logger.Handler(), logLevel.Level())
"version": version, "commit": commit, "date": date,
}).Info("Starting CAcert OpenID Connect demo application") logger.Info(
logger.Infoln("Server is starting") "Starting CAcert OpenID Connect demo application",
"version", version, "commit", commit, "date", date,
)
logger.Info("Server is starting")
bundle, catalog := services.InitI18n(logger, config.Strings("i18n.languages")) bundle, catalog := services.InitI18n(logger, config.Strings("i18n.languages"))
services.AddMessages(catalog) services.AddMessages(catalog)
tlsClientConfig := getTLSConfig(config) tlsClientConfig, err := getTLSConfig(config)
if err != nil {
logger.Error("error loading tls config", "error", err)
os.Exit(1)
}
apiTransport := &http.Transport{TLSClientConfig: tlsClientConfig} apiTransport := &http.Transport{TLSClientConfig: tlsClientConfig}
apiClient := &http.Client{Transport: apiTransport} apiClient := &http.Client{Transport: apiTransport}
@ -134,37 +147,50 @@ func main() {
APIClient: apiClient, APIClient: apiClient,
}) })
if err != nil { if err != nil {
logger.WithError(err).Fatal("OpenID Connect discovery failed") logger.Error("OpenID Connect discovery failed", "error", err)
os.Exit(1)
} }
sessionPath, sessionAuthKey, sessionEncKey := configureSessionParameters(config) sessionPath, sessionAuthKey, sessionEncKey, err := configureSessionParameters(logger, config)
services.InitSessionStore(logger, sessionPath, sessionAuthKey, sessionEncKey) if err := services.InitSessionStore(logger, sessionPath, sessionAuthKey, sessionEncKey); err != nil {
logger.Error("could not initialize session store", "error", err)
authMiddleware := handlers.Authenticate(oidcInfo.OAuth2Config) os.Exit(1)
}
authMiddleware := handlers.Authenticate(logger, oidcInfo.OAuth2Config)
publicURL := buildPublicURL(config.MustString("server.name"), config.MustInt("server.port")) publicURL := buildPublicURL(config.MustString("server.name"), config.MustInt("server.port"))
tokenInfoService, err := services.InitTokenInfoService(logger, oidcInfo) tokenInfoService, err := services.InitTokenInfoService(logger, oidcInfo)
if err != nil { if err != nil {
logger.WithError(err).Fatal("could not initialize token info service") logger.Error("could not initialize token info service", "error", err)
os.Exit(1)
} }
indexHandler, err := handlers.NewIndexHandler(logger, bundle, catalog, oidcInfo, publicURL, tokenInfoService) indexHandler, err := handlers.NewIndexHandler(logger, bundle, catalog, oidcInfo, publicURL, tokenInfoService)
if err != nil { if err != nil {
logger.WithError(err).Fatal("could not initialize index handler") logger.Error("could not initialize index handler", "error", err)
os.Exit(1)
} }
protectedResource, err := handlers.NewProtectedResourceHandler( protectedResource, err := handlers.NewProtectedResourceHandler(
logger, bundle, catalog, oidcInfo, publicURL, tokenInfoService, logger, bundle, catalog, oidcInfo, publicURL, tokenInfoService,
) )
if err != nil { if err != nil {
logger.WithError(err).Fatal("could not initialize protected resource handler") logger.Error("could not initialize protected resource handler", "error", err)
} }
callbackHandler := handlers.NewCallbackHandler(logger, oidcInfo.KeySet, oidcInfo.OAuth2Config) callbackHandler := handlers.NewCallbackHandler(logger, oidcInfo.KeySet, oidcInfo.OAuth2Config)
afterLogoutHandler := handlers.NewAfterLogoutHandler(logger) afterLogoutHandler := handlers.NewAfterLogoutHandler(logger)
staticFiles := staticFileHandler(logger) staticFiles, err := staticFileHandler()
if err != nil {
logger.Error("could not initialize static file handler", "error", err)
os.Exit(1)
}
router := http.NewServeMux() router := http.NewServeMux()
router.Handle("/", indexHandler) router.Handle("/", indexHandler)
@ -182,12 +208,13 @@ func main() {
} }
tracing := handlers.Tracing(nextRequestID) tracing := handlers.Tracing(nextRequestID)
logging := handlers.Logging(logger) logging := handlers.Logging(logLogger)
hsts := handlers.EnableHSTS() hsts := handlers.EnableHSTS()
errorMiddleware, err := handlers.ErrorHandling(logger, bundle, catalog) errorMiddleware, err := handlers.ErrorHandling(logger, bundle, catalog)
if err != nil { if err != nil {
logger.WithError(err).Fatal("could not initialize request error handling") logger.Error("could not initialize request error handling", "error", err)
os.Exit(1)
} }
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
@ -204,13 +231,16 @@ func main() {
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
} }
handlers.StartApplication(context.Background(), logger, server, publicURL, config) if err := handlers.StartApplication(context.Background(), logger, server, publicURL, config); err != nil {
logger.Error("could not start application", "error", err)
os.Exit(1)
}
} }
func staticFileHandler(logger *log.Logger) func(w http.ResponseWriter, r *http.Request) { func staticFileHandler() (func(w http.ResponseWriter, r *http.Request), error) {
stat, err := os.Stat(os.Args[0]) stat, err := os.Stat(os.Args[0])
if err != nil { if err != nil {
logger.WithError(err).Fatal("could not use stat on binary") return nil, fmt.Errorf("could not use stat on binary: %w", err)
} }
fileServer := http.FileServer(&StaticFSWrapper{FileSystem: http.FS(ui.Static), ModTime: stat.ModTime()}) fileServer := http.FileServer(&StaticFSWrapper{FileSystem: http.FS(ui.Static), ModTime: stat.ModTime()})
@ -223,10 +253,10 @@ func staticFileHandler(logger *log.Logger) func(w http.ResponseWriter, r *http.R
fileServer.ServeHTTP(w, r) fileServer.ServeHTTP(w, r)
} }
return staticFiles return staticFiles, nil
} }
func getTLSConfig(config *koanf.Koanf) *tls.Config { func getTLSConfig(config *koanf.Koanf) (*tls.Config, error) {
tlsClientConfig := &tls.Config{ tlsClientConfig := &tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
@ -237,14 +267,14 @@ func getTLSConfig(config *koanf.Koanf) *tls.Config {
pemBytes, err := os.ReadFile(rootCAFile) pemBytes, err := os.ReadFile(rootCAFile)
if err != nil { if err != nil {
log.Fatalf("could not read CA certificate file: %v", err) return nil, fmt.Errorf("could not read CA certificate file: %w", err)
} }
caCertPool.AppendCertsFromPEM(pemBytes) caCertPool.AppendCertsFromPEM(pemBytes)
tlsClientConfig.RootCAs = caCertPool tlsClientConfig.RootCAs = caCertPool
} }
return tlsClientConfig return tlsClientConfig, nil
} }
func buildPublicURL(hostname string, port int) string { func buildPublicURL(hostname string, port int) string {
@ -257,28 +287,36 @@ func buildPublicURL(hostname string, port int) string {
return fmt.Sprintf("https://%s", hostname) return fmt.Sprintf("https://%s", hostname)
} }
func configureSessionParameters(config *koanf.Koanf) (string, []byte, []byte) { func configureSessionParameters(logger *slog.Logger, config *koanf.Koanf) (string, []byte, []byte, error) {
sessionPath := config.MustString("session.path") sessionPath := config.MustString("session.path")
sessionAuthKey, err := base64.StdEncoding.DecodeString(config.String("session.auth-key")) sessionAuthKey, err := base64.StdEncoding.DecodeString(config.String("session.auth-key"))
if err != nil { if err != nil {
log.WithError(err).Fatal("could not decode session auth key") return "", nil, nil, fmt.Errorf("could not decode session authentication key: %w", err)
} }
sessionEncKey, err := base64.StdEncoding.DecodeString(config.String("session.enc-key")) sessionEncKey, err := base64.StdEncoding.DecodeString(config.String("session.enc-key"))
if err != nil { if err != nil {
log.WithError(err).Fatal("could not decode session encryption key") return "", nil, nil, fmt.Errorf("could not decode session encryption key: %w", err)
} }
generated := false generated := false
if len(sessionAuthKey) != sessionAuthKeyLength { if len(sessionAuthKey) != sessionAuthKeyLength {
sessionAuthKey = services.GenerateKey(sessionAuthKeyLength) sessionAuthKey, err = services.GenerateKey(sessionAuthKeyLength)
if err != nil {
return "", nil, nil, fmt.Errorf("could not generate session authentication key: %w", err)
}
generated = true generated = true
} }
if len(sessionEncKey) != sessionKeyLength { if len(sessionEncKey) != sessionKeyLength {
sessionEncKey = services.GenerateKey(sessionKeyLength) sessionEncKey, err = services.GenerateKey(sessionKeyLength)
if err != nil {
return "", nil, nil, fmt.Errorf("could not generate session encryption key: %w", err)
}
generated = true generated = true
} }
@ -290,11 +328,12 @@ func configureSessionParameters(config *koanf.Koanf) (string, []byte, []byte) {
tomlData, err := config.Marshal(toml.Parser()) tomlData, err := config.Marshal(toml.Parser())
if err != nil { if err != nil {
log.WithError(err).Fatal("could not encode session config") return "", nil, nil, fmt.Errorf("could not encode session configuration: %w", err)
} }
log.Infof("put the following in your resource_app.toml:\n%s", string(tomlData)) logger.Info("put the following in your resource_app.toml")
fmt.Print(string(tomlData)) //nolint:forbidigo
} }
return sessionPath, sessionAuthKey, sessionEncKey return sessionPath, sessionAuthKey, sessionEncKey, nil
} }

1
go.mod
View file

@ -8,7 +8,6 @@ require (
github.com/knadh/koanf v1.5.0 github.com/knadh/koanf v1.5.0
github.com/lestrrat-go/jwx v1.2.29 github.com/lestrrat-go/jwx v1.2.29
github.com/nicksnyder/go-i18n/v2 v2.4.0 github.com/nicksnyder/go-i18n/v2 v2.4.0
github.com/sirupsen/logrus v1.9.3
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
golang.org/x/oauth2 v0.20.0 golang.org/x/oauth2 v0.20.0
golang.org/x/text v0.15.0 golang.org/x/text v0.15.0

3
go.sum
View file

@ -252,8 +252,6 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -372,7 +370,6 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View file

@ -18,19 +18,18 @@ limitations under the License.
package handlers package handlers
import ( import (
"log/slog"
"net/http" "net/http"
"github.com/sirupsen/logrus"
) )
type AfterLogoutHandler struct { type AfterLogoutHandler struct {
logger *logrus.Logger logger *slog.Logger
} }
func (h *AfterLogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *AfterLogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
session, err := GetSession(r) session, err := GetSession(r)
if err != nil { if err != nil {
h.logger.WithError(err).Error("could not get session") h.logger.Error("could not get session", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@ -39,13 +38,13 @@ func (h *AfterLogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
session.Options.MaxAge = -1 session.Options.MaxAge = -1
if err = session.Save(r, w); err != nil { if err = session.Save(r, w); err != nil {
h.logger.WithError(err).Error("could not save session") h.logger.Error("could not save session", "error", err)
} }
w.Header().Set("Location", "/") w.Header().Set("Location", "/")
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
} }
func NewAfterLogoutHandler(logger *logrus.Logger) *AfterLogoutHandler { func NewAfterLogoutHandler(logger *slog.Logger) *AfterLogoutHandler {
return &AfterLogoutHandler{logger: logger} return &AfterLogoutHandler{logger: logger}
} }

View file

@ -20,6 +20,7 @@ package handlers
import ( import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
@ -34,12 +35,13 @@ const (
sessionName = "resource_app" sessionName = "resource_app"
) )
func Authenticate(oauth2Config *oauth2.Config) func(http.Handler) http.Handler { func Authenticate(logger *slog.Logger, oauth2Config *oauth2.Config) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := GetSession(r) session, err := GetSession(r)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) logger.ErrorContext(r.Context(), "failed to get session", "error", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
@ -53,14 +55,21 @@ func Authenticate(oauth2Config *oauth2.Config) func(http.Handler) http.Handler {
session.Values[services.SessionRedirectTarget] = r.URL.String() session.Values[services.SessionRedirectTarget] = r.URL.String()
if err = session.Save(r, w); err != nil { if err = session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) logger.ErrorContext(r.Context(), "failed to save session", "error", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
authURL := oauth2Config.AuthCodeURL( state, err := services.GenerateKey(oauth2RedirectStateLength)
base64.URLEncoding.EncodeToString(services.GenerateKey(oauth2RedirectStateLength)), if err != nil {
logger.ErrorContext(
r.Context(), "failed to generate state for starting OIDC flow", "error", err,
) )
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
authURL := oauth2Config.AuthCodeURL(base64.URLEncoding.EncodeToString(state))
w.Header().Set("Location", authURL) w.Header().Set("Location", authURL)
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
@ -68,28 +77,9 @@ func Authenticate(oauth2Config *oauth2.Config) func(http.Handler) http.Handler {
} }
} }
/*
func getRequestedClaims(logger *log.Logger) string {
claims := make(models.OIDCClaimsRequest)
claims["userinfo"] = make(models.ClaimElement)
essentialItem := make(models.IndividualClaimRequest)
essentialItem["essential"] = true
claims["userinfo"]["https://auth.cacert.org/groups"] = &essentialItem
target := make([]byte, 0)
buf := bytes.NewBuffer(target)
enc := json.NewEncoder(buf)
if err := enc.Encode(claims); err != nil {
logger.WithError(err).Warn("could not encode claims request parameter")
}
return buf.String()
}
*/
func GetSession(r *http.Request) (*sessions.Session, error) { func GetSession(r *http.Request) (*sessions.Session, error) {
session, err := services.GetSessionStore().Get(r, sessionName) session, err := services.GetSessionStore().Get(r, sessionName)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get session") return nil, fmt.Errorf("could not get session")
} }

View file

@ -21,10 +21,10 @@ import (
"context" "context"
"fmt" "fmt"
"html/template" "html/template"
"log/slog"
"net/http" "net/http"
"github.com/nicksnyder/go-i18n/v2/i18n" "github.com/nicksnyder/go-i18n/v2/i18n"
log "github.com/sirupsen/logrus"
"code.cacert.org/cacert/oidc-demo-app/ui" "code.cacert.org/cacert/oidc-demo-app/ui"
@ -47,7 +47,7 @@ type ErrorDetails struct {
type ErrorBucket struct { type ErrorBucket struct {
errorDetails *ErrorDetails errorDetails *ErrorDetails
templates *template.Template templates *template.Template
logger *log.Logger logger *slog.Logger
bundle *i18n.Bundle bundle *i18n.Bundle
messageCatalog *services.MessageCatalog messageCatalog *services.MessageCatalog
} }
@ -74,7 +74,7 @@ func (b *ErrorBucket) serveHTTP(w http.ResponseWriter, r *http.Request) {
err := b.templates.Lookup("base").Execute(w, data) err := b.templates.Lookup("base").Execute(w, data)
if err != nil { if err != nil {
log.WithError(err).Error("error rendering error template") b.logger.Error("error rendering error template", "error", err)
http.Error( http.Error(
w, w,
http.StatusText(http.StatusInternalServerError), http.StatusText(http.StatusInternalServerError),
@ -143,7 +143,7 @@ func (w *errorResponseWriter) Write(content []byte) (int, error) {
} }
func ErrorHandling( func ErrorHandling(
logger *log.Logger, logger *slog.Logger,
bundle *i18n.Bundle, bundle *i18n.Bundle,
messageCatalog *services.MessageCatalog, messageCatalog *services.MessageCatalog,
) (func(http.Handler) http.Handler, error) { ) (func(http.Handler) http.Handler, error) {

View file

@ -20,11 +20,11 @@ package handlers
import ( import (
"fmt" "fmt"
"html/template" "html/template"
"log/slog"
"net/http" "net/http"
"net/url" "net/url"
"github.com/nicksnyder/go-i18n/v2/i18n" "github.com/nicksnyder/go-i18n/v2/i18n"
log "github.com/sirupsen/logrus"
"code.cacert.org/cacert/oidc-demo-app/ui" "code.cacert.org/cacert/oidc-demo-app/ui"
@ -34,7 +34,7 @@ import (
type IndexHandler struct { type IndexHandler struct {
bundle *i18n.Bundle bundle *i18n.Bundle
indexTemplate *template.Template indexTemplate *template.Template
logger *log.Logger logger *slog.Logger
logoutURL string logoutURL string
messageCatalog *services.MessageCatalog messageCatalog *services.MessageCatalog
publicURL string publicURL string
@ -80,7 +80,7 @@ func (h *IndexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
tokenInfo, err := h.tokenInfo.GetTokenInfo(session) tokenInfo, err := h.tokenInfo.GetTokenInfo(session)
if err != nil { if err != nil {
h.logger.WithError(err).Error("failed to get token info for request") h.logger.Error("failed to get token info for request", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -88,7 +88,7 @@ func (h *IndexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
if !tokenInfo.Expires.IsZero() { if !tokenInfo.Expires.IsZero() {
h.logger.WithField("expires", tokenInfo.Expires).Info("id token expires at") h.logger.Info("id token expires at", "expires", tokenInfo.Expires)
} }
w.Header().Add("Content-Type", "text/html") w.Header().Add("Content-Type", "text/html")
@ -130,7 +130,7 @@ func (h *IndexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
func NewIndexHandler( func NewIndexHandler(
logger *log.Logger, logger *slog.Logger,
bundle *i18n.Bundle, bundle *i18n.Bundle,
catalog *services.MessageCatalog, catalog *services.MessageCatalog,
oidcInfo *services.OIDCInformation, oidcInfo *services.OIDCInformation,

View file

@ -20,10 +20,9 @@ package handlers
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/http" "net/http"
"sync/atomic" "sync/atomic"
log "github.com/sirupsen/logrus"
) )
type key int type key int
@ -64,7 +63,7 @@ func Logging(logger *log.Logger) func(http.Handler) http.Handler {
requestID = "unknown" requestID = "unknown"
} }
logger.Infof( logger.Printf(
"[%s] %s \"%s %s\" %d %d \"%s\"", "[%s] %s \"%s %s\" %d %d \"%s\"",
requestID, requestID,
r.RemoteAddr, r.RemoteAddr,

View file

@ -19,12 +19,12 @@ package handlers
import ( import (
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"time" "time"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jwk"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"code.cacert.org/cacert/oidc-demo-app/internal/services" "code.cacert.org/cacert/oidc-demo-app/internal/services"
@ -32,7 +32,7 @@ import (
type OidcCallbackHandler struct { type OidcCallbackHandler struct {
keySet jwk.Set keySet jwk.Set
logger *log.Logger logger *slog.Logger
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
} }
@ -59,7 +59,7 @@ func (c *OidcCallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
tok, err := c.oauth2Config.Exchange(r.Context(), code) tok, err := c.oauth2Config.Exchange(r.Context(), code)
if err != nil { if err != nil {
c.logger.WithError(err).Error("could not perform token exchange") c.logger.Error("could not perform token exchange", "error", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
@ -67,14 +67,14 @@ func (c *OidcCallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
session, err := GetSession(r) session, err := GetSession(r)
if err != nil { if err != nil {
c.logger.WithError(err).Error("could not get session") c.logger.Error("could not get session", "error", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
if err = c.storeTokens(session, tok); err != nil { if err = c.storeTokens(session, tok); err != nil {
c.logger.WithError(err).Error("could not store token in session") c.logger.Error("could not store token in session", "error", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@ -82,7 +82,7 @@ func (c *OidcCallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
} }
if err = session.Save(r, w); err != nil { if err = session.Save(r, w); err != nil {
c.logger.WithError(err).Error("could not save session") c.logger.Error("could not save session", "error", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@ -143,19 +143,20 @@ func (c *OidcCallbackHandler) storeTokens(
return fmt.Errorf("could not parse ID token: %w", err) return fmt.Errorf("could not parse ID token: %w", err)
} }
c.logger.WithFields(log.Fields{ c.logger.Debug(
"sub": oidcToken.Subject(), "receive OpenID Connect ID Token",
"aud": oidcToken.Audience(), "sub", oidcToken.Subject(),
"issued_at": oidcToken.IssuedAt(), "aud", oidcToken.Audience(),
"iss": oidcToken.Issuer(), "issued_at", oidcToken.IssuedAt(),
"not_before": oidcToken.NotBefore(), "iss", oidcToken.Issuer(),
"exp": oidcToken.Expiration(), "not_before", oidcToken.NotBefore(),
}).Debug("receive OpenID Connect ID Token") "exp", oidcToken.Expiration(),
)
return nil return nil
} }
func NewCallbackHandler(logger *log.Logger, keySet jwk.Set, oauth2Config *oauth2.Config) *OidcCallbackHandler { func NewCallbackHandler(logger *slog.Logger, keySet jwk.Set, oauth2Config *oauth2.Config) *OidcCallbackHandler {
return &OidcCallbackHandler{ return &OidcCallbackHandler{
keySet: keySet, keySet: keySet,
logger: logger, logger: logger,

View file

@ -20,19 +20,18 @@ package handlers
import ( import (
"fmt" "fmt"
"html/template" "html/template"
"log/slog"
"net/http" "net/http"
"net/url" "net/url"
"github.com/nicksnyder/go-i18n/v2/i18n"
log "github.com/sirupsen/logrus"
"code.cacert.org/cacert/oidc-demo-app/internal/services" "code.cacert.org/cacert/oidc-demo-app/internal/services"
"code.cacert.org/cacert/oidc-demo-app/ui" "code.cacert.org/cacert/oidc-demo-app/ui"
"github.com/nicksnyder/go-i18n/v2/i18n"
) )
type ProtectedResource struct { type ProtectedResource struct {
bundle *i18n.Bundle bundle *i18n.Bundle
logger *log.Logger logger *slog.Logger
protectedTemplate *template.Template protectedTemplate *template.Template
logoutURL string logoutURL string
tokenInfo *services.TokenInfoService tokenInfo *services.TokenInfoService
@ -79,7 +78,7 @@ func (h *ProtectedResource) ServeHTTP(w http.ResponseWriter, r *http.Request) {
tokenInfo, err := h.tokenInfo.GetTokenInfo(session) tokenInfo, err := h.tokenInfo.GetTokenInfo(session)
if err != nil { if err != nil {
h.logger.WithError(err).Error("failed to get token info for request") h.logger.Error("failed to get token info for request", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -87,7 +86,7 @@ func (h *ProtectedResource) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
if !tokenInfo.Expires.IsZero() { if !tokenInfo.Expires.IsZero() {
h.logger.WithField("expires", tokenInfo.Expires).Info("id token expires at") h.logger.Info("id token expires at", "expires", tokenInfo.Expires)
} }
if tokenInfo.IDToken == "" { if tokenInfo.IDToken == "" {
@ -130,7 +129,7 @@ func (h *ProtectedResource) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
func NewProtectedResourceHandler( func NewProtectedResourceHandler(
logger *log.Logger, logger *slog.Logger,
bundle *i18n.Bundle, bundle *i18n.Bundle,
catalog *services.MessageCatalog, catalog *services.MessageCatalog,
oidcInfo *services.OIDCInformation, oidcInfo *services.OIDCInformation,

View file

@ -20,6 +20,8 @@ package handlers
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"log/slog"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@ -27,23 +29,16 @@ import (
"time" "time"
"github.com/knadh/koanf" "github.com/knadh/koanf"
"github.com/sirupsen/logrus"
) )
func StartApplication( func StartApplication(ctx context.Context, logger *slog.Logger, server *http.Server, publicURL string, config *koanf.Koanf) error {
ctx context.Context,
logger *logrus.Logger,
server *http.Server,
publicURL string,
config *koanf.Koanf,
) {
done := make(chan bool) done := make(chan bool)
quit := make(chan os.Signal, 1) quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt) signal.Notify(quit, os.Interrupt)
go func() { go func() {
<-quit <-quit
logger.Infoln("Server is shutting down...") logger.Info("Server is shutting down...")
atomic.StoreInt32(&Healthy, 0) atomic.StoreInt32(&Healthy, 0)
const shutdownWaitTime = 30 * time.Second const shutdownWaitTime = 30 * time.Second
@ -55,24 +50,25 @@ func StartApplication(
server.SetKeepAlivesEnabled(false) server.SetKeepAlivesEnabled(false)
if err := server.Shutdown(ctx); err != nil { if err := server.Shutdown(ctx); err != nil {
logger.WithError(err).Fatal("Could not gracefully shutdown the server") logger.Error("Could not gracefully shutdown the server", "error", err)
} }
close(done) close(done)
}() }()
logger.WithField("public_url", publicURL).Info("Server is ready to handle requests") logger.Info("Server is ready to handle requests", "public_url", publicURL)
atomic.StoreInt32(&Healthy, 1) atomic.StoreInt32(&Healthy, 1)
if err := server.ListenAndServeTLS( if err := server.ListenAndServeTLS(
config.String("server.certificate"), config.String("server.key"), config.String("server.certificate"), config.String("server.key"),
); err != nil && !errors.Is(err, http.ErrServerClosed) { ); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.WithError(err).WithField( logger.Error("Could not listen on requested address", "server_address", server.Addr)
"server_address",
server.Addr, return fmt.Errorf("listening failed: %w", err)
).Fatal("Could not listen on requested address")
} }
<-done <-done
logger.Infoln("Server stopped") logger.Info("Server stopped")
return nil
} }

View file

@ -19,6 +19,7 @@ package services
import ( import (
"fmt" "fmt"
"log"
"os" "os"
"strings" "strings"
@ -28,7 +29,6 @@ import (
"github.com/knadh/koanf/providers/env" "github.com/knadh/koanf/providers/env"
"github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/providers/file"
"github.com/knadh/koanf/providers/posflag" "github.com/knadh/koanf/providers/posflag"
"github.com/sirupsen/logrus"
"github.com/spf13/pflag" "github.com/spf13/pflag"
) )
@ -51,13 +51,12 @@ var DefaultConfiguration = map[string]interface{}{
} }
func ConfigureApplication( func ConfigureApplication(
logger *logrus.Logger,
appName string, appName string,
defaultConfig map[string]interface{}, defaultConfig map[string]interface{},
) (*koanf.Koanf, error) { ) (*koanf.Koanf, error) {
f := pflag.NewFlagSet("config", pflag.ContinueOnError) f := pflag.NewFlagSet("config", pflag.ContinueOnError)
f.Usage = func() { f.Usage = func() {
logger.Info(f.FlagUsages()) log.Print(f.FlagUsages())
os.Exit(0) os.Exit(0)
} }
@ -70,7 +69,7 @@ func ConfigureApplication(
var err error var err error
if err = f.Parse(os.Args[1:]); err != nil { if err = f.Parse(os.Args[1:]); err != nil {
logger.WithError(err).Fatal("could not parse command line arguments") return nil, fmt.Errorf("could not parse command line arguments: %w", err)
} }
config := koanf.New(".") config := koanf.New(".")
@ -78,18 +77,18 @@ func ConfigureApplication(
_ = config.Load(confmap.Provider(defaultConfig, "."), nil) _ = config.Load(confmap.Provider(defaultConfig, "."), nil)
if err = config.Load(file.Provider(defaultFile), toml.Parser()); err != nil && !os.IsNotExist(err) { if err = config.Load(file.Provider(defaultFile), toml.Parser()); err != nil && !os.IsNotExist(err) {
logrus.WithError(err).WithField("file", defaultFile).Fatal("error loading configuration from file") return nil, fmt.Errorf("could not load configuration from file %s: %w", defaultFile, err)
} }
cFiles, _ := f.GetStringSlice("conf") cFiles, _ := f.GetStringSlice("conf")
for _, c := range cFiles { for _, c := range cFiles {
if err = config.Load(file.Provider(c), toml.Parser()); err != nil { if err = config.Load(file.Provider(c), toml.Parser()); err != nil {
logger.WithError(err).WithField("file", c).Fatal("error loading configuration from file") return nil, fmt.Errorf("error loading configuration from file %s: %w", c, err)
} }
} }
if err = config.Load(posflag.Provider(f, ".", config), nil); err != nil { if err = config.Load(posflag.Provider(f, ".", config), nil); err != nil {
logger.WithError(err).Fatal("error loading configuration from command line") return nil, fmt.Errorf("error loading configuration from command line: %w", err)
} }
prefix := fmt.Sprintf("%s_", strings.ToUpper(appName)) prefix := fmt.Sprintf("%s_", strings.ToUpper(appName))
@ -97,7 +96,7 @@ func ConfigureApplication(
if err = config.Load(env.Provider(prefix, ".", func(s string) string { if err = config.Load(env.Provider(prefix, ".", func(s string) string {
return strings.ReplaceAll(strings.ToLower(strings.TrimPrefix(s, prefix)), "_", ".") return strings.ReplaceAll(strings.ToLower(strings.TrimPrefix(s, prefix)), "_", ".")
}), nil); err != nil { }), nil); err != nil {
logrus.WithError(err).Fatal("error loading configuration from environment") return nil, fmt.Errorf("error loading configuration from environment variables: %w", err)
} }
return config, nil return config, nil

View file

@ -20,6 +20,7 @@ package services
import ( import (
"errors" "errors"
"fmt" "fmt"
"log/slog"
"github.com/nicksnyder/go-i18n/v2/i18n" "github.com/nicksnyder/go-i18n/v2/i18n"
"golang.org/x/text/language" "golang.org/x/text/language"
@ -27,7 +28,6 @@ import (
"code.cacert.org/cacert/oidc-demo-app/translations" "code.cacert.org/cacert/oidc-demo-app/translations"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
log "github.com/sirupsen/logrus"
) )
func AddMessages(catalog *MessageCatalog) { func AddMessages(catalog *MessageCatalog) {
@ -83,7 +83,7 @@ func AddMessages(catalog *MessageCatalog) {
type MessageCatalog struct { type MessageCatalog struct {
messages map[string]*i18n.Message messages map[string]*i18n.Message
logger *log.Logger logger *slog.Logger
} }
func (m *MessageCatalog) AddMessages(messages map[string]*i18n.Message) { func (m *MessageCatalog) AddMessages(messages map[string]*i18n.Message) {
@ -102,11 +102,11 @@ func (m *MessageCatalog) LookupErrorMessage(
message, ok := m.messages[fieldTag] message, ok := m.messages[fieldTag]
if !ok { if !ok {
m.logger.WithField("field_tag", fieldTag).Info("no specific error message for field and tag") m.logger.Info("no specific error message for field and tag", "field_tag", fieldTag)
message, ok = m.messages[tag] message, ok = m.messages[tag]
if !ok { if !ok {
m.logger.WithField("tag", tag).Info("no specific error message for tag") m.logger.Info("no specific error message for tag", "tag", tag)
message, ok = m.messages["unknown"] message, ok = m.messages["unknown"]
if !ok { if !ok {
@ -124,7 +124,7 @@ func (m *MessageCatalog) LookupErrorMessage(
}, },
}) })
if err != nil { if err != nil {
m.logger.WithError(err).Error("localization failed") m.logger.Error("localization failed", "error", err)
return tag return tag
} }
@ -149,7 +149,7 @@ func (m *MessageCatalog) LookupMessage(
return translation return translation
} }
m.logger.WithField("id", id).Warn("no translation found for id") m.logger.Warn("no translation found for id", "id", id)
return id return id
} }
@ -158,19 +158,19 @@ func (m *MessageCatalog) handleLocalizeError(id string, translation string, err
var messageNotFound *i18n.MessageNotFoundErr var messageNotFound *i18n.MessageNotFoundErr
if errors.As(err, &messageNotFound) { if errors.As(err, &messageNotFound) {
m.logger.WithError(err).WithField("message", id).Warn("message not found") m.logger.Warn("message not found", "error", err, "message", id)
if translation != "" { if translation != "" {
return translation return translation
} }
} else { } else {
m.logger.WithError(err).WithField("message", id).Error("translation error") m.logger.Error("translation error", "error", err, "message", id)
} }
return id return id
} }
func InitI18n(logger *log.Logger, languages []string) (*i18n.Bundle, *MessageCatalog) { func InitI18n(logger *slog.Logger, languages []string) (*i18n.Bundle, *MessageCatalog) {
bundle := i18n.NewBundle(language.English) bundle := i18n.NewBundle(language.English)
bundle.RegisterUnmarshalFunc("toml", toml.Unmarshal) bundle.RegisterUnmarshalFunc("toml", toml.Unmarshal)
@ -179,7 +179,7 @@ func InitI18n(logger *log.Logger, languages []string) (*i18n.Bundle, *MessageCat
bundleBytes, err := translations.Bundles.ReadFile(bundleName) bundleBytes, err := translations.Bundles.ReadFile(bundleName)
if err != nil { if err != nil {
logger.WithField("bundle", bundleName).Warn("message bundle not found") logger.Warn("message bundle not found", "bundle", bundleName)
continue continue
} }
@ -192,7 +192,7 @@ func InitI18n(logger *log.Logger, languages []string) (*i18n.Bundle, *MessageCat
return bundle, catalog return bundle, catalog
} }
func initMessageCatalog(logger *log.Logger) *MessageCatalog { func initMessageCatalog(logger *slog.Logger) *MessageCatalog {
messages := make(map[string]*i18n.Message) messages := make(map[string]*i18n.Message)
messages["ErrorTitle"] = &i18n.Message{ messages["ErrorTitle"] = &i18n.Message{
ID: "ErrorTitle", ID: "ErrorTitle",

View file

@ -22,12 +22,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
"github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jwk"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"code.cacert.org/cacert/oidc-demo-app/internal/models" "code.cacert.org/cacert/oidc-demo-app/internal/models"
@ -57,13 +57,15 @@ type OIDCInformation struct {
// retrieved by GetOAuth2Config. // retrieved by GetOAuth2Config.
// //
// The JSON Web Key Set can be retrieved by GetJwkSet. // The JSON Web Key Set can be retrieved by GetJwkSet.
func DiscoverOIDC(logger *log.Logger, params *OidcParams) (*OIDCInformation, error) { func DiscoverOIDC(logger *slog.Logger, params *OidcParams) (*OIDCInformation, error) {
discoveryURL, err := url.Parse(params.OidcServer) discoveryURL, err := url.Parse(params.OidcServer)
if err != nil { if err != nil {
logger.WithError(err).WithField( logger.Error(
"oidc.server", "could not parse parameter oidc.server as URL",
params.OidcServer, "oidc.server", params.OidcServer,
).Fatal("could not parse parameter value") )
return nil, fmt.Errorf("could not parse parameter value: %w", err)
} else { } else {
discoveryURL.Path = "/.well-known/openid-configuration" discoveryURL.Path = "/.well-known/openid-configuration"
} }

View file

@ -19,21 +19,20 @@ package services
import ( import (
"crypto/rand" "crypto/rand"
"fmt"
log "github.com/sirupsen/logrus"
) )
func GenerateKey(length int) []byte { func GenerateKey(length int) ([]byte, error) {
key := make([]byte, length) key := make([]byte, length)
read, err := rand.Read(key) read, err := rand.Read(key)
if err != nil { if err != nil {
log.WithError(err).Fatal("could not generate key") return nil, fmt.Errorf("could not generate key", err)
} }
if read != length { if read != length {
log.WithFields(log.Fields{"read": read, "expected": length}).Fatal("read unexpected number of bytes") return nil, fmt.Errorf("read unexpected number of bytes, read %d, expected %d", read, length)
} }
return key return key, nil
} }

View file

@ -18,10 +18,11 @@ limitations under the License.
package services package services
import ( import (
"fmt"
"log/slog"
"os" "os"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
log "github.com/sirupsen/logrus"
) )
var store *sessions.FilesystemStore var store *sessions.FilesystemStore
@ -33,16 +34,18 @@ const (
SessionRedirectTarget SessionRedirectTarget
) )
func InitSessionStore(logger *log.Logger, sessionPath string, keys ...[]byte) { func InitSessionStore(logger *slog.Logger, sessionPath string, keys ...[]byte) error {
store = sessions.NewFilesystemStore(sessionPath, keys...) store = sessions.NewFilesystemStore(sessionPath, keys...)
if _, err := os.Stat(sessionPath); err != nil { if _, err := os.Stat(sessionPath); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
if err = os.MkdirAll(sessionPath, 0700); err != nil { //nolint:mnd if err = os.MkdirAll(sessionPath, 0700); err != nil { //nolint:mnd
logger.WithError(err).Fatal("could not create session store director") return fmt.Errorf("could not create session store director: %w", err)
} }
} }
} }
return nil
} }
func GetSessionStore() *sessions.FilesystemStore { func GetSessionStore() *sessions.FilesystemStore {

View file

@ -20,13 +20,13 @@ package services
import ( import (
"errors" "errors"
"fmt" "fmt"
"log/slog"
"time" "time"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt" "github.com/lestrrat-go/jwx/jwt"
"github.com/lestrrat-go/jwx/jwt/openid" "github.com/lestrrat-go/jwx/jwt/openid"
log "github.com/sirupsen/logrus"
) )
type OIDCTokenInfo struct { type OIDCTokenInfo struct {
@ -39,7 +39,7 @@ type OIDCTokenInfo struct {
} }
type TokenInfoService struct { type TokenInfoService struct {
logger *log.Logger logger *slog.Logger
keySet jwk.Set keySet jwk.Set
} }
@ -49,15 +49,15 @@ func (s *TokenInfoService) GetTokenInfo(session *sessions.Session) (*OIDCTokenIn
var ok bool var ok bool
if tokenInfo.AccessToken, ok = session.Values[SessionAccessToken].(string); ok { if tokenInfo.AccessToken, ok = session.Values[SessionAccessToken].(string); ok {
s.logger.WithField("access_token", tokenInfo.AccessToken).Debug("found access token in session") s.logger.Debug("found access token in session", "access_token", tokenInfo.AccessToken)
} }
if tokenInfo.RefreshToken, ok = session.Values[SessionRefreshToken].(string); ok { if tokenInfo.RefreshToken, ok = session.Values[SessionRefreshToken].(string); ok {
s.logger.WithField("refresh_token", tokenInfo.RefreshToken).Debug("found refresh token in session") s.logger.Debug("found refresh token in session", "refresh_token", tokenInfo.RefreshToken)
} }
if tokenInfo.IDToken, ok = session.Values[SessionIDToken].(string); ok { if tokenInfo.IDToken, ok = session.Values[SessionIDToken].(string); ok {
s.logger.WithField("id_token", tokenInfo.IDToken).Debug("found ID token in session") s.logger.Debug("found ID token in session", "id_token", tokenInfo.IDToken)
} }
if tokenInfo.IDToken == "" { if tokenInfo.IDToken == "" {
@ -78,7 +78,7 @@ func (s *TokenInfoService) GetTokenInfo(session *sessions.Session) (*OIDCTokenIn
return tokenInfo, nil return tokenInfo, nil
} }
func InitTokenInfoService(logger *log.Logger, oidcInfo *OIDCInformation) (*TokenInfoService, error) { func InitTokenInfoService(logger *slog.Logger, oidcInfo *OIDCInformation) (*TokenInfoService, error) {
return &TokenInfoService{logger: logger, keySet: oidcInfo.KeySet}, nil return &TokenInfoService{logger: logger, keySet: oidcInfo.KeySet}, nil
} }