wroofauth/internal/helpers/fetoken/fetoken.go

93 lines
2.3 KiB
Go

package fetoken
import (
"context"
"net/http"
"strings"
"git.1in9.net/raider/wroofauth/internal/keystore"
"git.1in9.net/raider/wroofauth/internal/logger"
"github.com/go-chi/chi/v5/middleware"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"github.com/spf13/viper"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
type FeToken struct {
ID primitive.ObjectID
}
type feTokenCtxKeyType = string
const feTokenCtxKey = feTokenCtxKeyType("feToken")
func getFeToken(r *http.Request) *FeToken {
header := r.Header.Get("Authorization")
if !strings.HasPrefix(header, "Bearer ") {
logger.Logger.Info("FeToken - Authorization is not Bearer", zap.String("requestId", middleware.GetReqID(r.Context())))
return nil
}
tokenString := strings.TrimPrefix(header, "Bearer ")
kid := viper.GetString("crypto.use_key.frontend")
key, found := keystore.Global.LookupKeyID(kid)
if !found {
logger.Logger.Error("Keystore doesn't contain key to use for frontend!")
return nil
}
public, err := key.PublicKey()
if err != nil {
logger.Logger.Error("Failed to make key into public key!", zap.Error(err))
return nil
}
token, err := jwt.Parse([]byte(tokenString), jwt.WithVerify(jwa.EdDSA, public))
if err != nil {
logger.Logger.Info("FeToken - Could not parse token", zap.Error(err), zap.String("requestId", middleware.GetReqID(r.Context())))
return nil
}
err = jwt.Validate(token)
if err != nil {
logger.Logger.Info("FeToken - Could not validate token", zap.Error(err), zap.String("requestId", middleware.GetReqID(r.Context())))
return nil
}
tokenId := token.JwtID()
id, err := primitive.ObjectIDFromHex(tokenId)
if err != nil {
logger.Logger.Info("FeToken - Could not parse token ID", zap.Error(err), zap.String("requestId", middleware.GetReqID(r.Context())))
return nil
}
feToken := FeToken{
ID: id,
}
return &feToken
}
func ForContext(ctx context.Context) *FeToken {
raw, _ := ctx.Value(feTokenCtxKey).(*FeToken)
return raw
}
func Middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
feToken := getFeToken(r)
// feToken is a pointer.
ctx := context.WithValue(r.Context(), feTokenCtxKey, feToken)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}