mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 08:50:42 +02:00
authenticator: support groups (#57)
- authenticate/providers: add group support to azure - authenticate/providers: add group support to google - authenticate/providers: add group support to okta - authenticate/providers: add group support to onelogin - {authenticate/proxy}: change default cookie lifetime timeout to 14 hours - proxy: sign group membership - proxy: add group header - deployment: add CHANGELOG - deployment: fix where make release wasn’t including version
This commit is contained in:
parent
a2d647ee5b
commit
1187be2bf3
54 changed files with 1757 additions and 1706 deletions
|
@ -11,8 +11,8 @@ import (
|
|||
|
||||
"github.com/pomerium/envconfig"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/providers"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
)
|
||||
|
@ -21,7 +21,7 @@ var defaultOptions = &Options{
|
|||
CookieName: "_pomerium_authenticate",
|
||||
CookieHTTPOnly: true,
|
||||
CookieSecure: true,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieExpire: time.Duration(14) * time.Hour,
|
||||
CookieRefresh: time.Duration(30) * time.Minute,
|
||||
}
|
||||
|
||||
|
@ -50,11 +50,12 @@ type Options struct {
|
|||
|
||||
// IdentityProvider provider configuration variables as specified by RFC6749
|
||||
// https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749
|
||||
ClientID string `envconfig:"IDP_CLIENT_ID"`
|
||||
ClientSecret string `envconfig:"IDP_CLIENT_SECRET"`
|
||||
Provider string `envconfig:"IDP_PROVIDER"`
|
||||
ProviderURL string `envconfig:"IDP_PROVIDER_URL"`
|
||||
Scopes []string `envconfig:"IDP_SCOPES"`
|
||||
ClientID string `envconfig:"IDP_CLIENT_ID"`
|
||||
ClientSecret string `envconfig:"IDP_CLIENT_SECRET"`
|
||||
Provider string `envconfig:"IDP_PROVIDER"`
|
||||
ProviderURL string `envconfig:"IDP_PROVIDER_URL"`
|
||||
Scopes []string `envconfig:"IDP_SCOPES"`
|
||||
ServiceAccount string `envconfig:"IDP_SERVICE_ACCOUNT"`
|
||||
}
|
||||
|
||||
// OptionsFromEnvConfig builds the authenticate service's configuration environmental variables
|
||||
|
@ -117,7 +118,7 @@ type Authenticate struct {
|
|||
sessionStore sessions.SessionStore
|
||||
cipher cryptutil.Cipher
|
||||
|
||||
provider providers.Provider
|
||||
provider identity.Authenticator
|
||||
}
|
||||
|
||||
// New validates and creates a new authenticate service from a set of Options
|
||||
|
@ -147,15 +148,16 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
|
|||
return nil, err
|
||||
}
|
||||
|
||||
provider, err := providers.New(
|
||||
provider, err := identity.New(
|
||||
opts.Provider,
|
||||
&providers.IdentityProvider{
|
||||
RedirectURL: opts.RedirectURL,
|
||||
ProviderName: opts.Provider,
|
||||
ProviderURL: opts.ProviderURL,
|
||||
ClientID: opts.ClientID,
|
||||
ClientSecret: opts.ClientSecret,
|
||||
Scopes: opts.Scopes,
|
||||
&identity.Provider{
|
||||
RedirectURL: opts.RedirectURL,
|
||||
ProviderName: opts.Provider,
|
||||
ProviderURL: opts.ProviderURL,
|
||||
ClientID: opts.ClientID,
|
||||
ClientSecret: opts.ClientSecret,
|
||||
Scopes: opts.Scopes,
|
||||
ServiceAccount: opts.ServiceAccount,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -19,9 +19,8 @@ func testOptions() *Options {
|
|||
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieRefresh: time.Duration(1) * time.Hour,
|
||||
// CookieLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieName: "pomerium",
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieName: "pomerium",
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,329 +0,0 @@
|
|||
// Package circuit implements the Circuit Breaker pattern.
|
||||
// https://docs.microsoft.com/en-us/azure/architecture/patterns/circuit-breaker
|
||||
package circuit // import "github.com/pomerium/pomerium/internal/circuit"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
)
|
||||
|
||||
// State is a type that represents a state of Breaker.
|
||||
type State int
|
||||
|
||||
// These constants are states of Breaker.
|
||||
const (
|
||||
StateClosed State = iota
|
||||
StateHalfOpen
|
||||
StateOpen
|
||||
)
|
||||
|
||||
type (
|
||||
// ShouldTripFunc is a function that takes in a Counts and returns true if the circuit breaker should be tripped.
|
||||
ShouldTripFunc func(Counts) bool
|
||||
// ShouldResetFunc is a function that takes in a Counts and returns true if the circuit breaker should be reset.
|
||||
ShouldResetFunc func(Counts) bool
|
||||
// BackoffDurationFunc is a function that takes in a Counts and returns the backoff duration
|
||||
BackoffDurationFunc func(Counts) time.Duration
|
||||
|
||||
// StateChangeHook is a function that represents a state change.
|
||||
StateChangeHook func(prev, to State)
|
||||
// BackoffHook is a function that represents backoff.
|
||||
BackoffHook func(duration time.Duration, reset time.Time)
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultShouldTripFunc is a default ShouldTripFunc.
|
||||
DefaultShouldTripFunc = func(counts Counts) bool {
|
||||
// Trip into Open after three consecutive failures
|
||||
return counts.ConsecutiveFailures >= 3
|
||||
}
|
||||
// DefaultShouldResetFunc is a default ShouldResetFunc.
|
||||
DefaultShouldResetFunc = func(counts Counts) bool {
|
||||
// Reset after three consecutive successes
|
||||
return counts.ConsecutiveSuccesses >= 3
|
||||
}
|
||||
// DefaultBackoffDurationFunc is an exponential backoff function
|
||||
DefaultBackoffDurationFunc = ExponentialBackoffDuration(time.Duration(100)*time.Second, time.Duration(500)*time.Millisecond)
|
||||
)
|
||||
|
||||
// ErrOpenState is returned when the b state is open
|
||||
type ErrOpenState struct{}
|
||||
|
||||
func (e *ErrOpenState) Error() string { return "circuit breaker is open" }
|
||||
|
||||
// ExponentialBackoffDuration returns a function that uses exponential backoff and full jitter
|
||||
func ExponentialBackoffDuration(maxBackoff, baseTimeout time.Duration) func(Counts) time.Duration {
|
||||
return func(counts Counts) time.Duration {
|
||||
// Full Jitter from https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
||||
// sleep = random_between(0, min(cap, base * 2 ** attempt))
|
||||
backoff := math.Min(float64(maxBackoff), float64(baseTimeout)*math.Exp2(float64(counts.ConsecutiveFailures)))
|
||||
jittered := rand.Float64() * backoff
|
||||
return time.Duration(jittered)
|
||||
}
|
||||
}
|
||||
|
||||
// String implements stringer interface.
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateClosed:
|
||||
return "closed"
|
||||
case StateHalfOpen:
|
||||
return "half-open"
|
||||
case StateOpen:
|
||||
return "open"
|
||||
default:
|
||||
return fmt.Sprintf("unknown state: %d", s)
|
||||
}
|
||||
}
|
||||
|
||||
// Counts holds the numbers of requests and their successes/failures.
|
||||
type Counts struct {
|
||||
CurrentRequests int
|
||||
ConsecutiveSuccesses int
|
||||
ConsecutiveFailures int
|
||||
}
|
||||
|
||||
func (c *Counts) onRequest() {
|
||||
c.CurrentRequests++
|
||||
}
|
||||
|
||||
func (c *Counts) afterRequest() {
|
||||
c.CurrentRequests--
|
||||
}
|
||||
|
||||
func (c *Counts) onSuccess() {
|
||||
c.ConsecutiveSuccesses++
|
||||
c.ConsecutiveFailures = 0
|
||||
}
|
||||
|
||||
func (c *Counts) onFailure() {
|
||||
c.ConsecutiveFailures++
|
||||
c.ConsecutiveSuccesses = 0
|
||||
}
|
||||
|
||||
func (c *Counts) clear() {
|
||||
c.ConsecutiveSuccesses = 0
|
||||
c.ConsecutiveFailures = 0
|
||||
}
|
||||
|
||||
// Options configures Breaker:
|
||||
//
|
||||
// HalfOpenConcurrentRequests specifies how many concurrent requests to allow while
|
||||
// the circuit is in the half-open state
|
||||
//
|
||||
// ShouldTripFunc specifies when the circuit should trip from the closed state to
|
||||
// the open state. It takes a Counts struct and returns a bool.
|
||||
//
|
||||
// ShouldResetFunc specifies when the circuit should be reset from the half-open state
|
||||
// to the closed state and allow all requests. It takes a Counts struct and returns a bool.
|
||||
//
|
||||
// BackoffDurationFunc specifies how long to set the backoff duration. It takes a
|
||||
// counts struct and returns a time.Duration
|
||||
//
|
||||
// OnStateChange is called whenever the state of the Breaker changes.
|
||||
//
|
||||
// OnBackoff is called whenever a backoff is set with the backoff duration and reset time
|
||||
//
|
||||
// TestClock is used to mock the clock during tests
|
||||
type Options struct {
|
||||
HalfOpenConcurrentRequests int
|
||||
|
||||
ShouldTripFunc ShouldTripFunc
|
||||
ShouldResetFunc ShouldResetFunc
|
||||
BackoffDurationFunc BackoffDurationFunc
|
||||
|
||||
// hooks
|
||||
OnStateChange StateChangeHook
|
||||
OnBackoff BackoffHook
|
||||
|
||||
// used in tests
|
||||
TestClock clock.Clock
|
||||
}
|
||||
|
||||
// Breaker is a state machine to prevent sending requests that are likely to fail.
|
||||
type Breaker struct {
|
||||
halfOpenRequests int
|
||||
|
||||
shouldTripFunc ShouldTripFunc
|
||||
shouldResetFunc ShouldResetFunc
|
||||
backoffDurationFunc BackoffDurationFunc
|
||||
|
||||
// hooks
|
||||
onStateChange StateChangeHook
|
||||
onBackoff BackoffHook
|
||||
|
||||
// used primarily for mocking tests
|
||||
clock clock.Clock
|
||||
|
||||
mutex sync.Mutex
|
||||
state State
|
||||
counts Counts
|
||||
backoffExpires time.Time
|
||||
generation int
|
||||
}
|
||||
|
||||
// NewBreaker returns a new Breaker configured with the given Settings.
|
||||
func NewBreaker(opts *Options) *Breaker {
|
||||
b := new(Breaker)
|
||||
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
|
||||
// set hooks
|
||||
b.onStateChange = opts.OnStateChange
|
||||
b.onBackoff = opts.OnBackoff
|
||||
|
||||
b.halfOpenRequests = 1
|
||||
if opts.HalfOpenConcurrentRequests > 0 {
|
||||
b.halfOpenRequests = opts.HalfOpenConcurrentRequests
|
||||
}
|
||||
|
||||
b.backoffDurationFunc = DefaultBackoffDurationFunc
|
||||
if opts.BackoffDurationFunc != nil {
|
||||
b.backoffDurationFunc = opts.BackoffDurationFunc
|
||||
}
|
||||
|
||||
b.shouldTripFunc = DefaultShouldTripFunc
|
||||
if opts.ShouldTripFunc != nil {
|
||||
b.shouldTripFunc = opts.ShouldTripFunc
|
||||
}
|
||||
|
||||
b.shouldResetFunc = DefaultShouldResetFunc
|
||||
if opts.ShouldResetFunc != nil {
|
||||
b.shouldResetFunc = opts.ShouldResetFunc
|
||||
}
|
||||
|
||||
b.clock = clock.New()
|
||||
if opts.TestClock != nil {
|
||||
b.clock = opts.TestClock
|
||||
}
|
||||
|
||||
b.setState(StateClosed)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Call runs the given function if the Breaker allows the call.
|
||||
// Call returns an error instantly if the Breaker rejects the request.
|
||||
// Otherwise, Call returns the result of the request.
|
||||
func (b *Breaker) Call(f func() (interface{}, error)) (interface{}, error) {
|
||||
generation, err := b.beforeRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := f()
|
||||
b.afterRequest(err == nil, generation)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (b *Breaker) beforeRequest() (int, error) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
state, generation := b.currentState()
|
||||
|
||||
switch state {
|
||||
case StateOpen:
|
||||
return generation, &ErrOpenState{}
|
||||
case StateHalfOpen:
|
||||
if b.counts.CurrentRequests >= b.halfOpenRequests {
|
||||
return generation, &ErrOpenState{}
|
||||
}
|
||||
}
|
||||
|
||||
b.counts.onRequest()
|
||||
return generation, nil
|
||||
}
|
||||
|
||||
func (b *Breaker) afterRequest(success bool, prevGeneration int) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
b.counts.afterRequest()
|
||||
|
||||
state, generation := b.currentState()
|
||||
if prevGeneration != generation {
|
||||
return
|
||||
}
|
||||
|
||||
if success {
|
||||
b.onSuccess(state)
|
||||
return
|
||||
}
|
||||
|
||||
b.onFailure(state)
|
||||
}
|
||||
|
||||
func (b *Breaker) onSuccess(state State) {
|
||||
b.counts.onSuccess()
|
||||
switch state {
|
||||
case StateHalfOpen:
|
||||
if b.shouldResetFunc(b.counts) {
|
||||
b.setState(StateClosed)
|
||||
b.counts.clear()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Breaker) onFailure(state State) {
|
||||
b.counts.onFailure()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
if b.shouldTripFunc(b.counts) {
|
||||
b.setState(StateOpen)
|
||||
b.counts.clear()
|
||||
b.setBackoff()
|
||||
}
|
||||
case StateOpen:
|
||||
b.setBackoff()
|
||||
case StateHalfOpen:
|
||||
b.setState(StateOpen)
|
||||
b.setBackoff()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Breaker) setBackoff() {
|
||||
backoffDuration := b.backoffDurationFunc(b.counts)
|
||||
backoffExpires := b.clock.Now().Add(backoffDuration)
|
||||
b.backoffExpires = backoffExpires
|
||||
if b.onBackoff != nil {
|
||||
b.onBackoff(backoffDuration, backoffExpires)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Breaker) currentState() (State, int) {
|
||||
switch b.state {
|
||||
case StateOpen:
|
||||
if b.clock.Now().After(b.backoffExpires) {
|
||||
b.setState(StateHalfOpen)
|
||||
}
|
||||
}
|
||||
return b.state, b.generation
|
||||
}
|
||||
|
||||
func (b *Breaker) newGeneration() {
|
||||
b.generation++
|
||||
}
|
||||
|
||||
func (b *Breaker) setState(state State) {
|
||||
if b.state == state {
|
||||
return
|
||||
}
|
||||
|
||||
b.newGeneration()
|
||||
|
||||
prev := b.state
|
||||
b.state = state
|
||||
|
||||
if b.onStateChange != nil {
|
||||
b.onStateChange(prev, state)
|
||||
}
|
||||
}
|
|
@ -1,187 +0,0 @@
|
|||
package circuit // import "github.com/pomerium/pomerium/internal/circuit"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
)
|
||||
|
||||
var errFailed = errors.New("failed")
|
||||
|
||||
func fail() (interface{}, error) {
|
||||
return nil, errFailed
|
||||
}
|
||||
|
||||
func succeed() (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestCircuitBreaker(t *testing.T) {
|
||||
mock := clock.NewMock()
|
||||
threshold := 3
|
||||
timeout := time.Duration(2) * time.Second
|
||||
trip := func(c Counts) bool { return c.ConsecutiveFailures > threshold }
|
||||
reset := func(c Counts) bool { return c.ConsecutiveSuccesses > threshold }
|
||||
backoff := func(c Counts) time.Duration { return timeout }
|
||||
stateChange := func(p, c State) { t.Logf("state change from %s to %s\n", p, c) }
|
||||
cb := NewBreaker(&Options{
|
||||
TestClock: mock,
|
||||
ShouldTripFunc: trip,
|
||||
ShouldResetFunc: reset,
|
||||
BackoffDurationFunc: backoff,
|
||||
OnStateChange: stateChange,
|
||||
})
|
||||
state, _ := cb.currentState()
|
||||
if state != StateClosed {
|
||||
t.Fatalf("expected state to start %s, got %s", StateClosed, state)
|
||||
}
|
||||
|
||||
for i := 0; i <= threshold; i++ {
|
||||
_, err := cb.Call(fail)
|
||||
if err == nil {
|
||||
t.Fatalf("expected to error, got nil")
|
||||
}
|
||||
state, _ := cb.currentState()
|
||||
t.Logf("iteration %#v", i)
|
||||
if i == threshold {
|
||||
// we expect this to be the case to trip the circuit
|
||||
if state != StateOpen {
|
||||
t.Fatalf("expected state to be %s, got %s", StateOpen, state)
|
||||
}
|
||||
} else if state != StateClosed {
|
||||
// this is a normal failure case
|
||||
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := cb.Call(fail)
|
||||
switch err.(type) {
|
||||
case *ErrOpenState:
|
||||
// this is the expected case
|
||||
break
|
||||
default:
|
||||
t.Errorf("%#v", cb.counts)
|
||||
t.Fatalf("expected to get open state failure, got %s", err)
|
||||
}
|
||||
|
||||
// we advance time by the timeout and a hair
|
||||
mock.Add(timeout + time.Duration(1)*time.Millisecond)
|
||||
state, _ = cb.currentState()
|
||||
if state != StateHalfOpen {
|
||||
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
|
||||
}
|
||||
|
||||
for i := 0; i <= threshold; i++ {
|
||||
_, err := cb.Call(succeed)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to get no error, got %s", err)
|
||||
}
|
||||
state, _ := cb.currentState()
|
||||
t.Logf("iteration %#v", i)
|
||||
if i == threshold {
|
||||
// we expect this to be the case that ressets the circuit
|
||||
if state != StateClosed {
|
||||
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
||||
}
|
||||
} else if state != StateHalfOpen {
|
||||
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
|
||||
}
|
||||
}
|
||||
|
||||
state, _ = cb.currentState()
|
||||
if state != StateClosed {
|
||||
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExponentialBackOffFunc(t *testing.T) {
|
||||
baseTimeout := time.Duration(1) * time.Millisecond
|
||||
// Note Expected is an upper range case
|
||||
cases := []struct {
|
||||
FailureCount int
|
||||
Expected time.Duration
|
||||
}{
|
||||
{
|
||||
FailureCount: 0,
|
||||
Expected: time.Duration(1) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 1,
|
||||
Expected: time.Duration(2) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 2,
|
||||
Expected: time.Duration(4) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 3,
|
||||
Expected: time.Duration(8) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 4,
|
||||
Expected: time.Duration(16) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 5,
|
||||
Expected: time.Duration(32) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 6,
|
||||
Expected: time.Duration(64) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 7,
|
||||
Expected: time.Duration(128) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 8,
|
||||
Expected: time.Duration(256) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 9,
|
||||
Expected: time.Duration(512) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 10,
|
||||
Expected: time.Duration(1024) * time.Millisecond,
|
||||
},
|
||||
}
|
||||
|
||||
f := ExponentialBackoffDuration(time.Duration(1)*time.Hour, baseTimeout)
|
||||
for _, tc := range cases {
|
||||
got := f(Counts{ConsecutiveFailures: tc.FailureCount})
|
||||
t.Logf("got backoff %#v", got)
|
||||
if got > tc.Expected {
|
||||
t.Errorf("got %#v but expected less than %#v", got, tc.Expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerClosedParallel(t *testing.T) {
|
||||
cb := NewBreaker(nil)
|
||||
numReqs := 10000
|
||||
wg := &sync.WaitGroup{}
|
||||
routine := func(wg *sync.WaitGroup) {
|
||||
for i := 0; i < numReqs; i++ {
|
||||
cb.Call(succeed)
|
||||
}
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
numRoutines := 10
|
||||
for i := 0; i < numRoutines; i++ {
|
||||
wg.Add(1)
|
||||
go routine(wg)
|
||||
}
|
||||
|
||||
total := numReqs * numRoutines
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if cb.counts.ConsecutiveSuccesses != total {
|
||||
t.Fatalf("expected to get total requests %d, got %d", total, cb.counts.ConsecutiveSuccesses)
|
||||
}
|
||||
}
|
|
@ -3,62 +3,54 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
// Authenticate takes an encrypted code, and returns the authentication result.
|
||||
func (p *Authenticate) Authenticate(ctx context.Context, in *pb.AuthenticateRequest) (*pb.AuthenticateReply, error) {
|
||||
func (p *Authenticate) Authenticate(ctx context.Context, in *pb.AuthenticateRequest) (*pb.Session, error) {
|
||||
session, err := sessions.UnmarshalSession(in.Code, p.cipher)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/grpc: %v", err)
|
||||
return nil, fmt.Errorf("authenticate/grpc: authenticate %v", err)
|
||||
}
|
||||
expiryTimestamp, err := ptypes.TimestampProto(session.RefreshDeadline)
|
||||
newSessionProto, err := pb.ProtoFromSession(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pb.AuthenticateReply{
|
||||
AccessToken: session.AccessToken,
|
||||
RefreshToken: session.RefreshToken,
|
||||
IdToken: session.IDToken,
|
||||
User: session.User,
|
||||
Email: session.Email,
|
||||
Expiry: expiryTimestamp,
|
||||
}, nil
|
||||
return newSessionProto, nil
|
||||
}
|
||||
|
||||
// Validate locally validates a JWT id token; does NOT do nonce or revokation validation.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (p *Authenticate) Validate(ctx context.Context, in *pb.ValidateRequest) (*pb.ValidateReply, error) {
|
||||
isValid, err := p.provider.Validate(in.IdToken)
|
||||
isValid, err := p.provider.Validate(ctx, in.IdToken)
|
||||
if err != nil {
|
||||
return &pb.ValidateReply{IsValid: false}, err
|
||||
return &pb.ValidateReply{IsValid: false}, fmt.Errorf("authenticate/grpc: validate %v", err)
|
||||
}
|
||||
return &pb.ValidateReply{IsValid: isValid}, nil
|
||||
}
|
||||
|
||||
// Refresh renews a user's session checks if the session has been revoked using an access token
|
||||
// without reprompting the user.
|
||||
func (p *Authenticate) Refresh(ctx context.Context, in *pb.RefreshRequest) (*pb.RefreshReply, error) {
|
||||
newToken, err := p.provider.Refresh(in.RefreshToken)
|
||||
func (p *Authenticate) Refresh(ctx context.Context, in *pb.Session) (*pb.Session, error) {
|
||||
// todo(bdd): add request id from incoming context
|
||||
// md, _ := metadata.FromIncomingContext(ctx)
|
||||
// sublogger := log.With().Str("req_id", md.Get("req_id")[0]).WithContext(ctx)
|
||||
// sublogger.Info().Msg("tracing sucks!")
|
||||
if in == nil {
|
||||
return nil, fmt.Errorf("authenticate/grpc: session cannot be nil")
|
||||
}
|
||||
oldSession, err := pb.SessionFromProto(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiryTimestamp, err := ptypes.TimestampProto(newToken.Expiry)
|
||||
newSession, err := p.provider.Refresh(ctx, oldSession)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/grpc: refresh failed %v", err)
|
||||
}
|
||||
newSessionProto, err := pb.ProtoFromSession(newSession)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info().
|
||||
Str("session.AccessToken", newToken.AccessToken).
|
||||
Msg("authenticate: grpc: refresh: ok")
|
||||
|
||||
return &pb.RefreshReply{
|
||||
AccessToken: newToken.AccessToken,
|
||||
Expiry: expiryTimestamp,
|
||||
}, nil
|
||||
|
||||
return newSessionProto, nil
|
||||
}
|
||||
|
|
|
@ -7,58 +7,32 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
||||
|
||||
// TestProvider is a mock provider
|
||||
type testProvider struct{}
|
||||
|
||||
func (tp *testProvider) Authenticate(s string) (*sessions.SessionState, error) {
|
||||
return &sessions.SessionState{}, nil
|
||||
}
|
||||
|
||||
func (tp *testProvider) Revoke(s string) error { return nil }
|
||||
func (tp *testProvider) GetSignInURL(s string) string { return "/signin" }
|
||||
func (tp *testProvider) Refresh(s string) (*oauth2.Token, error) {
|
||||
if s == "error" {
|
||||
return nil, errors.New("failed refresh")
|
||||
}
|
||||
if s == "bad time" {
|
||||
return &oauth2.Token{AccessToken: "updated", Expiry: time.Time{}}, nil
|
||||
}
|
||||
return &oauth2.Token{AccessToken: "updated", Expiry: fixedDate}, nil
|
||||
}
|
||||
func (tp *testProvider) Validate(token string) (bool, error) {
|
||||
if token == "good" {
|
||||
return true, nil
|
||||
} else if token == "error" {
|
||||
return false, errors.New("error validating id token")
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func TestAuthenticate_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
mp *identity.MockProvider
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "example", false, false},
|
||||
{"error", "error", false, true},
|
||||
{"not error", "not error", false, false},
|
||||
{"good", "example", &identity.MockProvider{}, false, false},
|
||||
{"error", "error", &identity.MockProvider{ValidateError: errors.New("err")}, false, true},
|
||||
{"not error", "not error", &identity.MockProvider{ValidateError: nil}, false, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tp := &testProvider{}
|
||||
p := &Authenticate{provider: tp}
|
||||
p := &Authenticate{provider: tt.mp}
|
||||
got, err := p.Validate(context.Background(), &pb.ValidateRequest{IdToken: tt.idToken})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authenticate.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
@ -78,24 +52,43 @@ func TestAuthenticate_Refresh(t *testing.T) {
|
|||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
want *pb.RefreshReply
|
||||
wantErr bool
|
||||
name string
|
||||
mock *identity.MockProvider
|
||||
originalSession *pb.Session
|
||||
want *pb.Session
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "refresh-token", &pb.RefreshReply{AccessToken: "updated", Expiry: fixedProtoTime}, false},
|
||||
{"test error", "error", nil, true},
|
||||
{"good",
|
||||
&identity.MockProvider{
|
||||
RefreshResponse: &sessions.SessionState{
|
||||
AccessToken: "updated",
|
||||
LifetimeDeadline: fixedDate,
|
||||
RefreshDeadline: fixedDate,
|
||||
}},
|
||||
&pb.Session{
|
||||
AccessToken: "original",
|
||||
LifetimeDeadline: fixedProtoTime,
|
||||
RefreshDeadline: fixedProtoTime,
|
||||
},
|
||||
&pb.Session{
|
||||
AccessToken: "updated",
|
||||
LifetimeDeadline: fixedProtoTime,
|
||||
RefreshDeadline: fixedProtoTime,
|
||||
},
|
||||
false},
|
||||
{"test error", &identity.MockProvider{RefreshError: errors.New("hi")}, &pb.Session{RefreshToken: "refresh token", RefreshDeadline: fixedProtoTime, LifetimeDeadline: fixedProtoTime}, nil, true},
|
||||
{"test catch nil", nil, nil, nil, true},
|
||||
|
||||
// {"test error", "error", nil, true},
|
||||
// {"test bad time", "bad time", nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tp := &testProvider{}
|
||||
p := &Authenticate{provider: tp}
|
||||
p := &Authenticate{provider: tt.mock}
|
||||
|
||||
got, err := p.Refresh(context.Background(), &pb.RefreshRequest{RefreshToken: tt.refreshToken})
|
||||
got, err := p.Refresh(context.Background(), tt.originalSession)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authenticate.Refresh() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authenticate.Refresh() = %v, want %v", got, tt.want)
|
||||
|
@ -132,12 +125,13 @@ func TestAuthenticate_Authenticate(t *testing.T) {
|
|||
User: "user",
|
||||
}
|
||||
|
||||
goodReply := &pb.AuthenticateReply{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
Expiry: vtProto,
|
||||
Email: "user@domain.com",
|
||||
User: "user"}
|
||||
goodReply := &pb.Session{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
LifetimeDeadline: vtProto,
|
||||
RefreshDeadline: vtProto,
|
||||
Email: "user@domain.com",
|
||||
User: "user"}
|
||||
ciphertext, err := sessions.MarshalSession(want, c)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be encode session: %v", err)
|
||||
|
@ -147,7 +141,7 @@ func TestAuthenticate_Authenticate(t *testing.T) {
|
|||
name string
|
||||
cipher cryptutil.Cipher
|
||||
code string
|
||||
want *pb.AuthenticateReply
|
||||
want *pb.Session
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", c, ciphertext, goodReply, false},
|
||||
|
@ -162,7 +156,7 @@ func TestAuthenticate_Authenticate(t *testing.T) {
|
|||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authenticate.Authenticate() = %v, want %v", got, tt.want)
|
||||
t.Errorf("Authenticate.Authenticate() = got: \n%vwant:\n%v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -83,34 +83,29 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
|||
|
||||
// check if session refresh period is up
|
||||
if session.RefreshPeriodExpired() {
|
||||
newToken, err := a.provider.Refresh(session.RefreshToken)
|
||||
newSession, err := a.provider.Refresh(r.Context(), session)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to refresh session")
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
return nil, err
|
||||
}
|
||||
session.AccessToken = newToken.AccessToken
|
||||
session.RefreshDeadline = newToken.Expiry
|
||||
err = a.sessionStore.SaveSession(w, r, session)
|
||||
err = a.sessionStore.SaveSession(w, r, newSession)
|
||||
if err != nil {
|
||||
// We refreshed the session successfully, but failed to save it.
|
||||
// This could be from failing to encode the session properly.
|
||||
// But, we clear the session cookie and reject the request
|
||||
log.FromRequest(r).Error().Err(err).Msg("could not save refreshed session")
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: could not save refreshed session")
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// The session has not exceeded it's lifetime or requires refresh
|
||||
ok, err := a.provider.Validate(session.IDToken)
|
||||
ok, err := a.provider.Validate(r.Context(), session.IDToken)
|
||||
if !ok || err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("invalid session state")
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: invalid session state")
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
err = a.sessionStore.SaveSession(w, r, session)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("failed to save valid session")
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to save valid session")
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
return nil, err
|
||||
}
|
||||
|
@ -136,7 +131,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
log.FromRequest(r).Info().Msg("authenticate: user authenticated")
|
||||
a.ProxyCallback(w, r, session)
|
||||
|
||||
}
|
||||
|
||||
// ProxyCallback redirects the user back to proxy service along with an encrypted payload, as
|
||||
|
@ -310,7 +304,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
errorString := r.Form.Get("error")
|
||||
if errorString != "" {
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: provider returned error")
|
||||
log.FromRequest(r).Error().Str("Error", errorString).Msg("authenticate: provider returned error")
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: errorString}
|
||||
}
|
||||
code := r.Form.Get("code")
|
||||
|
|
|
@ -11,11 +11,10 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/providers"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// mocks for validator func
|
||||
|
@ -89,36 +88,36 @@ func TestAuthenticate_authenticate(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
session sessions.SessionStore
|
||||
provider providers.MockProvider
|
||||
provider identity.MockProvider
|
||||
validator func(string) bool
|
||||
want *sessions.SessionState
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", goodSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, false},
|
||||
{"good but fails validation", goodSession, providers.MockProvider{ValidateResponse: true}, falseValidator, nil, true},
|
||||
{"can't load session", &sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
|
||||
{"validation fails", goodSession, providers.MockProvider{ValidateResponse: false}, trueValidator, nil, true},
|
||||
{"good", goodSession, identity.MockProvider{ValidateResponse: true}, trueValidator, nil, false},
|
||||
{"good but fails validation", goodSession, identity.MockProvider{ValidateResponse: true}, falseValidator, nil, true},
|
||||
{"can't load session", &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
|
||||
{"validation fails", goodSession, identity.MockProvider{ValidateResponse: false}, trueValidator, nil, true},
|
||||
{"session fails after good validation", &sessions.MockSessionStore{
|
||||
SaveError: errors.New("error"),
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}}, providers.MockProvider{ValidateResponse: true},
|
||||
}}, identity.MockProvider{ValidateResponse: true},
|
||||
trueValidator, nil, true},
|
||||
{"refresh expired",
|
||||
expiredRefresPeriod,
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
ValidateResponse: true,
|
||||
RefreshResponse: &oauth2.Token{
|
||||
AccessToken: "new token",
|
||||
Expiry: time.Now(),
|
||||
RefreshResponse: &sessions.SessionState{
|
||||
AccessToken: "new token",
|
||||
LifetimeDeadline: time.Now(),
|
||||
},
|
||||
},
|
||||
trueValidator, nil, false},
|
||||
{"refresh expired refresh error",
|
||||
expiredRefresPeriod,
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
ValidateResponse: true,
|
||||
RefreshError: errors.New("error"),
|
||||
},
|
||||
|
@ -132,11 +131,11 @@ func TestAuthenticate_authenticate(t *testing.T) {
|
|||
|
||||
RefreshDeadline: time.Now().Add(10 * -time.Second),
|
||||
}},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
ValidateResponse: true,
|
||||
RefreshResponse: &oauth2.Token{
|
||||
AccessToken: "new token",
|
||||
Expiry: time.Now(),
|
||||
RefreshResponse: &sessions.SessionState{
|
||||
AccessToken: "new token",
|
||||
LifetimeDeadline: time.Now(),
|
||||
},
|
||||
},
|
||||
trueValidator, nil, true},
|
||||
|
@ -164,7 +163,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
session sessions.SessionStore
|
||||
provider providers.MockProvider
|
||||
provider identity.MockProvider
|
||||
validator func(string) bool
|
||||
wantCode int
|
||||
}{
|
||||
|
@ -175,7 +174,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
RefreshToken: "RefreshToken",
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
providers.MockProvider{ValidateResponse: true},
|
||||
identity.MockProvider{ValidateResponse: true},
|
||||
trueValidator,
|
||||
http.StatusForbidden},
|
||||
{"session fails after good validation", &sessions.MockSessionStore{
|
||||
|
@ -184,7 +183,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}}, providers.MockProvider{ValidateResponse: true},
|
||||
}}, identity.MockProvider{ValidateResponse: true},
|
||||
trueValidator, http.StatusBadRequest},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -359,7 +358,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
sig string
|
||||
ts string
|
||||
|
||||
provider providers.Provider
|
||||
provider identity.Authenticator
|
||||
sessionStore sessions.SessionStore
|
||||
wantCode int
|
||||
wantBody string
|
||||
|
@ -369,7 +368,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
"https://corp.pomerium.io/",
|
||||
"sig",
|
||||
"ts",
|
||||
providers.MockProvider{},
|
||||
identity.MockProvider{},
|
||||
&sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
|
@ -386,7 +385,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
"https://corp.pomerium.io/",
|
||||
"sig",
|
||||
"ts",
|
||||
providers.MockProvider{RevokeError: errors.New("OH NO")},
|
||||
identity.MockProvider{RevokeError: errors.New("OH NO")},
|
||||
&sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
|
@ -404,7 +403,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
"https://corp.pomerium.io/",
|
||||
"sig",
|
||||
"ts",
|
||||
providers.MockProvider{},
|
||||
identity.MockProvider{},
|
||||
&sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
|
@ -421,7 +420,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
"https://corp.pomerium.io/",
|
||||
"sig",
|
||||
"ts",
|
||||
providers.MockProvider{},
|
||||
identity.MockProvider{},
|
||||
&sessions.MockSessionStore{
|
||||
LoadError: errors.New("uh oh"),
|
||||
Session: &sessions.SessionState{
|
||||
|
@ -439,7 +438,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
"https://pomerium.com%zzzzz",
|
||||
"sig",
|
||||
"ts",
|
||||
providers.MockProvider{},
|
||||
identity.MockProvider{},
|
||||
&sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
|
@ -497,7 +496,7 @@ func TestAuthenticate_OAuthStart(t *testing.T) {
|
|||
ts string
|
||||
allowedDomains []string
|
||||
|
||||
provider providers.Provider
|
||||
provider identity.Authenticator
|
||||
csrfStore sessions.MockCSRFStore
|
||||
// sessionStore sessions.SessionStore
|
||||
wantCode int
|
||||
|
@ -508,7 +507,7 @@ func TestAuthenticate_OAuthStart(t *testing.T) {
|
|||
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
|
||||
fmt.Sprint(time.Now().Unix()),
|
||||
[]string{".pomerium.io"},
|
||||
providers.MockProvider{},
|
||||
identity.MockProvider{},
|
||||
sessions.MockCSRFStore{},
|
||||
http.StatusFound,
|
||||
},
|
||||
|
@ -518,7 +517,7 @@ func TestAuthenticate_OAuthStart(t *testing.T) {
|
|||
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
|
||||
fmt.Sprint(time.Now().Add(10 * time.Hour).Unix()),
|
||||
[]string{".pomerium.io"},
|
||||
providers.MockProvider{},
|
||||
identity.MockProvider{},
|
||||
sessions.MockCSRFStore{},
|
||||
http.StatusBadRequest,
|
||||
},
|
||||
|
@ -528,7 +527,7 @@ func TestAuthenticate_OAuthStart(t *testing.T) {
|
|||
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
|
||||
fmt.Sprint(time.Now().Unix()),
|
||||
[]string{"not.pomerium.io"},
|
||||
providers.MockProvider{},
|
||||
identity.MockProvider{},
|
||||
sessions.MockCSRFStore{},
|
||||
http.StatusBadRequest,
|
||||
},
|
||||
|
@ -538,7 +537,7 @@ func TestAuthenticate_OAuthStart(t *testing.T) {
|
|||
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
|
||||
fmt.Sprint(time.Now().Unix()),
|
||||
[]string{".pomerium.io"},
|
||||
providers.MockProvider{},
|
||||
identity.MockProvider{},
|
||||
sessions.MockCSRFStore{},
|
||||
http.StatusBadRequest,
|
||||
},
|
||||
|
@ -548,7 +547,7 @@ func TestAuthenticate_OAuthStart(t *testing.T) {
|
|||
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
|
||||
fmt.Sprint(time.Now().Unix()),
|
||||
[]string{".pomerium.io"},
|
||||
providers.MockProvider{},
|
||||
identity.MockProvider{},
|
||||
sessions.MockCSRFStore{},
|
||||
http.StatusBadRequest,
|
||||
},
|
||||
|
@ -596,7 +595,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
validator func(string) bool
|
||||
|
||||
session sessions.SessionStore
|
||||
provider providers.MockProvider
|
||||
provider identity.MockProvider
|
||||
csrfStore sessions.MockCSRFStore
|
||||
|
||||
want string
|
||||
|
@ -610,7 +609,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
@ -632,7 +631,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
@ -655,7 +654,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
@ -677,7 +676,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateError: errors.New("error"),
|
||||
},
|
||||
sessions.MockCSRFStore{
|
||||
|
@ -694,7 +693,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
&sessions.MockSessionStore{SaveError: errors.New("error")},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
@ -716,7 +715,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
falseValidator,
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
@ -739,7 +738,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
@ -761,7 +760,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
@ -783,7 +782,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
@ -805,7 +804,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
@ -827,7 +826,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
identity.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
// Package providers authentication for third party identity providers (IdP) using OpenID
|
||||
// Connect, an identity layer on top of the OAuth 2.0 RFC6749 protocol.
|
||||
//
|
||||
// see: https://openid.net/specs/openid-connect-core-1_0.html
|
||||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
|
@ -1,80 +0,0 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/circuit"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
const defaultGitlabProviderURL = "https://gitlab.com"
|
||||
|
||||
// GitlabProvider is an implementation of the Provider interface.
|
||||
type GitlabProvider struct {
|
||||
*IdentityProvider
|
||||
cb *circuit.Breaker
|
||||
}
|
||||
|
||||
// NewGitlabProvider returns a new Gitlab identity provider; defaults to the hosted version.
|
||||
//
|
||||
// Unlike other providers, `email` is not returned from the initial OIDC token. To retrieve email,
|
||||
// a secondary call must be made to the user's info endpoint. Unfortunately, email is not guaranteed
|
||||
// or even likely to be returned even if the user has it set as their email must be set to public.
|
||||
// As pomerium is currently very email centric, I would caution using until Gitlab fixes the issue.
|
||||
//
|
||||
// See :
|
||||
// - https://gitlab.com/gitlab-org/gitlab-ce/issues/44435#note_88150387
|
||||
// - https://docs.gitlab.com/ee/integration/openid_connect_provider.html
|
||||
// - https://docs.gitlab.com/ee/integration/oauth_provider.html
|
||||
// - https://docs.gitlab.com/ee/api/oauth2.html
|
||||
// - https://gitlab.com/.well-known/openid-configuration
|
||||
func NewGitlabProvider(p *IdentityProvider) (*GitlabProvider, error) {
|
||||
ctx := context.Background()
|
||||
if p.ProviderURL == "" {
|
||||
p.ProviderURL = defaultGitlabProviderURL
|
||||
}
|
||||
var err error
|
||||
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "read_user"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: p.Scopes,
|
||||
}
|
||||
gitlabProvider := &GitlabProvider{
|
||||
IdentityProvider: p,
|
||||
}
|
||||
gitlabProvider.cb = circuit.NewBreaker(&circuit.Options{
|
||||
HalfOpenConcurrentRequests: 2,
|
||||
OnStateChange: gitlabProvider.cbStateChange,
|
||||
OnBackoff: gitlabProvider.cbBackoff,
|
||||
ShouldTripFunc: func(c circuit.Counts) bool { return c.ConsecutiveFailures >= 3 },
|
||||
ShouldResetFunc: func(c circuit.Counts) bool { return c.ConsecutiveSuccesses >= 6 },
|
||||
BackoffDurationFunc: circuit.ExponentialBackoffDuration(
|
||||
time.Duration(200)*time.Second,
|
||||
time.Duration(500)*time.Millisecond),
|
||||
})
|
||||
|
||||
return gitlabProvider, nil
|
||||
}
|
||||
|
||||
func (p *GitlabProvider) cbBackoff(duration time.Duration, reset time.Time) {
|
||||
log.Info().Dur("duration", duration).Msg("authenticate/providers/gitlab.cbBackoff")
|
||||
|
||||
}
|
||||
|
||||
func (p *GitlabProvider) cbStateChange(from, to circuit.State) {
|
||||
log.Info().Str("from", from.String()).Str("to", to.String()).Msg("authenticate/providers/gitlab.cbStateChange")
|
||||
}
|
|
@ -1,110 +0,0 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/circuit"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
const defaultGoogleProviderURL = "https://accounts.google.com"
|
||||
|
||||
// GoogleProvider is an implementation of the Provider interface.
|
||||
type GoogleProvider struct {
|
||||
*IdentityProvider
|
||||
cb *circuit.Breaker
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
}
|
||||
|
||||
// NewGoogleProvider returns a new GoogleProvider and sets the provider url endpoints.
|
||||
func NewGoogleProvider(p *IdentityProvider) (*GoogleProvider, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
if p.ProviderURL == "" {
|
||||
p.ProviderURL = defaultGoogleProviderURL
|
||||
}
|
||||
var err error
|
||||
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "profile", "email"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: p.Scopes,
|
||||
}
|
||||
|
||||
googleProvider := &GoogleProvider{
|
||||
IdentityProvider: p,
|
||||
}
|
||||
// google supports a revocation endpoint
|
||||
var claims struct {
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
|
||||
if err := p.provider.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
googleProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
googleProvider.cb = circuit.NewBreaker(&circuit.Options{
|
||||
HalfOpenConcurrentRequests: 2,
|
||||
OnStateChange: googleProvider.cbStateChange,
|
||||
OnBackoff: googleProvider.cbBackoff,
|
||||
ShouldTripFunc: func(c circuit.Counts) bool { return c.ConsecutiveFailures >= 3 },
|
||||
ShouldResetFunc: func(c circuit.Counts) bool { return c.ConsecutiveSuccesses >= 6 },
|
||||
BackoffDurationFunc: circuit.ExponentialBackoffDuration(
|
||||
time.Duration(200)*time.Second,
|
||||
time.Duration(500)*time.Millisecond),
|
||||
})
|
||||
|
||||
return googleProvider, nil
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) cbBackoff(duration time.Duration, reset time.Time) {
|
||||
log.Info().Dur("duration", duration).Msg("authenticate/providers/google.cbBackoff")
|
||||
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) cbStateChange(from, to circuit.State) {
|
||||
log.Info().Str("from", from.String()).Str("to", to.String()).Msg("authenticate/providers/google.cbStateChange")
|
||||
}
|
||||
|
||||
// Revoke revokes the access token a given session state.
|
||||
//
|
||||
// https://developers.google.com/identity/protocols/OAuth2WebServer#tokenrevoke
|
||||
// https://github.com/googleapis/google-api-dotnet-client/issues/1285
|
||||
func (p *GoogleProvider) Revoke(accessToken string) error {
|
||||
params := url.Values{}
|
||||
params.Add("token", accessToken)
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||
// Google requires access type offline
|
||||
func (p *GoogleProvider) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
|
||||
}
|
|
@ -1,112 +0,0 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/circuit"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
// defaultAzureProviderURL Users with both a personal Microsoft
|
||||
// account and a work or school account from Azure Active Directory (Azure AD)
|
||||
// an sign in to the application.
|
||||
const defaultAzureProviderURL = "https://login.microsoftonline.com/common"
|
||||
|
||||
// AzureProvider is an implementation of the Provider interface
|
||||
type AzureProvider struct {
|
||||
*IdentityProvider
|
||||
cb *circuit.Breaker
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
}
|
||||
|
||||
// NewAzureProvider returns a new AzureProvider and sets the provider url endpoints.
|
||||
// If non-"common" tenant is desired, ProviderURL must be set.
|
||||
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-protocols-oidc
|
||||
func NewAzureProvider(p *IdentityProvider) (*AzureProvider, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
if p.ProviderURL == "" {
|
||||
p.ProviderURL = defaultAzureProviderURL
|
||||
}
|
||||
log.Info().Msgf("provider url %s", p.ProviderURL)
|
||||
var err error
|
||||
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: p.Scopes,
|
||||
}
|
||||
|
||||
azureProvider := &AzureProvider{
|
||||
IdentityProvider: p,
|
||||
}
|
||||
// azure has a "end session endpoint"
|
||||
var claims struct {
|
||||
RevokeURL string `json:"end_session_endpoint"`
|
||||
}
|
||||
|
||||
if err := p.provider.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
azureProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
azureProvider.cb = circuit.NewBreaker(&circuit.Options{
|
||||
HalfOpenConcurrentRequests: 2,
|
||||
OnStateChange: azureProvider.cbStateChange,
|
||||
OnBackoff: azureProvider.cbBackoff,
|
||||
ShouldTripFunc: func(c circuit.Counts) bool { return c.ConsecutiveFailures >= 3 },
|
||||
ShouldResetFunc: func(c circuit.Counts) bool { return c.ConsecutiveSuccesses >= 6 },
|
||||
BackoffDurationFunc: circuit.ExponentialBackoffDuration(
|
||||
time.Duration(200)*time.Second,
|
||||
time.Duration(500)*time.Millisecond),
|
||||
})
|
||||
|
||||
return azureProvider, nil
|
||||
}
|
||||
|
||||
func (p *AzureProvider) cbBackoff(duration time.Duration, reset time.Time) {
|
||||
log.Info().Dur("duration", duration).Msg("authenticate/providers/azure.cbBackoff")
|
||||
|
||||
}
|
||||
|
||||
func (p *AzureProvider) cbStateChange(from, to circuit.State) {
|
||||
log.Info().Str("from", from.String()).Str("to", to.String()).Msg("authenticate/providers/azure.cbStateChange")
|
||||
}
|
||||
|
||||
// Revoke revokes the access token a given session state.
|
||||
//https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-protocols-oidc#send-a-sign-out-request
|
||||
func (p *AzureProvider) Revoke(token string) error {
|
||||
params := url.Values{}
|
||||
params.Add("token", token)
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||
func (p *AzureProvider) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"github.com/pomerium/pomerium/internal/sessions" // type Provider interface {
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// MockProvider provides a mocked implementation of the providers interface.
|
||||
type MockProvider struct {
|
||||
AuthenticateResponse sessions.SessionState
|
||||
AuthenticateError error
|
||||
ValidateResponse bool
|
||||
ValidateError error
|
||||
RefreshResponse *oauth2.Token
|
||||
RefreshError error
|
||||
RevokeError error
|
||||
GetSignInURLResponse string
|
||||
}
|
||||
|
||||
// Authenticate is a mocked providers function.
|
||||
func (mp MockProvider) Authenticate(code string) (*sessions.SessionState, error) {
|
||||
return &mp.AuthenticateResponse, mp.AuthenticateError
|
||||
}
|
||||
|
||||
// Validate is a mocked providers function.
|
||||
func (mp MockProvider) Validate(s string) (bool, error) {
|
||||
return mp.ValidateResponse, mp.ValidateError
|
||||
}
|
||||
|
||||
// Refresh is a mocked providers function.
|
||||
func (mp MockProvider) Refresh(s string) (*oauth2.Token, error) {
|
||||
return mp.RefreshResponse, mp.RefreshError
|
||||
}
|
||||
|
||||
// Revoke is a mocked providers function.
|
||||
func (mp MockProvider) Revoke(s string) error {
|
||||
return mp.RevokeError
|
||||
}
|
||||
|
||||
// GetSignInURL is a mocked providers function.
|
||||
func (mp MockProvider) GetSignInURL(s string) string { return mp.GetSignInURLResponse }
|
|
@ -1,40 +0,0 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OIDCProvider provides a standard, OpenID Connect implementation
|
||||
// of an authorization identity provider.
|
||||
// see : https://openid.net/specs/openid-connect-core-1_0.html
|
||||
type OIDCProvider struct {
|
||||
*IdentityProvider
|
||||
}
|
||||
|
||||
// NewOIDCProvider creates a new instance of an OpenID Connect provider.
|
||||
func NewOIDCProvider(p *IdentityProvider) (*OIDCProvider, error) {
|
||||
ctx := context.Background()
|
||||
if p.ProviderURL == "" {
|
||||
return nil, ErrMissingProviderURL
|
||||
}
|
||||
var err error
|
||||
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: p.Scopes,
|
||||
}
|
||||
return &OIDCProvider{IdentityProvider: p}, nil
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
// OktaProvider provides a standard, OpenID Connect implementation
|
||||
// of an authorization identity provider.
|
||||
type OktaProvider struct {
|
||||
*IdentityProvider
|
||||
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
}
|
||||
|
||||
// NewOktaProvider creates a new instance of an OpenID Connect provider.
|
||||
func NewOktaProvider(p *IdentityProvider) (*OktaProvider, error) {
|
||||
ctx := context.Background()
|
||||
if p.ProviderURL == "" {
|
||||
return nil, ErrMissingProviderURL
|
||||
}
|
||||
var err error
|
||||
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: p.Scopes,
|
||||
}
|
||||
|
||||
// okta supports a revocation endpoint
|
||||
var claims struct {
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
if err := p.provider.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oktaProvider := OktaProvider{IdentityProvider: p}
|
||||
|
||||
oktaProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oktaProvider, nil
|
||||
}
|
||||
|
||||
// Revoke revokes the access token a given session state.
|
||||
// https://developer.okta.com/docs/api/resources/oidc#revoke
|
||||
func (p *OktaProvider) Revoke(token string) error {
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
params.Add("token", token)
|
||||
params.Add("token_type_hint", "refresh_token")
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||
// Google requires access type offline
|
||||
func (p *OktaProvider) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
// OneLoginProvider provides a standard, OpenID Connect implementation
|
||||
// of an authorization identity provider.
|
||||
type OneLoginProvider struct {
|
||||
*IdentityProvider
|
||||
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
}
|
||||
|
||||
const defaultOneLoginProviderURL = "https://openid-connect.onelogin.com/oidc"
|
||||
|
||||
// NewOneLoginProvider creates a new instance of an OpenID Connect provider.
|
||||
func NewOneLoginProvider(p *IdentityProvider) (*OneLoginProvider, error) {
|
||||
ctx := context.Background()
|
||||
if p.ProviderURL == "" {
|
||||
p.ProviderURL = defaultOneLoginProviderURL
|
||||
}
|
||||
var err error
|
||||
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: p.Scopes,
|
||||
}
|
||||
|
||||
// okta supports a revocation endpoint
|
||||
var claims struct {
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
if err := p.provider.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
OneLoginProvider := OneLoginProvider{IdentityProvider: p}
|
||||
|
||||
OneLoginProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OneLoginProvider, nil
|
||||
}
|
||||
|
||||
// Revoke revokes the access token a given session state.
|
||||
// https://developers.onelogin.com/openid-connect/api/revoke-session
|
||||
func (p *OneLoginProvider) Revoke(token string) error {
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
params.Add("token", token)
|
||||
params.Add("token_type_hint", "access_token")
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
log.Error().Err(err).Msg("authenticate/providers: failed to revoke session")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||
func (p *OneLoginProvider) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
}
|
|
@ -1,192 +0,0 @@
|
|||
//go:generate protoc -I ../../proto/authenticate --go_out=plugins=grpc:../../proto/authenticate ../../proto/authenticate/authenticate.proto
|
||||
|
||||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
const (
|
||||
// AzureProviderName identifies the Azure identity provider
|
||||
AzureProviderName = "azure"
|
||||
// GitlabProviderName identifies the GitLab identity provider
|
||||
GitlabProviderName = "gitlab"
|
||||
// GoogleProviderName identifies the Google identity provider
|
||||
GoogleProviderName = "google"
|
||||
// OIDCProviderName identifies a generic OpenID connect provider
|
||||
OIDCProviderName = "oidc"
|
||||
// OktaProviderName identifies the Okta identity provider
|
||||
OktaProviderName = "okta"
|
||||
// OneLoginProviderName identifies the OneLogin identity provider
|
||||
OneLoginProviderName = "onelogin"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMissingProviderURL is returned when the CB state is half open and the requests count is over the cb maxRequests
|
||||
ErrMissingProviderURL = errors.New("proxy/providers: missing provider url")
|
||||
)
|
||||
|
||||
// Provider is an interface exposing functions necessary to interact with a given provider.
|
||||
type Provider interface {
|
||||
Authenticate(string) (*sessions.SessionState, error)
|
||||
Validate(string) (bool, error)
|
||||
Refresh(string) (*oauth2.Token, error)
|
||||
Revoke(string) error
|
||||
GetSignInURL(state string) string
|
||||
}
|
||||
|
||||
// New returns a new identity provider based given its name.
|
||||
// Returns an error if selected provided not found or if the identity provider is not known.
|
||||
func New(providerName string, pd *IdentityProvider) (p Provider, err error) {
|
||||
switch providerName {
|
||||
case AzureProviderName:
|
||||
p, err = NewAzureProvider(pd)
|
||||
case GitlabProviderName:
|
||||
p, err = NewGitlabProvider(pd)
|
||||
case GoogleProviderName:
|
||||
p, err = NewGoogleProvider(pd)
|
||||
case OIDCProviderName:
|
||||
p, err = NewOIDCProvider(pd)
|
||||
case OktaProviderName:
|
||||
p, err = NewOktaProvider(pd)
|
||||
case OneLoginProviderName:
|
||||
p, err = NewOneLoginProvider(pd)
|
||||
default:
|
||||
return nil, fmt.Errorf("authenticate: %q name not found", providerName)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// IdentityProvider contains the fields required for an OAuth 2.0 Authorization Request that
|
||||
// requests that the End-User be authenticated by the Authorization Server.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
type IdentityProvider struct {
|
||||
ProviderName string
|
||||
|
||||
RedirectURL *url.URL
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ProviderURL string
|
||||
Scopes []string
|
||||
SessionLifetimeTTL time.Duration
|
||||
|
||||
provider *oidc.Provider
|
||||
verifier *oidc.IDTokenVerifier
|
||||
oauth *oauth2.Config
|
||||
}
|
||||
|
||||
// GetSignInURL returns a URL to OAuth 2.0 provider's consent page
|
||||
// that asks for permissions for the required scopes explicitly.
|
||||
//
|
||||
// State is a token to protect the user from CSRF attacks. You must
|
||||
// always provide a non-empty string and validate that it matches the
|
||||
// the state query parameter on your redirect callback.
|
||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||
func (p *IdentityProvider) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state)
|
||||
}
|
||||
|
||||
// Validate validates a given session's from it's JWT token
|
||||
// The function verifies it's been signed by the provider, preforms
|
||||
// any additional checks depending on the Config, and returns the payload.
|
||||
//
|
||||
// Validate does NOT do nonce validation.
|
||||
// Validate does NOT check if revoked.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (p *IdentityProvider) Validate(idToken string) (bool, error) {
|
||||
ctx := context.Background()
|
||||
_, err := p.verifier.Verify(ctx, idToken)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers: failed to verify session state")
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Authenticate creates a session with an identity provider from a authorization code
|
||||
func (p *IdentityProvider) Authenticate(code string) (*sessions.SessionState, error) {
|
||||
ctx := context.Background()
|
||||
// convert authorization code into a token
|
||||
oauth2Token, err := p.oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/providers: failed token exchange: %v", err)
|
||||
}
|
||||
log.Info().
|
||||
Str("RefreshToken", oauth2Token.RefreshToken).
|
||||
Str("TokenType", oauth2Token.TokenType).
|
||||
Str("AccessToken", oauth2Token.AccessToken).
|
||||
Msg("Authenticate - oauth.Exchange")
|
||||
|
||||
//id_token contains claims about the authenticated user
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||
}
|
||||
|
||||
// Parse and verify ID Token payload.
|
||||
idToken, err := p.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/providers: could not verify id_token: %v", err)
|
||||
}
|
||||
|
||||
// Extract id_token which contains claims about the authenticated user
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
// parse claims from the raw, encoded jwt token
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("authenticate/providers: failed to parse id_token claims: %v", err)
|
||||
}
|
||||
|
||||
return &sessions.SessionState{
|
||||
IDToken: rawIDToken,
|
||||
AccessToken: oauth2Token.AccessToken,
|
||||
RefreshToken: oauth2Token.RefreshToken,
|
||||
RefreshDeadline: oauth2Token.Expiry,
|
||||
LifetimeDeadline: sessions.ExtendDeadline(p.SessionLifetimeTTL),
|
||||
Email: claims.Email,
|
||||
User: idToken.Subject,
|
||||
Groups: claims.Groups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh renews a user's session using an access token without reprompting the user.
|
||||
func (p *IdentityProvider) Refresh(refreshToken string) (*oauth2.Token, error) {
|
||||
if refreshToken == "" {
|
||||
return nil, errors.New("authenticate/providers: missing refresh token")
|
||||
}
|
||||
t := oauth2.Token{RefreshToken: refreshToken}
|
||||
newToken, err := p.oauth.TokenSource(context.Background(), &t).Token()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers.Refresh")
|
||||
return nil, err
|
||||
}
|
||||
log.Info().
|
||||
Str("RefreshToken", refreshToken).
|
||||
Str("newToken.AccessToken", newToken.AccessToken).
|
||||
Str("time.Until(newToken.Expiry)", time.Until(newToken.Expiry).String()).
|
||||
Msg("authenticate/providers.Refresh")
|
||||
|
||||
return newToken, nil
|
||||
}
|
||||
|
||||
// Revoke enables a user to revoke her token. If the identity provider supports revocation
|
||||
// the endpoint is available, otherwise an error is thrown.
|
||||
func (p *IdentityProvider) Revoke(token string) error {
|
||||
return errors.New("authenticate/providers: revoke not implemented")
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue