package fetoken import ( "context" "net/http" "strings" "git.1in9.net/raider/wroofauth/internal/keystore" "git.1in9.net/raider/wroofauth/internal/logger" "github.com/go-chi/chi/v5/middleware" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwt" "github.com/spf13/viper" "go.mongodb.org/mongo-driver/bson/primitive" "go.uber.org/zap" ) type FeToken struct { ID primitive.ObjectID } type feTokenCtxKeyType = string const feTokenCtxKey = feTokenCtxKeyType("feToken") func getFeToken(r *http.Request) *FeToken { header := r.Header.Get("Authorization") if !strings.HasPrefix(header, "Bearer ") { logger.Logger.Info("FeToken - Authorization is not Bearer", zap.String("requestId", middleware.GetReqID(r.Context()))) return nil } tokenString := strings.TrimPrefix(header, "Bearer ") kid := viper.GetString("crypto.use_key.frontend") key, found := keystore.Global.LookupKeyID(kid) if !found { logger.Logger.Error("Keystore doesn't contain key to use for frontend!") return nil } public, err := key.PublicKey() if err != nil { logger.Logger.Error("Failed to make key into public key!", zap.Error(err)) return nil } token, err := jwt.Parse([]byte(tokenString), jwt.WithVerify(jwa.EdDSA, public)) if err != nil { logger.Logger.Info("FeToken - Could not parse token", zap.Error(err), zap.String("requestId", middleware.GetReqID(r.Context()))) return nil } err = jwt.Validate(token) if err != nil { logger.Logger.Info("FeToken - Could not validate token", zap.Error(err), zap.String("requestId", middleware.GetReqID(r.Context()))) return nil } tokenId := token.JwtID() id, err := primitive.ObjectIDFromHex(tokenId) if err != nil { logger.Logger.Info("FeToken - Could not parse token ID", zap.Error(err), zap.String("requestId", middleware.GetReqID(r.Context()))) return nil } feToken := FeToken{ ID: id, } return &feToken } func ForContext(ctx context.Context) *FeToken { raw, _ := ctx.Value(feTokenCtxKey).(*FeToken) return raw } func Middleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { feToken := getFeToken(r) // feToken is a pointer. ctx := context.WithValue(r.Context(), feTokenCtxKey, feToken) r = r.WithContext(ctx) next.ServeHTTP(w, r) }) } }