mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-30 09:27:19 +02:00
authenticate: use gRPC for service endpoints (#39)
* authenticate: set cookie secure as default. * authenticate: remove single flight provider. * authenticate/providers: Rename “ProviderData” to “IdentityProvider” * authenticate/providers: Fixed an issue where scopes were not being overwritten * proxy/authenticate : http client code removed. * proxy: standardized session variable names between services. * docs: change basic docker-config to be an “all-in-one” example with no nginx load. * docs: nginx balanced docker compose example with intra-ingress settings. * license: attribution for adaptation of goji’s middleware pattern.
This commit is contained in:
parent
9ca3ff4fa2
commit
c886b924e7
54 changed files with 2184 additions and 1463 deletions
|
@ -18,18 +18,19 @@ import (
|
|||
)
|
||||
|
||||
var defaultOptions = &Options{
|
||||
CookieName: "_pomerium_authenticate",
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieRefresh: time.Duration(1) * time.Hour,
|
||||
SessionLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
CookieName: "_pomerium_authenticate",
|
||||
CookieHTTPOnly: true,
|
||||
CookieSecure: true,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieRefresh: time.Duration(30) * time.Minute,
|
||||
CookieLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
}
|
||||
|
||||
// Options permits the configuration of the authentication service
|
||||
// Options details the available configuration settings for the authenticate service
|
||||
type Options struct {
|
||||
RedirectURL *url.URL `envconfig:"REDIRECT_URL"`
|
||||
|
||||
// SharedKey is used to authenticate requests between services
|
||||
SharedKey string `envconfig:"SHARED_SECRET"`
|
||||
|
||||
// Coarse authorization based on user email domain
|
||||
|
@ -37,27 +38,25 @@ type Options struct {
|
|||
ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"`
|
||||
|
||||
// Session/Cookie management
|
||||
CookieName string
|
||||
CookieSecret string `envconfig:"COOKIE_SECRET"`
|
||||
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
|
||||
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
|
||||
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"`
|
||||
CookieSecure bool `envconfig:"COOKIE_SECURE"`
|
||||
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
|
||||
CookieName string
|
||||
CookieSecret string `envconfig:"COOKIE_SECRET"`
|
||||
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
|
||||
CookieSecure bool `envconfig:"COOKIE_SECURE"`
|
||||
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
|
||||
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
|
||||
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"`
|
||||
CookieLifetimeTTL time.Duration `envconfig:"COOKIE_LIFETIME"`
|
||||
|
||||
SessionLifetimeTTL time.Duration `envconfig:"SESSION_LIFETIME_TTL"`
|
||||
|
||||
// Authentication provider configuration variables as specified by RFC6749
|
||||
// IdentityProvider provider configuration variables as specified by RFC6749
|
||||
// See: https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749
|
||||
ClientID string `envconfig:"IDP_CLIENT_ID"`
|
||||
ClientSecret string `envconfig:"IDP_CLIENT_SECRET"`
|
||||
Provider string `envconfig:"IDP_PROVIDER"`
|
||||
ProviderURL string `envconfig:"IDP_PROVIDER_URL"`
|
||||
Scopes []string `envconfig:"IDP_SCOPE"`
|
||||
Scopes []string `envconfig:"IDP_SCOPES"`
|
||||
}
|
||||
|
||||
// OptionsFromEnvConfig builds the authentication service's configuration
|
||||
// options from provided environmental variables
|
||||
// OptionsFromEnvConfig builds the authenticate service's configuration environmental variables
|
||||
func OptionsFromEnvConfig() (*Options, error) {
|
||||
o := defaultOptions
|
||||
if err := envconfig.Process("", o); err != nil {
|
||||
|
@ -66,7 +65,7 @@ func OptionsFromEnvConfig() (*Options, error) {
|
|||
return o, nil
|
||||
}
|
||||
|
||||
// Validate checks to see if configuration values are valid for the authentication service.
|
||||
// Validate checks to see if configuration values are valid for the authenticate service.
|
||||
// The checks do not modify the internal state of the Option structure. Returns
|
||||
// on first error found.
|
||||
func (o *Options) Validate() error {
|
||||
|
@ -102,8 +101,7 @@ func (o *Options) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Authenticate is service for validating user authentication for proxied-requests
|
||||
// against third-party identity provider (IdP) services.
|
||||
// Authenticate validates a user's identity
|
||||
type Authenticate struct {
|
||||
RedirectURL *url.URL
|
||||
|
||||
|
@ -115,7 +113,7 @@ type Authenticate struct {
|
|||
|
||||
SharedKey string
|
||||
|
||||
SessionLifetimeTTL time.Duration
|
||||
CookieLifetimeTTL time.Duration
|
||||
|
||||
templates *template.Template
|
||||
csrfStore sessions.CSRFStore
|
||||
|
@ -125,7 +123,7 @@ type Authenticate struct {
|
|||
provider providers.Provider
|
||||
}
|
||||
|
||||
// New validates and creates a new authentication service from a configuration options.
|
||||
// New validates and creates a new authenticate service from a set of Options
|
||||
func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate, error) {
|
||||
if opts == nil {
|
||||
return nil, errors.New("options cannot be nil")
|
||||
|
@ -133,15 +131,9 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
|
|||
if err := opts.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedAuthCodeSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cipher, err := cryptutil.NewCipher([]byte(decodedAuthCodeSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedCookieSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
// checked by validate
|
||||
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
cipher, err := cryptutil.NewCipher([]byte(decodedCookieSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -183,25 +175,22 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func newProvider(opts *Options) (providers.Provider, error) {
|
||||
pd := &providers.ProviderData{
|
||||
pd := &providers.IdentityProvider{
|
||||
RedirectURL: opts.RedirectURL,
|
||||
ProviderName: opts.Provider,
|
||||
ProviderURL: opts.ProviderURL,
|
||||
ClientID: opts.ClientID,
|
||||
ClientSecret: opts.ClientSecret,
|
||||
SessionLifetimeTTL: opts.SessionLifetimeTTL,
|
||||
SessionLifetimeTTL: opts.CookieLifetimeTTL,
|
||||
Scopes: opts.Scopes,
|
||||
}
|
||||
np, err := providers.New(opts.Provider, pd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return providers.NewSingleFlightProvider(np), nil
|
||||
|
||||
return np, err
|
||||
}
|
||||
|
||||
func dotPrependDomains(d []string) []string {
|
||||
|
|
|
@ -8,23 +8,19 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
os.Clearenv()
|
||||
}
|
||||
|
||||
func testOptions() *Options {
|
||||
redirectURL, _ := url.Parse("https://example.com/oauth2/callback")
|
||||
return &Options{
|
||||
ProxyRootDomains: []string{"example.com"},
|
||||
AllowedDomains: []string{"example.com"},
|
||||
RedirectURL: redirectURL,
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieRefresh: time.Duration(1) * time.Hour,
|
||||
SessionLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
ProxyRootDomains: []string{"example.com"},
|
||||
AllowedDomains: []string{"example.com"},
|
||||
RedirectURL: redirectURL,
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieRefresh: time.Duration(1) * time.Hour,
|
||||
CookieLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,6 +77,8 @@ func TestOptions_Validate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestOptionsFromEnvConfig(t *testing.T) {
|
||||
os.Clearenv()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want *Options
|
||||
|
@ -91,7 +89,7 @@ func TestOptionsFromEnvConfig(t *testing.T) {
|
|||
{"good default, no env settings", defaultOptions, "", "", false},
|
||||
{"bad url", nil, "REDIRECT_URL", "%.rjlw", true},
|
||||
{"good duration", defaultOptions, "COOKIE_EXPIRE", "1m", false},
|
||||
{"bad duration", nil, "COOKIE_EXPIRE", "1sm", true},
|
||||
{"bad duration", nil, "COOKIE_REFRESH", "1sm", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -131,3 +129,65 @@ func Test_dotPrependDomains(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newProvider(t *testing.T) {
|
||||
redirectURL, _ := url.Parse("https://example.com/oauth3/callback")
|
||||
|
||||
goodOpts := &Options{
|
||||
RedirectURL: redirectURL,
|
||||
Provider: "google",
|
||||
ProviderURL: "",
|
||||
ClientID: "cllient-id",
|
||||
ClientSecret: "client-secret",
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", goodOpts, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := newProvider(tt.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("newProvider() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
// if !reflect.DeepEqual(got, tt.want) {
|
||||
// t.Errorf("newProvider() = %v, want %v", got, tt.want)
|
||||
// }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
good := testOptions()
|
||||
good.Provider = "google"
|
||||
|
||||
badRedirectURL := testOptions()
|
||||
badRedirectURL.RedirectURL = nil
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *Options
|
||||
// want *Authenticate
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", good, false},
|
||||
{"empty opts", nil, true},
|
||||
{"fails to validate", badRedirectURL, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := New(tt.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
// if !reflect.DeepEqual(got, tt.want) {
|
||||
// t.Errorf("New() = %v, want %v", got, tt.want)
|
||||
// }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
64
authenticate/grpc.go
Normal file
64
authenticate/grpc.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
// Authenticate takes an encrypted code, and returns the authentication result.
|
||||
func (p *Authenticate) Authenticate(ctx context.Context, in *pb.AuthenticateRequest) (*pb.AuthenticateReply, error) {
|
||||
session, err := sessions.UnmarshalSession(in.Code, p.cipher)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/grpc: %v", err)
|
||||
}
|
||||
expiryTimestamp, err := ptypes.TimestampProto(session.RefreshDeadline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pb.AuthenticateReply{
|
||||
AccessToken: session.AccessToken,
|
||||
RefreshToken: session.RefreshToken,
|
||||
IdToken: session.IDToken,
|
||||
User: session.User,
|
||||
Email: session.Email,
|
||||
Expiry: expiryTimestamp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate locally validates a JWT id token; does NOT do nonce or revokation validation.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (p *Authenticate) Validate(ctx context.Context, in *pb.ValidateRequest) (*pb.ValidateReply, error) {
|
||||
isValid, err := p.provider.Validate(in.IdToken)
|
||||
if err != nil {
|
||||
return &pb.ValidateReply{IsValid: false}, err
|
||||
}
|
||||
return &pb.ValidateReply{IsValid: isValid}, nil
|
||||
}
|
||||
|
||||
// Refresh renews a user's session checks if the session has been revoked using an access token
|
||||
// without reprompting the user.
|
||||
func (p *Authenticate) Refresh(ctx context.Context, in *pb.RefreshRequest) (*pb.RefreshReply, error) {
|
||||
newToken, err := p.provider.Refresh(in.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiryTimestamp, err := ptypes.TimestampProto(newToken.Expiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info().
|
||||
Str("session.AccessToken", newToken.AccessToken).
|
||||
Msg("authenticate: grpc: refresh: ok")
|
||||
|
||||
return &pb.RefreshReply{
|
||||
AccessToken: newToken.AccessToken,
|
||||
Expiry: expiryTimestamp,
|
||||
}, nil
|
||||
|
||||
}
|
170
authenticate/grpc_test.go
Normal file
170
authenticate/grpc_test.go
Normal file
|
@ -0,0 +1,170 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
||||
|
||||
// TestProvider is a mock provider
|
||||
type testProvider struct{}
|
||||
|
||||
func (tp *testProvider) Authenticate(s string) (*sessions.SessionState, error) {
|
||||
return &sessions.SessionState{}, nil
|
||||
}
|
||||
|
||||
func (tp *testProvider) Revoke(s string) error { return nil }
|
||||
func (tp *testProvider) GetSignInURL(s string) string { return "/signin" }
|
||||
func (tp *testProvider) Refresh(s string) (*oauth2.Token, error) {
|
||||
if s == "error" {
|
||||
return nil, errors.New("failed refresh")
|
||||
}
|
||||
if s == "bad time" {
|
||||
return &oauth2.Token{AccessToken: "updated", Expiry: time.Time{}}, nil
|
||||
}
|
||||
return &oauth2.Token{AccessToken: "updated", Expiry: fixedDate}, nil
|
||||
}
|
||||
func (tp *testProvider) Validate(token string) (bool, error) {
|
||||
if token == "good" {
|
||||
return true, nil
|
||||
} else if token == "error" {
|
||||
return false, errors.New("error validating id token")
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func TestAuthenticate_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "example", false, false},
|
||||
{"error", "error", false, true},
|
||||
{"not error", "not error", false, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tp := &testProvider{}
|
||||
p := &Authenticate{provider: tp}
|
||||
got, err := p.Validate(context.Background(), &pb.ValidateRequest{IdToken: tt.idToken})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authenticate.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got.IsValid, tt.want) {
|
||||
t.Errorf("Authenticate.Validate() = %v, want %v", got.IsValid, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate_Refresh(t *testing.T) {
|
||||
fixedProtoTime, err := ptypes.TimestampProto(fixedDate)
|
||||
if err != nil {
|
||||
t.Fatal("failed to parse timestamp")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
want *pb.RefreshReply
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "refresh-token", &pb.RefreshReply{AccessToken: "updated", Expiry: fixedProtoTime}, false},
|
||||
{"test error", "error", nil, true},
|
||||
// {"test bad time", "bad time", nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tp := &testProvider{}
|
||||
p := &Authenticate{provider: tp}
|
||||
|
||||
got, err := p.Refresh(context.Background(), &pb.RefreshRequest{RefreshToken: tt.refreshToken})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authenticate.Refresh() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authenticate.Refresh() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate_Authenticate(t *testing.T) {
|
||||
secret := cryptutil.GenerateKey()
|
||||
c, err := cryptutil.NewCipher([]byte(secret))
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
}
|
||||
newSecret := cryptutil.GenerateKey()
|
||||
c2, err := cryptutil.NewCipher([]byte(newSecret))
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
}
|
||||
lt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
|
||||
rt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
|
||||
vt := time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC()
|
||||
vtProto, err := ptypes.TimestampProto(rt)
|
||||
if err != nil {
|
||||
t.Fatal("failed to parse timestamp")
|
||||
}
|
||||
|
||||
want := &sessions.SessionState{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
LifetimeDeadline: lt,
|
||||
RefreshDeadline: rt,
|
||||
ValidDeadline: vt,
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
goodReply := &pb.AuthenticateReply{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
Expiry: vtProto,
|
||||
Email: "user@domain.com",
|
||||
User: "user"}
|
||||
ciphertext, err := sessions.MarshalSession(want, c)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be encode session: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cipher cryptutil.Cipher
|
||||
code string
|
||||
want *pb.AuthenticateReply
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", c, ciphertext, goodReply, false},
|
||||
{"bad cipher", c2, ciphertext, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Authenticate{cipher: tt.cipher}
|
||||
got, err := p.Authenticate(context.Background(), &pb.AuthenticateRequest{Code: tt.code})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authenticate.Authenticate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authenticate.Authenticate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,7 +2,6 @@ package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
|||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -17,16 +16,19 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
// securityHeaders corresponds to HTTP response headers related to security.
|
||||
// https://www.owasp.org/index.php/OWASP_Secure_Headers_Project#tab=Headers
|
||||
var securityHeaders = map[string]string{
|
||||
"Strict-Transport-Security": "max-age=31536000",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
|
||||
"Referrer-Policy": "Same-origin",
|
||||
"Content-Security-Policy": "default-src 'none'; style-src 'self' " +
|
||||
"'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
|
||||
"Referrer-Policy": "Same-origin",
|
||||
}
|
||||
|
||||
// Handler returns the Http.Handlers for authentication, callback, and refresh
|
||||
// Handler returns the Http.Handlers for authenticate, callback, and refresh
|
||||
func (p *Authenticate) Handler() http.Handler {
|
||||
// set up our standard middlewares
|
||||
stdMiddleware := middleware.NewChain()
|
||||
|
@ -52,8 +54,6 @@ func (p *Authenticate) Handler() http.Handler {
|
|||
middleware.ValidateSignature(p.SharedKey),
|
||||
middleware.ValidateRedirectURI(p.ProxyRootDomains))
|
||||
|
||||
validateClientSecretMiddleware := stdMiddleware.Append(middleware.ValidateClientSecret(p.SharedKey))
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/robots.txt", stdMiddleware.ThenFunc(p.RobotsTxt))
|
||||
// Identity Provider (IdP) callback endpoints and callbacks
|
||||
|
@ -61,11 +61,7 @@ func (p *Authenticate) Handler() http.Handler {
|
|||
mux.Handle("/oauth2/callback", stdMiddleware.ThenFunc(p.OAuthCallback))
|
||||
// authenticate-server endpoints
|
||||
mux.Handle("/sign_in", validateSignatureMiddleware.ThenFunc(p.SignIn))
|
||||
mux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(p.SignOut)) // "GET", "POST"
|
||||
mux.Handle("/profile", validateClientSecretMiddleware.ThenFunc(p.GetProfile)) // GET
|
||||
mux.Handle("/validate", validateClientSecretMiddleware.ThenFunc(p.ValidateToken)) // GET
|
||||
mux.Handle("/redeem", validateClientSecretMiddleware.ThenFunc(p.Redeem)) // POST
|
||||
mux.Handle("/refresh", validateClientSecretMiddleware.ThenFunc(p.Refresh)) //POST
|
||||
mux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(p.SignOut)) // GET POST
|
||||
|
||||
return mux
|
||||
}
|
||||
|
@ -76,43 +72,15 @@ func (p *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
|||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||
}
|
||||
|
||||
// SignInPage directs the user to the sign in page. Takes a `redirect_uri` param.
|
||||
func (p *Authenticate) SignInPage(w http.ResponseWriter, r *http.Request) {
|
||||
redirectURL := p.RedirectURL.ResolveReference(r.URL)
|
||||
|
||||
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
|
||||
t := struct {
|
||||
ProviderName string
|
||||
AllowedDomains []string
|
||||
Redirect string
|
||||
Destination string
|
||||
Version string
|
||||
}{
|
||||
ProviderName: p.provider.Data().ProviderName,
|
||||
AllowedDomains: p.AllowedDomains,
|
||||
Redirect: redirectURL.String(),
|
||||
Destination: destinationURL.Host,
|
||||
Version: version.FullVersion(),
|
||||
}
|
||||
log.FromRequest(r).Debug().
|
||||
Str("ProviderName", p.provider.Data().ProviderName).
|
||||
Str("Redirect", redirectURL.String()).
|
||||
Str("Destination", destinationURL.Host).
|
||||
Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")).
|
||||
Msg("authenticate: SignInPage")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
p.templates.ExecuteTemplate(w, "sign_in.html", t)
|
||||
}
|
||||
|
||||
func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) {
|
||||
session, err := p.sessionStore.LoadSession(r)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate: failed to load session")
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to load session")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ensure sessions lifetime has not expired
|
||||
// if long-lived lifetime has expired, clear session
|
||||
if session.LifetimePeriodExpired() {
|
||||
log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
|
@ -120,18 +88,14 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
|||
}
|
||||
// check if session refresh period is up
|
||||
if session.RefreshPeriodExpired() {
|
||||
ok, err := p.provider.RefreshSessionIfNeeded(session)
|
||||
newToken, err := p.provider.Refresh(session.RefreshToken)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to refresh session")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
log.FromRequest(r).Error().Msg("user unauthorized after refresh")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
// update refresh'd session in cookie
|
||||
session.AccessToken = newToken.AccessToken
|
||||
session.RefreshDeadline = newToken.Expiry
|
||||
err = p.sessionStore.SaveSession(w, r, session)
|
||||
if err != nil {
|
||||
// We refreshed the session successfully, but failed to save it.
|
||||
|
@ -143,9 +107,9 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
|||
}
|
||||
} else {
|
||||
// The session has not exceeded it's lifetime or requires refresh
|
||||
ok := p.provider.ValidateSessionState(session)
|
||||
if !ok {
|
||||
log.FromRequest(r).Error().Msg("invalid session state")
|
||||
ok, err := p.provider.Validate(session.IDToken)
|
||||
if !ok || err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("invalid session state")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
|
@ -157,68 +121,53 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
|||
}
|
||||
}
|
||||
|
||||
// authenticate really should not be in the business of authorization
|
||||
// todo(bdd) : remove when authorization module added
|
||||
if !p.Validator(session.Email) {
|
||||
log.FromRequest(r).Error().Msg("invalid email user")
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
log.Info().Msg("authenticate")
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
|
||||
// and if the user is not authenticated, it renders a sign in page.
|
||||
func (p *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
||||
// We attempt to authenticate the user. If they cannot be authenticated, we render a sign-in
|
||||
// page.
|
||||
//
|
||||
// If the user is authenticated, we redirect back to the proxy application
|
||||
// at the `redirect_uri`, with a temporary token.
|
||||
//
|
||||
// TODO: It is possible for a user to visit this page without a redirect destination.
|
||||
// Should we allow the user to authenticate? If not, what should be the proposed workflow?
|
||||
|
||||
session, err := p.authenticate(w, r)
|
||||
switch err {
|
||||
case nil:
|
||||
// User is authenticated, redirect back to proxy
|
||||
p.ProxyOAuthRedirect(w, r, session)
|
||||
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
||||
log.Debug().Err(err).Msg("authenticate.SignIn")
|
||||
log.Info().Err(err).Msg("authenticate.SignIn : expected failure")
|
||||
if err != http.ErrNoCookie {
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
}
|
||||
p.OAuthStart(w, r)
|
||||
|
||||
default:
|
||||
log.Error().Err(err).Msg("authenticate.SignIn")
|
||||
log.Error().Err(err).Msg("authenticate: unexpected sign in error")
|
||||
httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err))
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyOAuthRedirect redirects the user back to sso proxy's redirection endpoint.
|
||||
// ProxyOAuthRedirect redirects the user back to proxy's redirection endpoint.
|
||||
// This workflow corresponds to Section 3.1.2 of the OAuth2 RFC.
|
||||
// See https://tools.ietf.org/html/rfc6749#section-3.1.2 for more specific information.
|
||||
func (p *Authenticate) ProxyOAuthRedirect(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) {
|
||||
// This workflow corresponds to Section 3.1.2 of the OAuth2 RFC.
|
||||
// See https://tools.ietf.org/html/rfc6749#section-3.1.2 for more specific information.
|
||||
//
|
||||
// We redirect the user back to the proxy application's redirection endpoint; in the
|
||||
// sso proxy, this is the `/oauth/callback` endpoint.
|
||||
//
|
||||
// We must provide the proxy with a temporary authorization code via the `code` parameter,
|
||||
// which they can use to redeem an access token for subsequent API calls.
|
||||
//
|
||||
// We must also include the original `state` parameter received from the proxy application.
|
||||
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// original `state` parameter received from the proxy application.
|
||||
state := r.Form.Get("state")
|
||||
if state == "" {
|
||||
httputil.ErrorResponse(w, r, "no state parameter supplied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// redirect url of proxy-service
|
||||
redirectURI := r.Form.Get("redirect_uri")
|
||||
if redirectURI == "" {
|
||||
httputil.ErrorResponse(w, r, "no redirect_uri parameter supplied", http.StatusForbidden)
|
||||
|
@ -230,7 +179,7 @@ func (p *Authenticate) ProxyOAuthRedirect(w http.ResponseWriter, r *http.Request
|
|||
httputil.ErrorResponse(w, r, "malformed redirect_uri parameter passed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// encrypt session state as json blob
|
||||
encrypted, err := sessions.MarshalSession(session, p.cipher)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||
|
@ -267,6 +216,7 @@ func (p *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
case nil:
|
||||
break
|
||||
case http.ErrNoCookie: // if there's no cookie in the session we can just redirect
|
||||
log.Error().Err(err).Msg("authenticate.SignOut : no cookie")
|
||||
http.Redirect(w, r, redirectURI, http.StatusFound)
|
||||
return
|
||||
default:
|
||||
|
@ -277,7 +227,7 @@ func (p *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
err = p.provider.Revoke(session)
|
||||
err = p.provider.Revoke(session.AccessToken)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate.SignOut : error revoking session")
|
||||
p.SignOutPage(w, r, "An error occurred during sign out. Please try again.")
|
||||
|
@ -299,10 +249,11 @@ func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, messa
|
|||
|
||||
signature := r.Form.Get("sig")
|
||||
timestamp := r.Form.Get("ts")
|
||||
destinationURL, _ := url.Parse(redirectURI) //checked by middleware
|
||||
destinationURL, err := url.Parse(redirectURI)
|
||||
|
||||
// An error message indicates that an internal server error occurred
|
||||
if message != "" {
|
||||
if message != "" || err != nil {
|
||||
log.Error().Err(err).Msg("authenticate.SignOutPage")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
|
@ -326,8 +277,8 @@ func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, messa
|
|||
p.templates.ExecuteTemplate(w, "sign_out.html", t)
|
||||
}
|
||||
|
||||
// OAuthStart starts the authentication process by redirecting to the provider. It provides a
|
||||
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication.
|
||||
// OAuthStart starts the authenticate process by redirecting to the provider. It provides a
|
||||
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authenticate.
|
||||
func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||
authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri"))
|
||||
if err != nil {
|
||||
|
@ -339,12 +290,12 @@ func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
|||
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
||||
p.csrfStore.SetCSRF(w, r, nonce)
|
||||
|
||||
// confirm the redirect uri is from the root domain
|
||||
// verify redirect uri is from the root domain
|
||||
if !middleware.ValidRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
|
||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// confirm proxy url is from the root domain
|
||||
// verify proxy url is from the root domain
|
||||
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
|
||||
if err != nil || !middleware.ValidRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
|
||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
|
@ -359,51 +310,41 @@ func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// embed authenticate service's state as the base64'd nonce and authenticate callback url
|
||||
// concat base64'd nonce and authenticate url to make state
|
||||
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
|
||||
// build the provider sign in url
|
||||
signInURL := p.provider.GetSignInURL(state)
|
||||
|
||||
http.Redirect(w, r, signInURL, http.StatusFound)
|
||||
}
|
||||
|
||||
func (p *Authenticate) redeemCode(host, code string) (*sessions.SessionState, error) {
|
||||
session, err := p.provider.Redeem(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// if session.Email == "" {
|
||||
// return nil, fmt.Errorf("no email included in session")
|
||||
// }
|
||||
|
||||
return session, nil
|
||||
|
||||
}
|
||||
|
||||
// getOAuthCallback completes the oauth cycle from an identity provider's callback
|
||||
func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
// finish the oauth cycle
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: bad form on oauth callback")
|
||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
||||
}
|
||||
errorString := r.Form.Get("error")
|
||||
if errorString != "" {
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: provider returned error")
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: errorString}
|
||||
}
|
||||
code := r.Form.Get("code")
|
||||
if code == "" {
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: provider missing code")
|
||||
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"}
|
||||
}
|
||||
|
||||
session, err := p.redeemCode(r.Host, code)
|
||||
session, err := p.provider.Authenticate(code)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("error redeeming authentication code")
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: error redeeming authenticate code")
|
||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
||||
}
|
||||
|
||||
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("failed decoding state")
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed decoding state")
|
||||
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Couldn't decode state"}
|
||||
}
|
||||
s := strings.SplitN(string(bytes), ":", 2)
|
||||
|
@ -414,11 +355,12 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
redirect := s[1]
|
||||
c, err := p.csrfStore.GetCSRF(r)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: bad csrf")
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Missing CSRF token"}
|
||||
}
|
||||
p.csrfStore.ClearCSRF(w, r)
|
||||
if c.Value != nonce {
|
||||
log.FromRequest(r).Error().Err(err).Msg("CSRF token mismatch")
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: csrf mismatch")
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "CSRF failed"}
|
||||
}
|
||||
|
||||
|
@ -427,13 +369,10 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
|
||||
// Set cookie, or deny: validates the session email and group
|
||||
// - for p.Validator see validator.go#newValidatorImpl for more info
|
||||
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
|
||||
if !p.Validator(session.Email) {
|
||||
log.FromRequest(r).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "You don't have access"}
|
||||
}
|
||||
log.FromRequest(r).Info().Str("email", session.Email).Msg("authentication complete")
|
||||
err = p.sessionStore.SaveSession(w, r, session)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("internal error")
|
||||
|
@ -442,182 +381,22 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
return redirect, nil
|
||||
}
|
||||
|
||||
// OAuthCallback handles the callback from the provider, and returns an error response if there is an error.
|
||||
// If there is no error it will redirect to the redirect url.
|
||||
// OAuthCallback handles the callback from the identity provider. Displays an error page if there
|
||||
// was an error. If successful, redirects back to the proxy-service via the redirect-url.
|
||||
func (p *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
redirect, err := p.getOAuthCallback(w, r)
|
||||
switch h := err.(type) {
|
||||
case nil:
|
||||
break
|
||||
case httputil.HTTPError:
|
||||
log.Error().Err(err).Msg("authenticate: oauth callback error")
|
||||
httputil.ErrorResponse(w, r, h.Message, h.Code)
|
||||
return
|
||||
default:
|
||||
log.Error().Err(err).Msg("authenticate.OAuthCallback")
|
||||
log.Error().Err(err).Msg("authenticate: unexpected oauth callback error")
|
||||
httputil.ErrorResponse(w, r, "Internal Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// redirect back to the proxy-service
|
||||
http.Redirect(w, r, redirect, http.StatusFound)
|
||||
}
|
||||
|
||||
// Redeem has a signed access token, and provides the user information associated with the access token.
|
||||
func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
|
||||
// The auth code is redeemed by the sso proxy for an access token, refresh token,
|
||||
// expiration, and email.
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := sessions.UnmarshalSession(r.Form.Get("code"), p.cipher)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to unmarshal session")
|
||||
http.Error(w, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("empty session")
|
||||
http.Error(w, fmt.Sprintf("empty session: %s", err.Error()), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) {
|
||||
log.FromRequest(r).Error().Msg("expired session")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
http.Error(w, fmt.Sprintf("expired session"), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
response := struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Email string `json:"email"`
|
||||
}{
|
||||
AccessToken: session.AccessToken,
|
||||
RefreshToken: session.RefreshToken,
|
||||
IDToken: session.IDToken,
|
||||
ExpiresIn: int64(time.Until(session.RefreshDeadline).Seconds()),
|
||||
Email: session.Email,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("GAP-Auth", session.Email)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(jsonBytes)
|
||||
|
||||
}
|
||||
|
||||
// Refresh takes a refresh token and returns a new access token
|
||||
func (p *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken := r.Form.Get("refresh_token")
|
||||
if refreshToken == "" {
|
||||
http.Error(w, "Bad Request: No Refresh Token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, expiresIn, err := p.provider.RefreshAccessToken(refreshToken)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err))
|
||||
return
|
||||
}
|
||||
|
||||
response := struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
}{
|
||||
AccessToken: accessToken,
|
||||
ExpiresIn: int64(expiresIn.Seconds()),
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(bytes)
|
||||
}
|
||||
|
||||
// GetProfile gets a list of groups of which a user is a member.
|
||||
func (p *Authenticate) GetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
// The sso proxy sends the user's email to this endpoint to get a list of Google groups that
|
||||
// the email is a member of. The proxy will compare these groups to the list of allowed
|
||||
// groups for the upstream service the user is trying to access.
|
||||
|
||||
email := r.FormValue("email")
|
||||
if email == "" {
|
||||
http.Error(w, "no email address included", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// groupsFormValue := r.FormValue("groups")
|
||||
// allowedGroups := []string{}
|
||||
// if groupsFormValue != "" {
|
||||
// allowedGroups = strings.Split(groupsFormValue, ",")
|
||||
// }
|
||||
|
||||
// groups, err := p.provider.ValidateGroupMembership(email, allowedGroups)
|
||||
// if err != nil {
|
||||
// log.Error().Err(err).Msg("authenticate.GetProfile : error retrieving groups")
|
||||
// httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err))
|
||||
// return
|
||||
// }
|
||||
|
||||
response := struct {
|
||||
Email string `json:"email"`
|
||||
}{
|
||||
Email: email,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("error marshaling response: %s", err.Error()), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("GAP-Auth", email)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(jsonBytes)
|
||||
}
|
||||
|
||||
// ValidateToken validates the X-Access-Token from the header and returns an error response
|
||||
// if it's invalid
|
||||
func (p *Authenticate) ValidateToken(w http.ResponseWriter, r *http.Request) {
|
||||
accessToken := r.Header.Get("X-Access-Token")
|
||||
idToken := r.Header.Get("X-Id-Token")
|
||||
|
||||
if accessToken == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if idToken == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ok := p.provider.ValidateSessionState(&sessions.SessionState{
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
})
|
||||
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/providers"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
)
|
||||
|
||||
|
@ -19,7 +17,6 @@ func testAuthenticate() *Authenticate {
|
|||
auth.AllowedDomains = []string{"*"}
|
||||
auth.ProxyRootDomains = []string{"example.com"}
|
||||
auth.templates = templates.New()
|
||||
auth.provider = providers.NewTestProvider(auth.RedirectURL)
|
||||
return &auth
|
||||
}
|
||||
|
||||
|
@ -38,43 +35,5 @@ func TestAuthenticate_RobotsTxt(t *testing.T) {
|
|||
expected := fmt.Sprintf("User-agent: *\nDisallow: /")
|
||||
if rr.Body.String() != expected {
|
||||
t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate_SignInPage(t *testing.T) {
|
||||
auth := testAuthenticate()
|
||||
v := url.Values{}
|
||||
v.Set("request_uri", "this-is-a-test-uri")
|
||||
url := fmt.Sprintf("/signin?%s", v.Encode())
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
handler := http.HandlerFunc(auth.SignInPage)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||
}
|
||||
body := rr.Body.Bytes()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want bool
|
||||
}{
|
||||
{"provider name", auth.provider.Data().ProviderName, true},
|
||||
{"destination url", v.Encode(), true},
|
||||
{"shouldn't be found", "this string should not be in the body", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := bytes.Contains(body, []byte(tt.value)); got != tt.want {
|
||||
t.Errorf("handler body missing expected value %v", tt.value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
5
authenticate/providers/doc.go
Normal file
5
authenticate/providers/doc.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
// Package providers implements OpenID Connect client logic for the set of supported identity
|
||||
// providers.
|
||||
// OpenID Connect 1.0 is a simple identity layer on top of the OAuth 2.0 RFC6749 protocol.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html
|
||||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
|
@ -15,7 +15,7 @@ const defaultGitlabProviderURL = "https://gitlab.com"
|
|||
|
||||
// GitlabProvider is an implementation of the Provider interface.
|
||||
type GitlabProvider struct {
|
||||
*ProviderData
|
||||
*IdentityProvider
|
||||
cb *circuit.Breaker
|
||||
}
|
||||
|
||||
|
@ -32,7 +32,7 @@ type GitlabProvider struct {
|
|||
// - https://docs.gitlab.com/ee/integration/oauth_provider.html
|
||||
// - https://docs.gitlab.com/ee/api/oauth2.html
|
||||
// - https://gitlab.com/.well-known/openid-configuration
|
||||
func NewGitlabProvider(p *ProviderData) (*GitlabProvider, error) {
|
||||
func NewGitlabProvider(p *IdentityProvider) (*GitlabProvider, error) {
|
||||
ctx := context.Background()
|
||||
if p.ProviderURL == "" {
|
||||
p.ProviderURL = defaultGitlabProviderURL
|
||||
|
@ -42,8 +42,9 @@ func NewGitlabProvider(p *ProviderData) (*GitlabProvider, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "read_user"}
|
||||
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "read_user"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
|
@ -53,7 +54,7 @@ func NewGitlabProvider(p *ProviderData) (*GitlabProvider, error) {
|
|||
Scopes: p.Scopes,
|
||||
}
|
||||
gitlabProvider := &GitlabProvider{
|
||||
ProviderData: p,
|
||||
IdentityProvider: p,
|
||||
}
|
||||
gitlabProvider.cb = circuit.NewBreaker(&circuit.Options{
|
||||
HalfOpenConcurrentRequests: 2,
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"github.com/pomerium/pomerium/authenticate/circuit"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
|
@ -19,14 +18,14 @@ const defaultGoogleProviderURL = "https://accounts.google.com"
|
|||
|
||||
// GoogleProvider is an implementation of the Provider interface.
|
||||
type GoogleProvider struct {
|
||||
*ProviderData
|
||||
*IdentityProvider
|
||||
cb *circuit.Breaker
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
}
|
||||
|
||||
// NewGoogleProvider returns a new GoogleProvider and sets the provider url endpoints.
|
||||
func NewGoogleProvider(p *ProviderData) (*GoogleProvider, error) {
|
||||
func NewGoogleProvider(p *IdentityProvider) (*GoogleProvider, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
if p.ProviderURL == "" {
|
||||
|
@ -37,18 +36,20 @@ func NewGoogleProvider(p *ProviderData) (*GoogleProvider, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "profile", "email"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
Scopes: p.Scopes,
|
||||
}
|
||||
|
||||
googleProvider := &GoogleProvider{
|
||||
ProviderData: p,
|
||||
IdentityProvider: p,
|
||||
}
|
||||
// google supports a revocation endpoint
|
||||
var claims struct {
|
||||
|
@ -91,9 +92,9 @@ func (p *GoogleProvider) cbStateChange(from, to circuit.State) {
|
|||
//
|
||||
// https://developers.google.com/identity/protocols/OAuth2WebServer#tokenrevoke
|
||||
// https://github.com/googleapis/google-api-dotnet-client/issues/1285
|
||||
func (p *GoogleProvider) Revoke(s *sessions.SessionState) error {
|
||||
func (p *GoogleProvider) Revoke(accessToken string) error {
|
||||
params := url.Values{}
|
||||
params.Add("token", s.AccessToken)
|
||||
params.Add("token", accessToken)
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
return err
|
||||
|
@ -105,4 +106,5 @@ func (p *GoogleProvider) Revoke(s *sessions.SessionState) error {
|
|||
// Google requires access type offline
|
||||
func (p *GoogleProvider) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
|
||||
}
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"github.com/pomerium/pomerium/authenticate/circuit"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
|
@ -22,7 +21,7 @@ const defaultAzureProviderURL = "https://login.microsoftonline.com/common"
|
|||
|
||||
// AzureProvider is an implementation of the Provider interface
|
||||
type AzureProvider struct {
|
||||
*ProviderData
|
||||
*IdentityProvider
|
||||
cb *circuit.Breaker
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
|
@ -31,7 +30,7 @@ type AzureProvider struct {
|
|||
// NewAzureProvider returns a new AzureProvider and sets the provider url endpoints.
|
||||
// If non-"common" tenant is desired, ProviderURL must be set.
|
||||
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-protocols-oidc
|
||||
func NewAzureProvider(p *ProviderData) (*AzureProvider, error) {
|
||||
func NewAzureProvider(p *IdentityProvider) (*AzureProvider, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
if p.ProviderURL == "" {
|
||||
|
@ -43,18 +42,20 @@ func NewAzureProvider(p *ProviderData) (*AzureProvider, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
Scopes: p.Scopes,
|
||||
}
|
||||
|
||||
azureProvider := &AzureProvider{
|
||||
ProviderData: p,
|
||||
IdentityProvider: p,
|
||||
}
|
||||
// azure has a "end session endpoint"
|
||||
var claims struct {
|
||||
|
@ -95,9 +96,9 @@ func (p *AzureProvider) cbStateChange(from, to circuit.State) {
|
|||
|
||||
// Revoke revokes the access token a given session state.
|
||||
//https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-protocols-oidc#send-a-sign-out-request
|
||||
func (p *AzureProvider) Revoke(s *sessions.SessionState) error {
|
||||
func (p *AzureProvider) Revoke(token string) error {
|
||||
params := url.Values{}
|
||||
params.Add("token", s.AccessToken)
|
||||
params.Add("token", token)
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
return err
|
||||
|
|
|
@ -12,11 +12,11 @@ import (
|
|||
// of an authorization identity provider.
|
||||
// see : https://openid.net/specs/openid-connect-core-1_0.html
|
||||
type OIDCProvider struct {
|
||||
*ProviderData
|
||||
*IdentityProvider
|
||||
}
|
||||
|
||||
// NewOIDCProvider creates a new instance of an OpenID Connect provider.
|
||||
func NewOIDCProvider(p *ProviderData) (*OIDCProvider, error) {
|
||||
func NewOIDCProvider(p *IdentityProvider) (*OIDCProvider, error) {
|
||||
ctx := context.Background()
|
||||
if p.ProviderURL == "" {
|
||||
return nil, errors.New("missing required provider url")
|
||||
|
@ -26,13 +26,16 @@ func NewOIDCProvider(p *ProviderData) (*OIDCProvider, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
Scopes: p.Scopes,
|
||||
}
|
||||
return &OIDCProvider{ProviderData: p}, nil
|
||||
return &OIDCProvider{IdentityProvider: p}, nil
|
||||
}
|
||||
|
|
|
@ -9,21 +9,20 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
// OktaProvider provides a standard, OpenID Connect implementation
|
||||
// of an authorization identity provider.
|
||||
type OktaProvider struct {
|
||||
*ProviderData
|
||||
*IdentityProvider
|
||||
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
}
|
||||
|
||||
// NewOktaProvider creates a new instance of an OpenID Connect provider.
|
||||
func NewOktaProvider(p *ProviderData) (*OktaProvider, error) {
|
||||
func NewOktaProvider(p *IdentityProvider) (*OktaProvider, error) {
|
||||
ctx := context.Background()
|
||||
if p.ProviderURL == "" {
|
||||
return nil, errors.New("missing required provider url")
|
||||
|
@ -33,24 +32,26 @@ func NewOktaProvider(p *ProviderData) (*OktaProvider, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(p.Scopes) == 0 {
|
||||
p.Scopes = []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}
|
||||
}
|
||||
p.verifier = p.provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
Scopes: p.Scopes,
|
||||
}
|
||||
oktaProvider := OktaProvider{ProviderData: p}
|
||||
|
||||
// okta supports a revocation endpoint
|
||||
var claims struct {
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
|
||||
if err := p.provider.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oktaProvider := OktaProvider{IdentityProvider: p}
|
||||
|
||||
oktaProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if err != nil {
|
||||
|
@ -61,11 +62,11 @@ func NewOktaProvider(p *ProviderData) (*OktaProvider, error) {
|
|||
|
||||
// Revoke revokes the access token a given session state.
|
||||
// https://developer.okta.com/docs/api/resources/oidc#revoke
|
||||
func (p *OktaProvider) Revoke(s *sessions.SessionState) error {
|
||||
func (p *OktaProvider) Revoke(token string) error {
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
params.Add("token", s.IDToken)
|
||||
params.Add("token", token)
|
||||
params.Add("token_type_hint", "refresh_token")
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
|
@ -73,3 +74,9 @@ func (p *OktaProvider) Revoke(s *sessions.SessionState) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||
// Google requires access type offline
|
||||
func (p *OktaProvider) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
}
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
//go:generate protoc -I ../../proto/authenticate --go_out=plugins=grpc:../../proto/authenticate ../../proto/authenticate/authenticate.proto
|
||||
|
||||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
|
@ -32,21 +31,17 @@ const (
|
|||
|
||||
// Provider is an interface exposing functions necessary to interact with a given provider.
|
||||
type Provider interface {
|
||||
Data() *ProviderData
|
||||
Redeem(string) (*sessions.SessionState, error)
|
||||
ValidateSessionState(*sessions.SessionState) bool
|
||||
Authenticate(string) (*sessions.SessionState, error)
|
||||
Validate(string) (bool, error)
|
||||
Refresh(string) (*oauth2.Token, error)
|
||||
Revoke(string) error
|
||||
GetSignInURL(state string) string
|
||||
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
|
||||
Revoke(*sessions.SessionState) error
|
||||
RefreshAccessToken(string) (string, time.Duration, error)
|
||||
}
|
||||
|
||||
// New returns a new identity provider based given its name.
|
||||
// Returns an error if selected provided not found or if the provider fails to instantiate.
|
||||
func New(provider string, pd *ProviderData) (Provider, error) {
|
||||
var err error
|
||||
var p Provider
|
||||
switch provider {
|
||||
// Returns an error if selected provided not found or if the identity provider is not known.
|
||||
func New(providerName string, pd *IdentityProvider) (p Provider, err error) {
|
||||
switch providerName {
|
||||
case AzureProviderName:
|
||||
p, err = NewAzureProvider(pd)
|
||||
case GitlabProviderName:
|
||||
|
@ -58,7 +53,7 @@ func New(provider string, pd *ProviderData) (Provider, error) {
|
|||
case OktaProviderName:
|
||||
p, err = NewOktaProvider(pd)
|
||||
default:
|
||||
return nil, fmt.Errorf("authenticate: provider %q not found", provider)
|
||||
return nil, fmt.Errorf("authenticate: %q name not found", providerName)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -66,11 +61,13 @@ func New(provider string, pd *ProviderData) (Provider, error) {
|
|||
return p, nil
|
||||
}
|
||||
|
||||
// ProviderData holds the fields associated with providers
|
||||
// necessary to implement the Provider interface.
|
||||
type ProviderData struct {
|
||||
// IdentityProvider contains the fields required for an OAuth 2.0 Authorization Request that
|
||||
// requests that the End-User be authenticated by the Authorization Server.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
type IdentityProvider struct {
|
||||
ProviderName string
|
||||
|
||||
RedirectURL *url.URL
|
||||
ProviderName string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ProviderURL string
|
||||
|
@ -82,118 +79,50 @@ type ProviderData struct {
|
|||
oauth *oauth2.Config
|
||||
}
|
||||
|
||||
// Data returns a ProviderData.
|
||||
func (p *ProviderData) Data() *ProviderData { return p }
|
||||
|
||||
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||
func (p *ProviderData) GetSignInURL(state string) string {
|
||||
// GetSignInURL returns a URL to OAuth 2.0 provider's consent page
|
||||
// that asks for permissions for the required scopes explicitly.
|
||||
//
|
||||
// State is a token to protect the user from CSRF attacks. You must
|
||||
// always provide a non-empty string and validate that it matches the
|
||||
// the state query parameter on your redirect callback.
|
||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||
func (p *IdentityProvider) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state)
|
||||
}
|
||||
|
||||
// ValidateSessionState validates a given session's from it's JWT token
|
||||
// Validate validates a given session's from it's JWT token
|
||||
// The function verifies it's been signed by the provider, preforms
|
||||
// any additional checks depending on the Config, and returns the payload.
|
||||
//
|
||||
// ValidateSessionState does NOT do nonce validation.
|
||||
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
// Validate does NOT do nonce validation.
|
||||
// Validate does NOT check if revoked.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (p *IdentityProvider) Validate(idToken string) (bool, error) {
|
||||
ctx := context.Background()
|
||||
_, err := p.verifier.Verify(ctx, s.IDToken)
|
||||
_, err := p.verifier.Verify(ctx, idToken)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers: failed to verify session state")
|
||||
return false
|
||||
return false, err
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Redeem creates a session with an identity provider from a authorization code
|
||||
func (p *ProviderData) Redeem(code string) (*sessions.SessionState, error) {
|
||||
ctx := context.Background()
|
||||
// convert authorization code into a token
|
||||
token, err := p.oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/providers: failed token exchange: %v", err)
|
||||
}
|
||||
s, err := p.createSessionState(ctx, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/providers: unable to update session: %v", err)
|
||||
}
|
||||
|
||||
// check if provider has info endpoint, try to hit that and gather more info
|
||||
// especially useful if initial request did not contain email
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
var claims struct {
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
}
|
||||
|
||||
if err := p.provider.Claims(&claims); err != nil || claims.UserInfoURL == "" {
|
||||
log.Error().Err(err).Msg("authenticate/providers: failed retrieving userinfo_endpoint")
|
||||
} else {
|
||||
// userinfo endpoint found and valid
|
||||
userInfo, err := p.UserInfo(ctx, claims.UserInfoURL, oauth2.StaticTokenSource(token))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/providers: can't parse userinfo_endpoint: %v", err)
|
||||
}
|
||||
s.Email = userInfo.Email
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded will refresh the session state if it's deadline is expired
|
||||
func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
if !sessionRefreshRequired(s) {
|
||||
log.Debug().Msg("authenticate/providers: session refresh not needed")
|
||||
return false, nil
|
||||
}
|
||||
origExpiration := s.RefreshDeadline
|
||||
err := p.redeemRefreshToken(s)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("authenticate/providers: couldn't refresh token: %v", err)
|
||||
}
|
||||
|
||||
log.Debug().Time("NewDeadline", s.RefreshDeadline).Time("OldDeadline", origExpiration).Msgf("authenticate/providers refreshed")
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *ProviderData) redeemRefreshToken(s *sessions.SessionState) error {
|
||||
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 1")
|
||||
// Authenticate creates a session with an identity provider from a authorization code
|
||||
func (p *IdentityProvider) Authenticate(code string) (*sessions.SessionState, error) {
|
||||
ctx := context.Background()
|
||||
t := &oauth2.Token{
|
||||
RefreshToken: s.RefreshToken,
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 3")
|
||||
|
||||
// returns a TokenSource automatically refreshing it as necessary using the provided context
|
||||
token, err := p.oauth.TokenSource(ctx, t).Token()
|
||||
// convert authorization code into a token
|
||||
oauth2Token, err := p.oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate/providers: failed to get token: %v", err)
|
||||
return nil, fmt.Errorf("authenticate/providers: failed token exchange: %v", err)
|
||||
}
|
||||
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 4")
|
||||
|
||||
newSession, err := p.createSessionState(ctx, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate/providers: unable to update session: %v", err)
|
||||
}
|
||||
s.AccessToken = newSession.AccessToken
|
||||
s.IDToken = newSession.IDToken
|
||||
s.RefreshToken = newSession.RefreshToken
|
||||
s.RefreshDeadline = newSession.RefreshDeadline
|
||||
s.Email = newSession.Email
|
||||
|
||||
log.Info().
|
||||
Str("AccessToken", s.AccessToken).
|
||||
Str("IdToken", s.IDToken).
|
||||
Time("RefreshDeadline", s.RefreshDeadline).
|
||||
Str("RefreshToken", s.RefreshToken).
|
||||
Str("Email", s.Email).
|
||||
Msg("authenticate/providers.redeemRefreshToken")
|
||||
Str("RefreshToken", oauth2Token.RefreshToken).
|
||||
Str("TokenType", oauth2Token.TokenType).
|
||||
Str("AccessToken", oauth2Token.AccessToken).
|
||||
Msg("Authenticate - oauth.Exchange")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProviderData) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
//id_token contains claims about the authenticated user
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||
}
|
||||
|
@ -204,142 +133,51 @@ func (p *ProviderData) createSessionState(ctx context.Context, token *oauth2.Tok
|
|||
return nil, fmt.Errorf("authenticate/providers: could not verify id_token: %v", err)
|
||||
}
|
||||
|
||||
// Extract custom claims.
|
||||
// Extract id_token which contains claims about the authenticated user
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
Verified *bool `json:"email_verified"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
// parse claims from the raw, encoded jwt token
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("authenticate/providers: failed to parse id_token claims: %v", err)
|
||||
}
|
||||
log.Debug().
|
||||
Str("AccessToken", token.AccessToken).
|
||||
Str("IDToken", rawIDToken).
|
||||
Str("claims.Email", claims.Email).
|
||||
Str("RefreshToken", token.RefreshToken).
|
||||
Str("idToken.Subject", idToken.Subject).
|
||||
Str("idToken.Nonce", idToken.Nonce).
|
||||
Str("RefreshDeadline", idToken.Expiry.String()).
|
||||
Str("LifetimeDeadline", idToken.Expiry.String()).
|
||||
Msg("authenticate/providers.createSessionState")
|
||||
|
||||
return &sessions.SessionState{
|
||||
AccessToken: token.AccessToken,
|
||||
IDToken: rawIDToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
RefreshDeadline: idToken.Expiry,
|
||||
LifetimeDeadline: idToken.Expiry,
|
||||
AccessToken: oauth2Token.AccessToken,
|
||||
RefreshToken: oauth2Token.RefreshToken,
|
||||
RefreshDeadline: oauth2Token.Expiry,
|
||||
LifetimeDeadline: sessions.ExtendDeadline(p.SessionLifetimeTTL),
|
||||
Email: claims.Email,
|
||||
User: idToken.Subject,
|
||||
Groups: claims.Groups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshAccessToken allows the service to refresh an access token without
|
||||
// prompting the user for permission.
|
||||
func (p *ProviderData) RefreshAccessToken(refreshToken string) (string, time.Duration, error) {
|
||||
// Refresh renews a user's session using an access token without reprompting the user.
|
||||
func (p *IdentityProvider) Refresh(refreshToken string) (*oauth2.Token, error) {
|
||||
if refreshToken == "" {
|
||||
return "", 0, errors.New("authenticate/providers: missing refresh token")
|
||||
}
|
||||
ctx := context.Background()
|
||||
c := oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{TokenURL: p.ProviderURL},
|
||||
return nil, errors.New("authenticate/providers: missing refresh token")
|
||||
}
|
||||
t := oauth2.Token{RefreshToken: refreshToken}
|
||||
ts := c.TokenSource(ctx, &t)
|
||||
newToken, err := p.oauth.TokenSource(context.Background(), &t).Token()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers.Refresh")
|
||||
return nil, err
|
||||
}
|
||||
log.Info().
|
||||
Str("RefreshToken", refreshToken).
|
||||
Msg("authenticate/providers.RefreshAccessToken")
|
||||
Str("newToken.AccessToken", newToken.AccessToken).
|
||||
Str("time.Until(newToken.Expiry)", time.Until(newToken.Expiry).String()).
|
||||
Msg("authenticate/providers.Refresh")
|
||||
|
||||
newToken, err := ts.Token()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers.RefreshAccessToken")
|
||||
return "", 0, err
|
||||
}
|
||||
return newToken.AccessToken, time.Until(newToken.Expiry), nil
|
||||
return newToken, nil
|
||||
}
|
||||
|
||||
// Revoke enables a user to revoke her token. If the identity provider supports revocation
|
||||
// the endpoint is available, otherwise an error is thrown.
|
||||
func (p *ProviderData) Revoke(s *sessions.SessionState) error {
|
||||
func (p *IdentityProvider) Revoke(token string) error {
|
||||
return errors.New("authenticate/providers: revoke not implemented")
|
||||
}
|
||||
|
||||
func sessionRefreshRequired(s *sessions.SessionState) bool {
|
||||
return s == nil || s.RefreshDeadline.After(time.Now()) || s.RefreshToken == ""
|
||||
}
|
||||
|
||||
// UserInfo represents the OpenID Connect userinfo claims.
|
||||
// see: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
type UserInfo struct {
|
||||
// Stanard OIDC User fields
|
||||
Subject string `json:"sub"`
|
||||
Profile string `json:"profile"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
// custom claims
|
||||
Name string `json:"name"` // google, gitlab
|
||||
GivenName string `json:"given_name"` // google
|
||||
FamilyName string `json:"family_name"` // google
|
||||
Picture string `json:"picture"` // google,gitlab
|
||||
Locale string `json:"locale"` // google
|
||||
Groups []string `json:"groups"` // gitlab
|
||||
|
||||
claims []byte
|
||||
}
|
||||
|
||||
// Claims unmarshals the raw JSON object claims into the provided object.
|
||||
func (u *UserInfo) Claims(v interface{}) error {
|
||||
if u.claims == nil {
|
||||
return errors.New("authenticate/providers: claims not set")
|
||||
}
|
||||
return json.Unmarshal(u.claims, v)
|
||||
}
|
||||
|
||||
// UserInfo uses the token source to query the provider's user info endpoint.
|
||||
func (p *ProviderData) UserInfo(ctx context.Context, uri string, tokenSource oauth2.TokenSource) (*UserInfo, error) {
|
||||
if uri == "" {
|
||||
return nil, errors.New("authenticate/providers: user info endpoint is not supported by this provider")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, uri, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/providers: create GET request: %v", err)
|
||||
}
|
||||
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/providers: get access token: %v", err)
|
||||
}
|
||||
token.SetAuthHeader(req)
|
||||
|
||||
resp, err := doRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Status, body)
|
||||
}
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("authenticate/providers failed to decode userinfo: %v", err)
|
||||
}
|
||||
userInfo.claims = body
|
||||
return &userInfo, nil
|
||||
}
|
||||
|
||||
func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
client := http.DefaultClient
|
||||
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
|
||||
client = c
|
||||
}
|
||||
return client.Do(req.WithContext(ctx))
|
||||
}
|
||||
|
|
|
@ -1,142 +0,0 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Provider = &SingleFlightProvider{}
|
||||
)
|
||||
|
||||
// ErrUnexpectedReturnType is an error for an unexpected return type
|
||||
var (
|
||||
ErrUnexpectedReturnType = errors.New("received unexpected return type from single flight func call")
|
||||
)
|
||||
|
||||
// SingleFlightProvider middleware provider that multiple requests for the same object
|
||||
// to be processed as a single request. This is often called request collapsing or coalesce.
|
||||
// This middleware leverages the golang singlelflight provider, with modifications for metrics.
|
||||
//
|
||||
// It's common among HTTP reverse proxy cache servers such as nginx, Squid or Varnish - they all call it something else but works similarly.
|
||||
//
|
||||
// * https://www.varnish-cache.org/docs/3.0/tutorial/handling_misbehaving_servers.html
|
||||
// * http://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_cache_lock
|
||||
// * http://wiki.squid-cache.org/Features/CollapsedForwarding
|
||||
type SingleFlightProvider struct {
|
||||
provider Provider
|
||||
|
||||
single *singleflight.Group
|
||||
}
|
||||
|
||||
// NewSingleFlightProvider returns a new SingleFlightProvider
|
||||
func NewSingleFlightProvider(provider Provider) *SingleFlightProvider {
|
||||
return &SingleFlightProvider{
|
||||
provider: provider,
|
||||
single: &singleflight.Group{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SingleFlightProvider) do(endpoint, key string, fn func() (interface{}, error)) (interface{}, error) {
|
||||
compositeKey := fmt.Sprintf("%s/%s", endpoint, key)
|
||||
resp, _, err := p.single.Do(compositeKey, fn)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Data returns the provider data
|
||||
func (p *SingleFlightProvider) Data() *ProviderData {
|
||||
return p.provider.Data()
|
||||
}
|
||||
|
||||
// Redeem wraps the provider's Redeem function.
|
||||
func (p *SingleFlightProvider) Redeem(code string) (*sessions.SessionState, error) {
|
||||
return p.provider.Redeem(code)
|
||||
}
|
||||
|
||||
// ValidateSessionState wraps the provider's ValidateSessionState in a single flight call.
|
||||
func (p *SingleFlightProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
response, err := p.do("ValidateSessionState", s.AccessToken, func() (interface{}, error) {
|
||||
valid := p.provider.ValidateSessionState(s)
|
||||
return valid, nil
|
||||
})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
valid, ok := response.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return valid
|
||||
}
|
||||
|
||||
// GetSignInURL calls the provider's GetSignInURL function.
|
||||
func (p *SingleFlightProvider) GetSignInURL(finalRedirect string) string {
|
||||
return p.provider.GetSignInURL(finalRedirect)
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded wraps the provider's RefreshSessionIfNeeded function in a single flight
|
||||
// call.
|
||||
func (p *SingleFlightProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
response, err := p.do("RefreshSessionIfNeeded", s.RefreshToken, func() (interface{}, error) {
|
||||
return p.provider.RefreshSessionIfNeeded(s)
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
r, ok := response.(bool)
|
||||
if !ok {
|
||||
return false, ErrUnexpectedReturnType
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Revoke wraps the provider's Revoke function in a single flight call.
|
||||
func (p *SingleFlightProvider) Revoke(s *sessions.SessionState) error {
|
||||
_, err := p.do("Revoke", s.AccessToken, func() (interface{}, error) {
|
||||
err := p.provider.Revoke(s)
|
||||
return nil, err
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// RefreshAccessToken wraps the provider's RefreshAccessToken function in a single flight call.
|
||||
func (p *SingleFlightProvider) RefreshAccessToken(refreshToken string) (string, time.Duration, error) {
|
||||
type Response struct {
|
||||
AccessToken string
|
||||
ExpiresIn time.Duration
|
||||
}
|
||||
response, err := p.do("RefreshAccessToken", refreshToken, func() (interface{}, error) {
|
||||
accessToken, expiresIn, err := p.provider.RefreshAccessToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Response{
|
||||
AccessToken: accessToken,
|
||||
ExpiresIn: expiresIn,
|
||||
}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
r, ok := response.(*Response)
|
||||
if !ok {
|
||||
return "", 0, ErrUnexpectedReturnType
|
||||
}
|
||||
|
||||
return r.AccessToken, r.ExpiresIn, nil
|
||||
}
|
||||
|
||||
// // Stop calls the provider's stop function
|
||||
// func (p *SingleFlightProvider) Stop() {
|
||||
// p.provider.Stop()
|
||||
// }
|
|
@ -1,77 +0,0 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
// TestProvider is a test implementation of the Provider interface.
|
||||
type TestProvider struct {
|
||||
*ProviderData
|
||||
|
||||
ValidToken bool
|
||||
ValidGroup bool
|
||||
SignInURL string
|
||||
Refresh bool
|
||||
RefreshFunc func(string) (string, time.Duration, error)
|
||||
RefreshError error
|
||||
Session *sessions.SessionState
|
||||
RedeemError error
|
||||
RevokeError error
|
||||
Groups []string
|
||||
GroupsError error
|
||||
GroupsCall int
|
||||
}
|
||||
|
||||
// NewTestProvider creates a new mock test provider.
|
||||
func NewTestProvider(providerURL *url.URL) *TestProvider {
|
||||
host := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: providerURL.Host,
|
||||
Path: "/authorize",
|
||||
}
|
||||
return &TestProvider{
|
||||
ProviderData: &ProviderData{
|
||||
ProviderName: "Test Provider",
|
||||
ProviderURL: host.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateSessionState returns the mock provider's ValidToken field value.
|
||||
func (tp *TestProvider) ValidateSessionState(*sessions.SessionState) bool {
|
||||
return tp.ValidToken
|
||||
}
|
||||
|
||||
// GetSignInURL returns the mock provider's SignInURL field value.
|
||||
func (tp *TestProvider) GetSignInURL(finalRedirect string) string {
|
||||
return tp.SignInURL
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded returns the mock provider's Refresh value, or an error.
|
||||
func (tp *TestProvider) RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) {
|
||||
return tp.Refresh, tp.RefreshError
|
||||
}
|
||||
|
||||
// RefreshAccessToken returns the mock provider's refresh access token information
|
||||
func (tp *TestProvider) RefreshAccessToken(s string) (string, time.Duration, error) {
|
||||
return tp.RefreshFunc(s)
|
||||
}
|
||||
|
||||
// Revoke returns nil
|
||||
func (tp *TestProvider) Revoke(*sessions.SessionState) error {
|
||||
return tp.RevokeError
|
||||
}
|
||||
|
||||
// ValidateGroupMembership returns the mock provider's GroupsError if not nil, or the Groups field value.
|
||||
func (tp *TestProvider) ValidateGroupMembership(string, []string) ([]string, error) {
|
||||
return tp.Groups, tp.GroupsError
|
||||
}
|
||||
|
||||
// Redeem returns the mock provider's Session and RedeemError field value.
|
||||
func (tp *TestProvider) Redeem(code string) (*sessions.SessionState, error) {
|
||||
return tp.Session, tp.RedeemError
|
||||
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue