mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
initial release
This commit is contained in:
commit
d56c889224
62 changed files with 8229 additions and 0 deletions
163
internal/sessions/cookie_store.go
Normal file
163
internal/sessions/cookie_store.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/aead"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// ErrInvalidSession is an error for invalid sessions.
|
||||
var ErrInvalidSession = errors.New("invalid session")
|
||||
|
||||
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
|
||||
type CSRFStore interface {
|
||||
SetCSRF(http.ResponseWriter, *http.Request, string)
|
||||
GetCSRF(*http.Request) (*http.Cookie, error)
|
||||
ClearCSRF(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||
type SessionStore interface {
|
||||
ClearSession(http.ResponseWriter, *http.Request)
|
||||
LoadSession(*http.Request) (*SessionState, error)
|
||||
SaveSession(http.ResponseWriter, *http.Request, *SessionState) error
|
||||
}
|
||||
|
||||
// CookieStore represents all the cookie related configurations
|
||||
type CookieStore struct {
|
||||
Name string
|
||||
CSRFCookieName string
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
CookieCipher aead.Cipher
|
||||
SessionLifetimeTTL time.Duration
|
||||
}
|
||||
|
||||
// CreateMiscreantCookieCipher creates a new miscreant cipher with the cookie secret
|
||||
func CreateMiscreantCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
|
||||
return func(s *CookieStore) error {
|
||||
cipher, err := aead.NewMiscreantCipher(cookieSecret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("miscreant cookie-secret error: %s", err.Error())
|
||||
}
|
||||
s.CookieCipher = cipher
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
||||
func NewCookieStore(cookieName string, optFuncs ...func(*CookieStore) error) (*CookieStore, error) {
|
||||
c := &CookieStore{
|
||||
Name: cookieName,
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: 168 * time.Hour,
|
||||
CSRFCookieName: fmt.Sprintf("%v_%v", cookieName, "csrf"),
|
||||
}
|
||||
|
||||
for _, f := range optFuncs {
|
||||
err := f(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
domain := c.CookieDomain
|
||||
if domain == "" {
|
||||
domain = "<default>"
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
domain := req.Host
|
||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||
domain = h
|
||||
}
|
||||
if s.CookieDomain != "" {
|
||||
if !strings.HasSuffix(domain, s.CookieDomain) {
|
||||
log.Warn().Str("cookie-domain", s.CookieDomain).Msg("using configured cookie domain")
|
||||
}
|
||||
domain = s.CookieDomain
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
Domain: domain,
|
||||
HttpOnly: s.CookieHTTPOnly,
|
||||
Secure: s.CookieSecure,
|
||||
Expires: now.Add(expiration),
|
||||
}
|
||||
}
|
||||
|
||||
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
||||
func (s *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return s.makeCookie(req, s.Name, value, expiration, now)
|
||||
}
|
||||
|
||||
// makeCSRFCookie creates a CSRF cookie given the request, an expiration time, and the current time.
|
||||
func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return s.makeCookie(req, s.CSRFCookieName, value, expiration, now)
|
||||
}
|
||||
|
||||
// ClearCSRF clears the CSRF cookie from the request
|
||||
func (s *CookieStore) ClearCSRF(rw http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(rw, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
|
||||
func (s *CookieStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(rw, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
|
||||
func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
|
||||
return req.Cookie(s.CSRFCookieName)
|
||||
}
|
||||
|
||||
// ClearSession clears the session cookie from a request
|
||||
func (s *CookieStore) ClearSession(rw http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(rw, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
func (s *CookieStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(rw, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
// LoadSession returns a SessionState from the cookie in the request.
|
||||
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
|
||||
c, err := req.Cookie(s.Name)
|
||||
if err != nil {
|
||||
// always http.ErrNoCookie
|
||||
return nil, err
|
||||
}
|
||||
session, err := UnmarshalSession(c.Value, s.CookieCipher)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("remote-host", req.Host).Msg("error unmarshaling session")
|
||||
return nil, ErrInvalidSession
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// SaveSession saves a session state to a request sessions.
|
||||
func (s *CookieStore) SaveSession(rw http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
|
||||
value, err := MarshalSession(sessionState, s.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.setSessionCookie(rw, req, value)
|
||||
return nil
|
||||
}
|
348
internal/sessions/cookie_store_test.go
Normal file
348
internal/sessions/cookie_store_test.go
Normal file
|
@ -0,0 +1,348 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
||||
var testEncodedCookieSecret, _ = base64.StdEncoding.DecodeString("qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=")
|
||||
|
||||
func TestCreateMiscreantCookieCipher(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
cookieSecret []byte
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "normal case with base64 encoded secret",
|
||||
cookieSecret: testEncodedCookieSecret,
|
||||
},
|
||||
|
||||
{
|
||||
name: "error when not base64 encoded",
|
||||
cookieSecret: []byte("abcd"),
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewCookieStore("cookieName", CreateMiscreantCookieCipher(tc.cookieSecret))
|
||||
if !tc.expectedError {
|
||||
testutil.Ok(t, err)
|
||||
} else {
|
||||
testutil.NotEqual(t, err, nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSession(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
expectedError bool
|
||||
expectedSession *CookieStore
|
||||
}{
|
||||
{
|
||||
name: "default with no opt funcs set",
|
||||
expectedSession: &CookieStore{
|
||||
Name: "cookieName",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: 168 * time.Hour,
|
||||
CSRFCookieName: "cookieName_csrf",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "opt func with an error returns an error",
|
||||
optFuncs: []func(*CookieStore) error{func(*CookieStore) error { return fmt.Errorf("error") }},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "opt func overrides default values",
|
||||
optFuncs: []func(*CookieStore) error{func(s *CookieStore) error {
|
||||
s.CookieExpire = time.Hour
|
||||
return nil
|
||||
}},
|
||||
expectedSession: &CookieStore{
|
||||
Name: "cookieName",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: time.Hour,
|
||||
CSRFCookieName: "cookieName_csrf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore("cookieName", tc.optFuncs...)
|
||||
if tc.expectedError {
|
||||
testutil.NotEqual(t, err, nil)
|
||||
} else {
|
||||
testutil.Ok(t, err)
|
||||
}
|
||||
testutil.Equal(t, tc.expectedSession, session)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeSessionCookie(t *testing.T) {
|
||||
now := time.Now()
|
||||
cookieValue := "cookieValue"
|
||||
expiration := time.Hour
|
||||
cookieName := "cookieName"
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
expectedCookie *http.Cookie
|
||||
}{
|
||||
{
|
||||
name: "default cookie domain",
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "www.example.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom cookie domain set",
|
||||
optFuncs: []func(*CookieStore) error{
|
||||
func(s *CookieStore) error {
|
||||
s.CookieDomain = "buzzfeed.com"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "buzzfeed.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
cookie := session.makeSessionCookie(req, cookieValue, expiration, now)
|
||||
testutil.Equal(t, cookie, tc.expectedCookie)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeSessionCSRFCookie(t *testing.T) {
|
||||
now := time.Now()
|
||||
cookieValue := "cookieValue"
|
||||
expiration := time.Hour
|
||||
cookieName := "cookieName"
|
||||
csrfName := "cookieName_csrf"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
expectedCookie *http.Cookie
|
||||
}{
|
||||
{
|
||||
name: "default cookie domain",
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: csrfName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "www.example.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom cookie domain set",
|
||||
optFuncs: []func(*CookieStore) error{
|
||||
func(s *CookieStore) error {
|
||||
s.CookieDomain = "buzzfeed.com"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: csrfName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "buzzfeed.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
cookie := session.makeCSRFCookie(req, cookieValue, expiration, now)
|
||||
testutil.Equal(t, tc.expectedCookie, cookie)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("set session cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
session.setSessionCookie(rw, req, cookieValue)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == cookieName {
|
||||
found = true
|
||||
testutil.Equal(t, cookieValue, cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
func TestSetCSRFSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("set csrf cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
session.SetCSRF(rw, req, cookieValue)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
|
||||
found = true
|
||||
testutil.Equal(t, cookieValue, cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("set session cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
req.AddCookie(session.makeSessionCookie(req, cookieValue, time.Hour, time.Now()))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
session.ClearSession(rw, req)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == cookieName {
|
||||
found = true
|
||||
testutil.Equal(t, "", cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearCSRFSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("clear csrf cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
req.AddCookie(session.makeCSRFCookie(req, cookieValue, time.Hour, time.Now()))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
session.ClearCSRF(rw, req)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
|
||||
found = true
|
||||
testutil.Equal(t, "", cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadCookiedSession(t *testing.T) {
|
||||
cookieName := "cookieName"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
setupCookies func(*testing.T, *http.Request, *CookieStore, *SessionState)
|
||||
expectedError error
|
||||
sessionState *SessionState
|
||||
}{
|
||||
{
|
||||
name: "no cookie set returns an error",
|
||||
setupCookies: func(*testing.T, *http.Request, *CookieStore, *SessionState) {},
|
||||
expectedError: http.ErrNoCookie,
|
||||
},
|
||||
{
|
||||
name: "cookie set with cipher set",
|
||||
optFuncs: []func(*CookieStore) error{CreateMiscreantCookieCipher(testEncodedCookieSecret)},
|
||||
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
|
||||
value, err := MarshalSession(sessionState, s.CookieCipher)
|
||||
testutil.Ok(t, err)
|
||||
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
|
||||
},
|
||||
sessionState: &SessionState{
|
||||
Email: "example@email.com",
|
||||
RefreshToken: "abccdddd",
|
||||
AccessToken: "access",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cookie set with invalid value cipher set",
|
||||
optFuncs: []func(*CookieStore) error{CreateMiscreantCookieCipher(testEncodedCookieSecret)},
|
||||
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
|
||||
value := "574b776a7c934d6b9fc42ec63a389f79"
|
||||
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
|
||||
},
|
||||
expectedError: ErrInvalidSession,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "https://www.example.com", nil)
|
||||
tc.setupCookies(t, req, session, tc.sessionState)
|
||||
s, err := session.LoadSession(req)
|
||||
|
||||
testutil.Equal(t, tc.expectedError, err)
|
||||
testutil.Equal(t, tc.sessionState, s)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
50
internal/sessions/mock_store.go
Normal file
50
internal/sessions/mock_store.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// MockCSRFStore is a mock implementation of the CSRF store interface
|
||||
type MockCSRFStore struct {
|
||||
ResponseCSRF string
|
||||
Cookie *http.Cookie
|
||||
GetError error
|
||||
}
|
||||
|
||||
// SetCSRF sets the ResponseCSRF string to a val
|
||||
func (ms *MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
ms.ResponseCSRF = val
|
||||
}
|
||||
|
||||
// ClearCSRF clears the ResponseCSRF string
|
||||
func (ms *MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseCSRF = ""
|
||||
}
|
||||
|
||||
// GetCSRF returns the cookie and error
|
||||
func (ms *MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
|
||||
return ms.Cookie, ms.GetError
|
||||
}
|
||||
|
||||
// MockSessionStore is a mock implementation of the SessionStore interface
|
||||
type MockSessionStore struct {
|
||||
ResponseSession string
|
||||
Session *SessionState
|
||||
SaveError error
|
||||
LoadError error
|
||||
}
|
||||
|
||||
// ClearSession clears the ResponseSession
|
||||
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseSession = ""
|
||||
}
|
||||
|
||||
// LoadSession returns the session and a error
|
||||
func (ms *MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
|
||||
return ms.Session, ms.LoadError
|
||||
}
|
||||
|
||||
// SaveSession returns a save error.
|
||||
func (ms *MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
|
||||
return ms.SaveError
|
||||
}
|
70
internal/sessions/session_state.go
Normal file
70
internal/sessions/session_state.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/aead"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrLifetimeExpired is an error for the lifetime deadline expiring
|
||||
ErrLifetimeExpired = errors.New("user lifetime expired")
|
||||
)
|
||||
|
||||
// SessionState is our object that keeps track of a user's session state
|
||||
type SessionState struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"` // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
|
||||
|
||||
RefreshDeadline time.Time `json:"refresh_deadline"`
|
||||
LifetimeDeadline time.Time `json:"lifetime_deadline"`
|
||||
ValidDeadline time.Time `json:"valid_deadline"`
|
||||
GracePeriodStart time.Time `json:"grace_period_start"`
|
||||
|
||||
Email string `json:"email"`
|
||||
User string `json:"user"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
|
||||
// LifetimePeriodExpired returns true if the lifetime has expired
|
||||
func (s *SessionState) LifetimePeriodExpired() bool {
|
||||
return isExpired(s.LifetimeDeadline)
|
||||
}
|
||||
|
||||
// RefreshPeriodExpired returns true if the refresh period has expired
|
||||
func (s *SessionState) RefreshPeriodExpired() bool {
|
||||
return isExpired(s.RefreshDeadline)
|
||||
}
|
||||
|
||||
// ValidationPeriodExpired returns true if the validation period has expired
|
||||
func (s *SessionState) ValidationPeriodExpired() bool {
|
||||
return isExpired(s.ValidDeadline)
|
||||
}
|
||||
|
||||
func isExpired(t time.Time) bool {
|
||||
return t.Before(time.Now())
|
||||
}
|
||||
|
||||
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
|
||||
// given cipher, and base64-encodes the result
|
||||
func MarshalSession(s *SessionState, c aead.Cipher) (string, error) {
|
||||
return c.Marshal(s)
|
||||
}
|
||||
|
||||
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||
// byte slice using the pased cipher, and unmarshals the resulting JSON into a session state struct
|
||||
func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) {
|
||||
s := &SessionState{}
|
||||
err := c.Unmarshal(value, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ExtendDeadline returns the time extended by a given duration
|
||||
func ExtendDeadline(ttl time.Duration) time.Time {
|
||||
return time.Now().Add(ttl).Truncate(time.Second)
|
||||
}
|
71
internal/sessions/session_state_test.go
Normal file
71
internal/sessions/session_state_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/aead"
|
||||
)
|
||||
|
||||
func TestSessionStateSerialization(t *testing.T) {
|
||||
secret := aead.GenerateKey()
|
||||
c, err := aead.NewMiscreantCipher([]byte(secret))
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
}
|
||||
|
||||
want := &SessionState{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
|
||||
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(),
|
||||
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
ciphertext, err := MarshalSession(want, c)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be encode session: %v", err)
|
||||
}
|
||||
|
||||
got, err := UnmarshalSession(ciphertext, c)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be decode session: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Logf("want: %#v", want)
|
||||
t.Logf(" got: %#v", got)
|
||||
t.Errorf("encoding and decoding session resulted in unexpected output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStateExpirations(t *testing.T) {
|
||||
session := &SessionState{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
|
||||
LifetimeDeadline: time.Now().Add(-1 * time.Hour),
|
||||
RefreshDeadline: time.Now().Add(-1 * time.Hour),
|
||||
ValidDeadline: time.Now().Add(-1 * time.Minute),
|
||||
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
if !session.LifetimePeriodExpired() {
|
||||
t.Errorf("expcted lifetime period to be expired")
|
||||
}
|
||||
|
||||
if !session.RefreshPeriodExpired() {
|
||||
t.Errorf("expcted lifetime period to be expired")
|
||||
}
|
||||
|
||||
if !session.ValidationPeriodExpired() {
|
||||
t.Errorf("expcted lifetime period to be expired")
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue