mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
* chore(deps): bump github.com/golangci/golangci-lint Bumps [github.com/golangci/golangci-lint](https://github.com/golangci/golangci-lint) from 1.48.0 to 1.50.0. - [Release notes](https://github.com/golangci/golangci-lint/releases) - [Changelog](https://github.com/golangci/golangci-lint/blob/master/CHANGELOG.md) - [Commits](https://github.com/golangci/golangci-lint/compare/v1.48.0...v1.50.0) --- updated-dependencies: - dependency-name: github.com/golangci/golangci-lint dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * lint Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
356 lines
9.5 KiB
Go
356 lines
9.5 KiB
Go
package redisutil
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
|
|
"github.com/pomerium/pomerium/internal/sets"
|
|
)
|
|
|
|
var (
|
|
standardSchemes = sets.NewHash("redis", "rediss", "unix")
|
|
clusterSchemes = sets.NewHash(
|
|
"redis+cluster", "redis-cluster",
|
|
"rediss+cluster", "rediss-cluster",
|
|
"redis+clusters", "redis-clusters",
|
|
)
|
|
sentinelSchemes = sets.NewHash(
|
|
"redis+sentinel", "redis-sentinel",
|
|
"rediss+sentinel", "rediss-sentinel",
|
|
"redis+sentinels", "redis-sentinels",
|
|
)
|
|
sentinelClusterSchemes = sets.NewHash(
|
|
"redis+sentinel+cluster", "redis-sentinel-cluster",
|
|
"rediss+sentinel+cluster", "rediss-sentinel-cluster",
|
|
"redis+sentinels+cluster", "redis-sentinels-cluster",
|
|
"redis+sentinel+clusters", "redis-sentinel-clusters",
|
|
)
|
|
tlsSchemes = sets.NewHash(
|
|
"rediss",
|
|
"rediss+cluster", "rediss-cluster",
|
|
"redis+clusters", "redis-clusters",
|
|
"rediss+sentinel", "rediss-sentinel",
|
|
"redis+sentinels", "redis-sentinels",
|
|
"rediss+sentinel+cluster", "rediss-sentinel-cluster",
|
|
"redis+sentinels+cluster", "redis-sentinels-cluster",
|
|
"redis+sentinel+clusters", "redis-sentinel-clusters",
|
|
)
|
|
)
|
|
|
|
// NewClientFromURL creates a new redis client by parsing the raw URL.
|
|
func NewClientFromURL(rawURL string, tlsConfig *tls.Config) (redis.UniversalClient, error) {
|
|
u, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch {
|
|
case standardSchemes.Has(u.Scheme):
|
|
opts, err := redis.ParseURL(rawURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// when using TLS, the TLS config will not be set to nil, in which case we replace it with our own
|
|
if opts.TLSConfig != nil {
|
|
opts.TLSConfig = tlsConfig
|
|
}
|
|
return redis.NewClient(opts), nil
|
|
|
|
case clusterSchemes.Has(u.Scheme):
|
|
opts, err := ParseClusterURL(rawURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if opts.TLSConfig != nil {
|
|
opts.TLSConfig = tlsConfig
|
|
}
|
|
return redis.NewClusterClient(opts), nil
|
|
|
|
case sentinelSchemes.Has(u.Scheme):
|
|
opts, err := ParseSentinelURL(rawURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if opts.TLSConfig != nil {
|
|
opts.TLSConfig = tlsConfig
|
|
}
|
|
return redis.NewFailoverClient(opts), nil
|
|
|
|
case sentinelClusterSchemes.Has(u.Scheme):
|
|
opts, err := ParseSentinelURL(rawURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if opts.TLSConfig != nil {
|
|
opts.TLSConfig = tlsConfig
|
|
}
|
|
return redis.NewFailoverClusterClient(opts), nil
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unsupported URL scheme: %s", u.Scheme)
|
|
}
|
|
}
|
|
|
|
// ParseClusterURL parses a redis-cluster URL. Format is:
|
|
//
|
|
// redis+cluster://[username:password@]host:port[,host2:port2,...]/[?param1=value1[¶m2=value=2&...]]
|
|
//
|
|
// Additionally TLS is supported with rediss+cluster, or redis+clusters. Supported query params:
|
|
//
|
|
// max_redirects: int
|
|
// read_only: bool
|
|
// route_by_latency: bool
|
|
// route_randomly: bool
|
|
// max_retries: int
|
|
// min_retry_backoff: duration
|
|
// max_retry_backoff: duration
|
|
// dial_timeout: duration
|
|
// read_timeout: duration
|
|
// write_timeout: duration
|
|
// pool_size: int
|
|
// min_idle_conns: int
|
|
// max_conn_age: duration
|
|
// pool_timeout: duration
|
|
// idle_timeout: duration
|
|
// idle_check_frequency: duration
|
|
func ParseClusterURL(rawurl string) (*redis.ClusterOptions, error) {
|
|
u, err := url.Parse(rawurl)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
opts := new(redis.ClusterOptions)
|
|
|
|
hostParts := strings.Split(u.Host, ",")
|
|
for _, hostPart := range hostParts {
|
|
host, port, err := net.SplitHostPort(hostPart)
|
|
if err != nil {
|
|
host = hostPart
|
|
port = "6379"
|
|
}
|
|
opts.Addrs = append(opts.Addrs,
|
|
net.JoinHostPort(host, port))
|
|
}
|
|
|
|
q := u.Query()
|
|
if err := parseIntParam(&opts.MaxRedirects, q, "max_redirects"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseBoolParam(&opts.ReadOnly, q, "read_only"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseBoolParam(&opts.RouteByLatency, q, "route_by_latency"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseBoolParam(&opts.RouteRandomly, q, "route_randomly"); err != nil {
|
|
return nil, err
|
|
}
|
|
if ui := u.User; ui != nil {
|
|
opts.Username = ui.Username()
|
|
opts.Password, _ = ui.Password()
|
|
}
|
|
if err := parseIntParam(&opts.MaxRetries, q, "max_retries"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.MinRetryBackoff, q, "min_retry_backoff"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.MaxRetryBackoff, q, "max_retry_backoff"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.DialTimeout, q, "dial_timeout"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.ReadTimeout, q, "read_timeout"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.WriteTimeout, q, "write_timeout"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseIntParam(&opts.PoolSize, q, "pool_size"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseIntParam(&opts.MinIdleConns, q, "min_idle_conns"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.MaxConnAge, q, "max_conn_age"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.PoolTimeout, q, "pool_timeout"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.IdleTimeout, q, "idle_timeout"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.IdleCheckFrequency, q, "idle_check_frequency"); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if tlsSchemes.Has(u.Scheme) {
|
|
opts.TLSConfig = &tls.Config{} //nolint
|
|
}
|
|
|
|
return opts, nil
|
|
}
|
|
|
|
// ParseSentinelURL parses a redis-sentinel URL. Format is based on https://github.com/exponea/redis-sentinel-url:
|
|
//
|
|
// redis+sentinel://[:password@]host:port[,host2:port2,...][/service_name[/db]][?param1=value1[¶m2=value=2&...]]
|
|
//
|
|
// Additionally TLS is supported with rediss+sentinel, or redis+sentinels. Supported query params:
|
|
//
|
|
// slave_only: bool
|
|
// use_disconnected_slaves: bool
|
|
// query_sentinel_randomly: bool
|
|
// username: string (username for redis connection)
|
|
// password: string (password for redis connection)
|
|
// max_retries: int
|
|
// min_retry_backoff: duration
|
|
// max_retry_backoff: duration
|
|
// dial_timeout: duration
|
|
// read_timeout: duration
|
|
// write_timeout: duration
|
|
// pool_size: int
|
|
// min_idle_conns: int
|
|
// max_conn_age: duration
|
|
// pool_timeout: duration
|
|
// idle_timeout: duration
|
|
// idle_check_frequency: duration
|
|
func ParseSentinelURL(rawurl string) (*redis.FailoverOptions, error) {
|
|
u, err := url.Parse(rawurl)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
opts := new(redis.FailoverOptions)
|
|
|
|
pathParts := strings.Split(u.Path, "/")
|
|
if len(pathParts) > 1 {
|
|
opts.MasterName = pathParts[1]
|
|
}
|
|
if len(pathParts) > 2 {
|
|
opts.DB, err = strconv.Atoi(pathParts[2])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid database: %w", err)
|
|
}
|
|
}
|
|
|
|
hostParts := strings.Split(u.Host, ",")
|
|
for _, hostPart := range hostParts {
|
|
host, port, err := net.SplitHostPort(hostPart)
|
|
if err != nil {
|
|
host = hostPart
|
|
port = "26379" // "By default Sentinel runs using TCP port 26379"
|
|
}
|
|
opts.SentinelAddrs = append(opts.SentinelAddrs,
|
|
net.JoinHostPort(host, port))
|
|
}
|
|
|
|
if u.User != nil {
|
|
opts.SentinelPassword, _ = u.User.Password()
|
|
}
|
|
|
|
q := u.Query()
|
|
if err := parseBoolParam(&opts.SlaveOnly, q, "slave_only"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseBoolParam(&opts.RouteByLatency, q, "route_by_latency"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseBoolParam(&opts.RouteRandomly, q, "route_randomly"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseBoolParam(&opts.UseDisconnectedSlaves, q, "use_disconnected_slaves"); err != nil {
|
|
return nil, err
|
|
}
|
|
opts.Username = q.Get("username")
|
|
opts.Password = q.Get("password")
|
|
if err := parseIntParam(&opts.MaxRetries, q, "max_retries"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.MinRetryBackoff, q, "min_retry_backoff"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.MaxRetryBackoff, q, "max_retry_backoff"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.DialTimeout, q, "dial_timeout"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.ReadTimeout, q, "read_timeout"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.WriteTimeout, q, "write_timeout"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseIntParam(&opts.PoolSize, q, "pool_size"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseIntParam(&opts.MinIdleConns, q, "min_idle_conns"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.MaxConnAge, q, "max_conn_age"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.PoolTimeout, q, "pool_timeout"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.IdleTimeout, q, "idle_timeout"); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := parseDurationParam(&opts.IdleCheckFrequency, q, "idle_check_frequency"); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if tlsSchemes.Has(u.Scheme) {
|
|
opts.TLSConfig = &tls.Config{} //nolint
|
|
}
|
|
|
|
return opts, nil
|
|
}
|
|
|
|
func parseBoolParam(dst *bool, values url.Values, name string) error {
|
|
v := values.Get(name)
|
|
if v == "" {
|
|
return nil
|
|
}
|
|
b, err := strconv.ParseBool(v)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid %s: %w", name, err)
|
|
}
|
|
*dst = b
|
|
return nil
|
|
}
|
|
|
|
func parseIntParam(dst *int, values url.Values, name string) error {
|
|
v := values.Get(name)
|
|
if v == "" {
|
|
return nil
|
|
}
|
|
i, err := strconv.Atoi(v)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid %s: %w", name, err)
|
|
}
|
|
*dst = i
|
|
return nil
|
|
}
|
|
|
|
func parseDurationParam(dst *time.Duration, values url.Values, name string) error {
|
|
v := values.Get(name)
|
|
if v == "" {
|
|
return nil
|
|
}
|
|
d, err := time.ParseDuration(v)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid %s: %w", name, err)
|
|
}
|
|
*dst = d
|
|
return nil
|
|
}
|