package config

import (
	"context"
	"errors"
	"fmt"
	"net/http"
	"strings"
	"time"

	"github.com/google/uuid"
	"golang.org/x/sync/singleflight"
	"google.golang.org/protobuf/types/known/timestamppb"

	"github.com/pomerium/pomerium/internal/encoding"
	"github.com/pomerium/pomerium/internal/encoding/jws"
	"github.com/pomerium/pomerium/internal/httputil"
	"github.com/pomerium/pomerium/internal/jwtutil"
	"github.com/pomerium/pomerium/internal/sessions"
	"github.com/pomerium/pomerium/internal/sessions/cookie"
	"github.com/pomerium/pomerium/internal/sessions/header"
	"github.com/pomerium/pomerium/internal/sessions/queryparam"
	"github.com/pomerium/pomerium/internal/urlutil"
	"github.com/pomerium/pomerium/pkg/authenticateapi"
	"github.com/pomerium/pomerium/pkg/grpc/databroker"
	identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
	"github.com/pomerium/pomerium/pkg/grpc/session"
	"github.com/pomerium/pomerium/pkg/grpc/user"
	"github.com/pomerium/pomerium/pkg/grpcutil"
	"github.com/pomerium/pomerium/pkg/identity"
	"github.com/pomerium/pomerium/pkg/protoutil"
	"github.com/pomerium/pomerium/pkg/storage"
)

// A SessionStore saves and loads sessions based on the options.
type SessionStore struct {
	store   sessions.SessionStore
	loader  sessions.SessionLoader
	options *Options
	encoder encoding.MarshalUnmarshaler
}

var _ sessions.SessionStore = (*SessionStore)(nil)

// NewSessionStore creates a new SessionStore from the Options.
func NewSessionStore(options *Options) (*SessionStore, error) {
	store := &SessionStore{
		options: options,
	}

	sharedKey, err := options.GetSharedKey()
	if err != nil {
		return nil, fmt.Errorf("config/sessions: shared_key is required: %w", err)
	}

	store.encoder, err = jws.NewHS256Signer(sharedKey)
	if err != nil {
		return nil, fmt.Errorf("config/sessions: invalid session encoder: %w", err)
	}

	store.store, err = cookie.NewStore(func() cookie.Options {
		return cookie.Options{
			Name:     options.CookieName,
			Domain:   options.CookieDomain,
			Secure:   true,
			HTTPOnly: options.CookieHTTPOnly,
			Expire:   options.CookieExpire,
			SameSite: options.GetCookieSameSite(),
		}
	}, store.encoder)
	if err != nil {
		return nil, err
	}
	headerStore := header.NewStore(store.encoder)
	queryParamStore := queryparam.NewStore(store.encoder, urlutil.QuerySession)
	store.loader = sessions.MultiSessionLoader(store.store, headerStore, queryParamStore)

	return store, nil
}

// ClearSession clears the session.
func (store *SessionStore) ClearSession(w http.ResponseWriter, r *http.Request) {
	store.store.ClearSession(w, r)
}

// LoadSession loads the session.
func (store *SessionStore) LoadSession(r *http.Request) (string, error) {
	return store.loader.LoadSession(r)
}

// LoadSessionState loads the session state from a request.
func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, error) {
	rawJWT, err := store.loader.LoadSession(r)
	if err != nil {
		return nil, err
	}

	var state sessions.State
	err = store.encoder.Unmarshal([]byte(rawJWT), &state)
	if err != nil {
		return nil, err
	}

	return &state, nil
}

// LoadSessionStateAndCheckIDP loads the session state from a request and checks that the idp id matches.
func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessions.State, error) {
	state, err := store.LoadSessionState(r)
	if err != nil {
		return nil, err
	}

	// confirm that the identity provider id matches the state
	if state.IdentityProviderID != "" {
		idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
		if err != nil {
			return nil, err
		}

		if idp.GetId() != state.IdentityProviderID {
			return nil, fmt.Errorf("unexpected session state identity provider id: %s != %s",
				idp.GetId(), state.IdentityProviderID)
		}
	}

	return state, nil
}

// SaveSession saves the session.
func (store *SessionStore) SaveSession(w http.ResponseWriter, r *http.Request, v any) error {
	return store.store.SaveSession(w, r, v)
}

type IncomingIDPTokenSessionCreator interface {
	CreateSession(ctx context.Context, cfg *Config, policy *Policy, r *http.Request) (*session.Session, error)
}

type incomingIDPTokenSessionCreator struct {
	timeNow      func() time.Time
	getRecord    func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
	putRecords   func(ctx context.Context, records []*databroker.Record) error
	singleflight singleflight.Group
}

func NewIncomingIDPTokenSessionCreator(
	getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error),
	putRecords func(ctx context.Context, records []*databroker.Record) error,
) IncomingIDPTokenSessionCreator {
	return &incomingIDPTokenSessionCreator{timeNow: time.Now, getRecord: getRecord, putRecords: putRecords}
}

// CreateSession attempts to create a session for incoming idp access and
// identity tokens. If no access or identity token is passed ErrNoSessionFound will be returned.
// If the tokens are not valid an error will be returned.
func (c *incomingIDPTokenSessionCreator) CreateSession(
	ctx context.Context,
	cfg *Config,
	policy *Policy,
	r *http.Request,
) (session *session.Session, err error) {
	if rawAccessToken, ok := cfg.GetIncomingIDPAccessTokenForPolicy(policy, r); ok {
		return c.createSessionAccessToken(ctx, cfg, policy, rawAccessToken)
	}

	if rawIdentityToken, ok := cfg.GetIncomingIDPIdentityTokenForPolicy(policy, r); ok {
		return c.createSessionForIdentityToken(ctx, cfg, policy, rawIdentityToken)
	}

	return nil, sessions.ErrNoSessionFound
}

func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
	ctx context.Context,
	cfg *Config,
	policy *Policy,
	rawAccessToken string,
) (*session.Session, error) {
	idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
	if err != nil {
		return nil, fmt.Errorf("error getting identity provider to verify access token: %w", err)
	}

	sessionID := getAccessTokenSessionID(idp, rawAccessToken)
	res, err, _ := c.singleflight.Do(sessionID, func() (any, error) {
		s, err := c.getSession(ctx, sessionID)
		if err == nil {
			return s, nil
		} else if !storage.IsNotFound(err) {
			return nil, err
		}

		authenticateURL, transport, err := cfg.resolveAuthenticateURL()
		if err != nil {
			return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
		}

		res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
			AccessToken:        rawAccessToken,
			IdentityProviderID: idp.GetId(),
		})
		if err != nil {
			return nil, fmt.Errorf("error verifying access token: %w", err)
		} else if !res.Valid {
			return nil, fmt.Errorf("%w: invalid access token", sessions.ErrInvalidSession)
		}

		s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
		s.OauthToken = &session.OAuthToken{
			TokenType:   "Bearer",
			AccessToken: rawAccessToken,
			ExpiresAt:   s.ExpiresAt,
		}

		u, err := c.getUser(ctx, s.GetUserId())
		if storage.IsNotFound(err) {
			u = &user.User{Id: s.GetUserId()}
		} else if err != nil {
			return nil, fmt.Errorf("error retrieving existing user: %w", err)
		}
		c.fillUserFromIDPClaims(u, res.Claims)

		err = c.putSessionAndUser(ctx, s, u)
		if err != nil {
			return nil, fmt.Errorf("error saving session and user: %w", err)
		}

		return s, nil
	})
	if err != nil {
		return nil, err
	}
	return res.(*session.Session), nil
}

func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
	ctx context.Context,
	cfg *Config,
	policy *Policy,
	rawIdentityToken string,
) (*session.Session, error) {
	idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
	if err != nil {
		return nil, fmt.Errorf("error getting identity provider to verify identity token: %w", err)
	}

	sessionID := getIdentityTokenSessionID(idp, rawIdentityToken)
	res, err, _ := c.singleflight.Do(sessionID, func() (any, error) {
		s, err := c.getSession(ctx, sessionID)
		if err == nil {
			return s, nil
		} else if !storage.IsNotFound(err) {
			return nil, err
		}

		authenticateURL, transport, err := cfg.resolveAuthenticateURL()
		if err != nil {
			return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
		}

		res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
			IdentityToken:      rawIdentityToken,
			IdentityProviderID: idp.GetId(),
		})
		if err != nil {
			return nil, fmt.Errorf("error verifying identity token: %w", err)
		} else if !res.Valid {
			return nil, fmt.Errorf("%w: invalid identity token", sessions.ErrInvalidSession)
		}

		s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
		s.SetRawIDToken(rawIdentityToken)

		u, err := c.getUser(ctx, s.GetUserId())
		if errors.Is(err, storage.ErrNotFound) {
			u = &user.User{Id: s.GetUserId()}
		} else if err != nil {
			return nil, fmt.Errorf("error retrieving existing user: %w", err)
		}
		c.fillUserFromIDPClaims(u, res.Claims)

		err = c.putSessionAndUser(ctx, s, u)
		if err != nil {
			return nil, fmt.Errorf("error saving session and user: %w", err)
		}

		return s, nil
	})
	if err != nil {
		return nil, err
	}
	return res.(*session.Session), nil
}

func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
	cfg *Config,
	sessionID string,
	claims jwtutil.Claims,
) *session.Session {
	now := c.timeNow()
	s := new(session.Session)
	s.Id = sessionID
	if userID, ok := claims.GetUserID(); ok {
		s.UserId = userID
	}
	if issuedAt, ok := claims.GetIssuedAt(); ok {
		s.IssuedAt = timestamppb.New(issuedAt)
	} else {
		s.IssuedAt = timestamppb.New(now)
	}
	if expiresAt, ok := claims.GetExpirationTime(); ok {
		s.ExpiresAt = timestamppb.New(expiresAt)
	} else {
		s.ExpiresAt = timestamppb.New(now.Add(cfg.Options.CookieExpire))
	}
	s.AccessedAt = timestamppb.New(now)
	s.AddClaims(identity.Claims(claims).Flatten())
	if aud, ok := claims.GetAudience(); ok {
		s.Audience = aud
	}
	s.RefreshDisabled = true
	return s
}

func (c *incomingIDPTokenSessionCreator) fillUserFromIDPClaims(
	u *user.User,
	claims jwtutil.Claims,
) {
	if userID, ok := claims.GetUserID(); ok {
		u.Id = userID
	}
	if name, ok := claims.GetString("name"); ok {
		u.Name = name
	}
	if email, ok := claims.GetString("email"); ok {
		u.Email = email
	}
	u.AddClaims(identity.Claims(claims).Flatten())
}

func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) {
	record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
	if err != nil {
		return nil, err
	}

	msg, err := record.GetData().UnmarshalNew()
	if err != nil {
		return nil, storage.ErrNotFound
	}

	s, ok := msg.(*session.Session)
	if !ok {
		return nil, storage.ErrNotFound
	}

	return s, nil
}

func (c *incomingIDPTokenSessionCreator) getUser(ctx context.Context, userID string) (*user.User, error) {
	record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID)
	if err != nil {
		return nil, err
	}

	msg, err := record.GetData().UnmarshalNew()
	if err != nil {
		return nil, storage.ErrNotFound
	}

	u, ok := msg.(*user.User)
	if !ok {
		return nil, storage.ErrNotFound
	}

	return u, nil
}

func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error {
	var records []*databroker.Record
	if id := s.GetId(); id != "" {
		records = append(records, &databroker.Record{
			Type: grpcutil.GetTypeURL(s),
			Id:   id,
			Data: protoutil.NewAny(s),
		})
	}
	if id := u.GetId(); id != "" {
		records = append(records, &databroker.Record{
			Type: grpcutil.GetTypeURL(u),
			Id:   id,
			Data: protoutil.NewAny(u),
		})
	}
	return c.putRecords(ctx, records)
}

// GetIncomingIDPAccessTokenForPolicy returns the raw idp access token from a request if there is one.
func (cfg *Config) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) {
	bearerTokenFormat := BearerTokenFormatUnknown
	if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
		bearerTokenFormat = *cfg.Options.BearerTokenFormat
	}
	if policy != nil && policy.BearerTokenFormat != nil {
		bearerTokenFormat = *policy.BearerTokenFormat
	}

	if token := r.Header.Get(httputil.HeaderPomeriumIDPAccessToken); token != "" {
		return token, true
	}

	if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
		prefix := httputil.AuthorizationTypePomeriumIDPAccessToken + " "
		if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
			return auth[len(prefix):], true
		}

		prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPAccessToken + "-"
		if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
			return auth[len(prefix):], true
		}

		prefix = "Bearer "
		if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
			bearerTokenFormat == BearerTokenFormatIDPAccessToken {
			return auth[len(prefix):], true
		}
	}

	return "", false
}

// GetIncomingIDPAccessTokenForPolicy returns the raw idp identity token from a request if there is one.
func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) {
	bearerTokenFormat := BearerTokenFormatDefault
	if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
		bearerTokenFormat = *cfg.Options.BearerTokenFormat
	}
	if policy != nil && policy.BearerTokenFormat != nil {
		bearerTokenFormat = *policy.BearerTokenFormat
	}

	if token := r.Header.Get(httputil.HeaderPomeriumIDPIdentityToken); token != "" {
		return token, true
	}

	if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
		prefix := httputil.AuthorizationTypePomeriumIDPIdentityToken + " "
		if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
			return auth[len(prefix):], true
		}

		prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPIdentityToken + "-"
		if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
			return auth[len(prefix):], true
		}

		prefix = "Bearer "
		if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
			bearerTokenFormat == BearerTokenFormatIDPIdentityToken {
			return auth[len(prefix):], true
		}
	}

	return "", false
}

var accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d")

func getAccessTokenSessionID(idp *identitypb.Provider, rawAccessToken string) string {
	namespace := accessTokenUUIDNamespace
	// make the session ID per-idp settings
	if idp != nil {
		namespace = uuid.NewSHA1(namespace, []byte(idp.GetId()))
	}
	return uuid.NewSHA1(namespace, []byte(rawAccessToken)).String()
}

var identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17")

func getIdentityTokenSessionID(idp *identitypb.Provider, rawIdentityToken string) string {
	namespace := identityTokenUUIDNamespace
	// make the session ID per-idp settings
	if idp != nil {
		namespace = uuid.NewSHA1(namespace, []byte(idp.GetId()))
	}
	return uuid.NewSHA1(namespace, []byte(rawIdentityToken)).String()
}