diff --git a/CHANGELOG.md b/CHANGELOG.md index 47ffe7adb..06f0be549 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,8 +16,9 @@ - Fixed HEADERS environment variable parsing [GH-188] - Fixed Azure group lookups [GH-190] +- If a session is too large (over 4096 bytes) Pomerium will no longer fail silently. [GH-211] -## v0.0.5 + ## v0.0.5 ### NEW diff --git a/authenticate/handlers.go b/authenticate/handlers.go index f4caced55..7289db81a 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -278,7 +278,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) if err := a.sessionStore.SaveSession(w, r, session); err != nil { log.Error().Err(err).Msg("authenticate: failed saving new session") - return "", httputil.Error{Code: http.StatusInternalServerError, Message: "Internal Error"} + return "", httputil.Error{Code: http.StatusInternalServerError, Message: err.Error()} } return redirect, nil diff --git a/internal/sessions/session_state.go b/internal/sessions/session_state.go index ec94258c7..6750c5dcc 100644 --- a/internal/sessions/session_state.go +++ b/internal/sessions/session_state.go @@ -11,6 +11,8 @@ import ( "github.com/pomerium/pomerium/internal/cryptutil" ) +const MaxCookieSize = 4096 + var ( // ErrLifetimeExpired is an error for the lifetime deadline expiring ErrLifetimeExpired = errors.New("user lifetime expired") @@ -87,7 +89,14 @@ func isExpired(t time.Time) bool { // MarshalSession marshals the session state as JSON, encrypts the JSON using the // given cipher, and base64-encodes the result func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) { - return c.Marshal(s) + v, err := c.Marshal(s) + if err != nil { + return "", err + } + if len(v) >= MaxCookieSize { + return "", fmt.Errorf("session too large, got %d bytes", len(v)) + } + return v, nil } // UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the diff --git a/internal/sessions/session_state_test.go b/internal/sessions/session_state_test.go index 3e725d88d..28144fb1e 100644 --- a/internal/sessions/session_state_test.go +++ b/internal/sessions/session_state_test.go @@ -1,10 +1,12 @@ package sessions import ( + "crypto/rand" "reflect" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/pomerium/pomerium/internal/cryptutil" ) @@ -138,3 +140,41 @@ func TestSessionState_Impersonating(t *testing.T) { }) } } + +func TestMarshalSession(t *testing.T) { + secret := cryptutil.GenerateKey() + c, err := cryptutil.NewCipher([]byte(secret)) + if err != nil { + t.Fatalf("expected to be able to create cipher: %v", err) + } + hugeString := make([]byte, 4097) + if _, err := rand.Read(hugeString); err != nil { + t.Fatal(err) + } + tests := []struct { + name string + s *SessionState + wantErr bool + }{ + {"simple", &SessionState{}, false}, + {"too big", &SessionState{AccessToken: string(hugeString)}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in, err := MarshalSession(tt.s, c) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalSession() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil { + out, err := UnmarshalSession(in, c) + if err != nil { + t.Fatalf("expected to be decode session: %v", err) + } + if diff := cmp.Diff(tt.s, out); diff != "" { + t.Errorf("MarshalSession() = %s", diff) + } + } + }) + } +}