initial release

This commit is contained in:
Bobby DeSimone 2019-01-02 12:13:36 -08:00
commit d56c889224
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
62 changed files with 8229 additions and 0 deletions

View file

@ -0,0 +1,163 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/pomerium/pomerium/internal/aead"
"github.com/pomerium/pomerium/internal/log"
)
// ErrInvalidSession is an error for invalid sessions.
var ErrInvalidSession = errors.New("invalid session")
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
type CSRFStore interface {
SetCSRF(http.ResponseWriter, *http.Request, string)
GetCSRF(*http.Request) (*http.Cookie, error)
ClearCSRF(http.ResponseWriter, *http.Request)
}
// SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface {
ClearSession(http.ResponseWriter, *http.Request)
LoadSession(*http.Request) (*SessionState, error)
SaveSession(http.ResponseWriter, *http.Request, *SessionState) error
}
// CookieStore represents all the cookie related configurations
type CookieStore struct {
Name string
CSRFCookieName string
CookieExpire time.Duration
CookieRefresh time.Duration
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
CookieCipher aead.Cipher
SessionLifetimeTTL time.Duration
}
// CreateMiscreantCookieCipher creates a new miscreant cipher with the cookie secret
func CreateMiscreantCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
return func(s *CookieStore) error {
cipher, err := aead.NewMiscreantCipher(cookieSecret)
if err != nil {
return fmt.Errorf("miscreant cookie-secret error: %s", err.Error())
}
s.CookieCipher = cipher
return nil
}
}
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
func NewCookieStore(cookieName string, optFuncs ...func(*CookieStore) error) (*CookieStore, error) {
c := &CookieStore{
Name: cookieName,
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: 168 * time.Hour,
CSRFCookieName: fmt.Sprintf("%v_%v", cookieName, "csrf"),
}
for _, f := range optFuncs {
err := f(c)
if err != nil {
return nil, err
}
}
domain := c.CookieDomain
if domain == "" {
domain = "<default>"
}
return c, nil
}
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
domain := req.Host
if h, _, err := net.SplitHostPort(domain); err == nil {
domain = h
}
if s.CookieDomain != "" {
if !strings.HasSuffix(domain, s.CookieDomain) {
log.Warn().Str("cookie-domain", s.CookieDomain).Msg("using configured cookie domain")
}
domain = s.CookieDomain
}
return &http.Cookie{
Name: name,
Value: value,
Path: "/",
Domain: domain,
HttpOnly: s.CookieHTTPOnly,
Secure: s.CookieSecure,
Expires: now.Add(expiration),
}
}
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
func (s *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return s.makeCookie(req, s.Name, value, expiration, now)
}
// makeCSRFCookie creates a CSRF cookie given the request, an expiration time, and the current time.
func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return s.makeCookie(req, s.CSRFCookieName, value, expiration, now)
}
// ClearCSRF clears the CSRF cookie from the request
func (s *CookieStore) ClearCSRF(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
func (s *CookieStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
}
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
return req.Cookie(s.CSRFCookieName)
}
// ClearSession clears the session cookie from a request
func (s *CookieStore) ClearSession(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
}
func (s *CookieStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
}
// LoadSession returns a SessionState from the cookie in the request.
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
c, err := req.Cookie(s.Name)
if err != nil {
// always http.ErrNoCookie
return nil, err
}
session, err := UnmarshalSession(c.Value, s.CookieCipher)
if err != nil {
log.Error().Err(err).Str("remote-host", req.Host).Msg("error unmarshaling session")
return nil, ErrInvalidSession
}
return session, nil
}
// SaveSession saves a session state to a request sessions.
func (s *CookieStore) SaveSession(rw http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
value, err := MarshalSession(sessionState, s.CookieCipher)
if err != nil {
return err
}
s.setSessionCookie(rw, req, value)
return nil
}

View file

@ -0,0 +1,348 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/pomerium/pomerium/internal/testutil"
)
var testEncodedCookieSecret, _ = base64.StdEncoding.DecodeString("qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=")
func TestCreateMiscreantCookieCipher(t *testing.T) {
testCases := []struct {
name string
cookieSecret []byte
expectedError bool
}{
{
name: "normal case with base64 encoded secret",
cookieSecret: testEncodedCookieSecret,
},
{
name: "error when not base64 encoded",
cookieSecret: []byte("abcd"),
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := NewCookieStore("cookieName", CreateMiscreantCookieCipher(tc.cookieSecret))
if !tc.expectedError {
testutil.Ok(t, err)
} else {
testutil.NotEqual(t, err, nil)
}
})
}
}
func TestNewSession(t *testing.T) {
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
expectedError bool
expectedSession *CookieStore
}{
{
name: "default with no opt funcs set",
expectedSession: &CookieStore{
Name: "cookieName",
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: 168 * time.Hour,
CSRFCookieName: "cookieName_csrf",
},
},
{
name: "opt func with an error returns an error",
optFuncs: []func(*CookieStore) error{func(*CookieStore) error { return fmt.Errorf("error") }},
expectedError: true,
},
{
name: "opt func overrides default values",
optFuncs: []func(*CookieStore) error{func(s *CookieStore) error {
s.CookieExpire = time.Hour
return nil
}},
expectedSession: &CookieStore{
Name: "cookieName",
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: time.Hour,
CSRFCookieName: "cookieName_csrf",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore("cookieName", tc.optFuncs...)
if tc.expectedError {
testutil.NotEqual(t, err, nil)
} else {
testutil.Ok(t, err)
}
testutil.Equal(t, tc.expectedSession, session)
})
}
}
func TestMakeSessionCookie(t *testing.T) {
now := time.Now()
cookieValue := "cookieValue"
expiration := time.Hour
cookieName := "cookieName"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
expectedCookie *http.Cookie
}{
{
name: "default cookie domain",
expectedCookie: &http.Cookie{
Name: cookieName,
Value: cookieValue,
Path: "/",
Domain: "www.example.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
{
name: "custom cookie domain set",
optFuncs: []func(*CookieStore) error{
func(s *CookieStore) error {
s.CookieDomain = "buzzfeed.com"
return nil
},
},
expectedCookie: &http.Cookie{
Name: cookieName,
Value: cookieValue,
Path: "/",
Domain: "buzzfeed.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
cookie := session.makeSessionCookie(req, cookieValue, expiration, now)
testutil.Equal(t, cookie, tc.expectedCookie)
})
}
}
func TestMakeSessionCSRFCookie(t *testing.T) {
now := time.Now()
cookieValue := "cookieValue"
expiration := time.Hour
cookieName := "cookieName"
csrfName := "cookieName_csrf"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
expectedCookie *http.Cookie
}{
{
name: "default cookie domain",
expectedCookie: &http.Cookie{
Name: csrfName,
Value: cookieValue,
Path: "/",
Domain: "www.example.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
{
name: "custom cookie domain set",
optFuncs: []func(*CookieStore) error{
func(s *CookieStore) error {
s.CookieDomain = "buzzfeed.com"
return nil
},
},
expectedCookie: &http.Cookie{
Name: csrfName,
Value: cookieValue,
Path: "/",
Domain: "buzzfeed.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
cookie := session.makeCSRFCookie(req, cookieValue, expiration, now)
testutil.Equal(t, tc.expectedCookie, cookie)
})
}
}
func TestSetSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set session cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
rw := httptest.NewRecorder()
session.setSessionCookie(rw, req, cookieValue)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == cookieName {
found = true
testutil.Equal(t, cookieValue, cookie.Value)
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestSetCSRFSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set csrf cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
rw := httptest.NewRecorder()
session.SetCSRF(rw, req, cookieValue)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
found = true
testutil.Equal(t, cookieValue, cookie.Value)
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestClearSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set session cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
req.AddCookie(session.makeSessionCookie(req, cookieValue, time.Hour, time.Now()))
rw := httptest.NewRecorder()
session.ClearSession(rw, req)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == cookieName {
found = true
testutil.Equal(t, "", cookie.Value)
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestClearCSRFSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("clear csrf cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
req.AddCookie(session.makeCSRFCookie(req, cookieValue, time.Hour, time.Now()))
rw := httptest.NewRecorder()
session.ClearCSRF(rw, req)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
found = true
testutil.Equal(t, "", cookie.Value)
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestLoadCookiedSession(t *testing.T) {
cookieName := "cookieName"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
setupCookies func(*testing.T, *http.Request, *CookieStore, *SessionState)
expectedError error
sessionState *SessionState
}{
{
name: "no cookie set returns an error",
setupCookies: func(*testing.T, *http.Request, *CookieStore, *SessionState) {},
expectedError: http.ErrNoCookie,
},
{
name: "cookie set with cipher set",
optFuncs: []func(*CookieStore) error{CreateMiscreantCookieCipher(testEncodedCookieSecret)},
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
value, err := MarshalSession(sessionState, s.CookieCipher)
testutil.Ok(t, err)
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
},
sessionState: &SessionState{
Email: "example@email.com",
RefreshToken: "abccdddd",
AccessToken: "access",
},
},
{
name: "cookie set with invalid value cipher set",
optFuncs: []func(*CookieStore) error{CreateMiscreantCookieCipher(testEncodedCookieSecret)},
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
value := "574b776a7c934d6b9fc42ec63a389f79"
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
},
expectedError: ErrInvalidSession,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "https://www.example.com", nil)
tc.setupCookies(t, req, session, tc.sessionState)
s, err := session.LoadSession(req)
testutil.Equal(t, tc.expectedError, err)
testutil.Equal(t, tc.sessionState, s)
})
}
}

View file

@ -0,0 +1,50 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"net/http"
)
// MockCSRFStore is a mock implementation of the CSRF store interface
type MockCSRFStore struct {
ResponseCSRF string
Cookie *http.Cookie
GetError error
}
// SetCSRF sets the ResponseCSRF string to a val
func (ms *MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
ms.ResponseCSRF = val
}
// ClearCSRF clears the ResponseCSRF string
func (ms *MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
ms.ResponseCSRF = ""
}
// GetCSRF returns the cookie and error
func (ms *MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
return ms.Cookie, ms.GetError
}
// MockSessionStore is a mock implementation of the SessionStore interface
type MockSessionStore struct {
ResponseSession string
Session *SessionState
SaveError error
LoadError error
}
// ClearSession clears the ResponseSession
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}
// LoadSession returns the session and a error
func (ms *MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
return ms.Session, ms.LoadError
}
// SaveSession returns a save error.
func (ms *MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
return ms.SaveError
}

View file

@ -0,0 +1,70 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"errors"
"time"
"github.com/pomerium/pomerium/internal/aead"
)
var (
// ErrLifetimeExpired is an error for the lifetime deadline expiring
ErrLifetimeExpired = errors.New("user lifetime expired")
)
// SessionState is our object that keeps track of a user's session state
type SessionState struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"` // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
RefreshDeadline time.Time `json:"refresh_deadline"`
LifetimeDeadline time.Time `json:"lifetime_deadline"`
ValidDeadline time.Time `json:"valid_deadline"`
GracePeriodStart time.Time `json:"grace_period_start"`
Email string `json:"email"`
User string `json:"user"`
Groups []string `json:"groups"`
}
// LifetimePeriodExpired returns true if the lifetime has expired
func (s *SessionState) LifetimePeriodExpired() bool {
return isExpired(s.LifetimeDeadline)
}
// RefreshPeriodExpired returns true if the refresh period has expired
func (s *SessionState) RefreshPeriodExpired() bool {
return isExpired(s.RefreshDeadline)
}
// ValidationPeriodExpired returns true if the validation period has expired
func (s *SessionState) ValidationPeriodExpired() bool {
return isExpired(s.ValidDeadline)
}
func isExpired(t time.Time) bool {
return t.Before(time.Now())
}
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
// given cipher, and base64-encodes the result
func MarshalSession(s *SessionState, c aead.Cipher) (string, error) {
return c.Marshal(s)
}
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice using the pased cipher, and unmarshals the resulting JSON into a session state struct
func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) {
s := &SessionState{}
err := c.Unmarshal(value, s)
if err != nil {
return nil, err
}
return s, nil
}
// ExtendDeadline returns the time extended by a given duration
func ExtendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second)
}

View file

@ -0,0 +1,71 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"reflect"
"testing"
"time"
"github.com/pomerium/pomerium/internal/aead"
)
func TestSessionStateSerialization(t *testing.T) {
secret := aead.GenerateKey()
c, err := aead.NewMiscreantCipher([]byte(secret))
if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err)
}
want := &SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(),
Email: "user@domain.com",
User: "user",
}
ciphertext, err := MarshalSession(want, c)
if err != nil {
t.Fatalf("expected to be encode session: %v", err)
}
got, err := UnmarshalSession(ciphertext, c)
if err != nil {
t.Fatalf("expected to be decode session: %v", err)
}
if !reflect.DeepEqual(want, got) {
t.Logf("want: %#v", want)
t.Logf(" got: %#v", got)
t.Errorf("encoding and decoding session resulted in unexpected output")
}
}
func TestSessionStateExpirations(t *testing.T) {
session := &SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(-1 * time.Hour),
RefreshDeadline: time.Now().Add(-1 * time.Hour),
ValidDeadline: time.Now().Add(-1 * time.Minute),
Email: "user@domain.com",
User: "user",
}
if !session.LifetimePeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
}
if !session.RefreshPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
}
if !session.ValidationPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
}
}