package redisutil import ( "crypto/tls" "fmt" "net" "net/url" "strconv" "strings" "time" "github.com/go-redis/redis/v8" "github.com/scylladb/go-set" ) var ( standardSchemes = set.NewStringSet("redis", "rediss", "unix") clusterSchemes = set.NewStringSet( "redis+cluster", "redis-cluster", "rediss+cluster", "rediss-cluster", "redis+clusters", "redis-clusters", ) sentinelSchemes = set.NewStringSet( "redis+sentinel", "redis-sentinel", "rediss+sentinel", "rediss-sentinel", "redis+sentinels", "redis-sentinels", ) sentinelClusterSchemes = set.NewStringSet( "redis+sentinel+cluster", "redis-sentinel-cluster", "rediss+sentinel+cluster", "rediss-sentinel-cluster", "redis+sentinels+cluster", "redis-sentinels-cluster", "redis+sentinel+clusters", "redis-sentinel-clusters", ) tlsSchemes = set.NewStringSet( "rediss", "rediss+cluster", "rediss-cluster", "redis+clusters", "redis-clusters", "rediss+sentinel", "rediss-sentinel", "redis+sentinels", "redis-sentinels", "rediss+sentinel+cluster", "rediss-sentinel-cluster", "redis+sentinels+cluster", "redis-sentinels-cluster", "redis+sentinel+clusters", "redis-sentinel-clusters", ) ) // NewClientFromURL creates a new redis client by parsing the raw URL. func NewClientFromURL(rawURL string, tlsConfig *tls.Config) (redis.UniversalClient, error) { u, err := url.Parse(rawURL) if err != nil { return nil, err } switch { case standardSchemes.Has(u.Scheme): opts, err := redis.ParseURL(rawURL) if err != nil { return nil, err } // when using TLS, the TLS config will not be set to nil, in which case we replace it with our own if opts.TLSConfig != nil { opts.TLSConfig = tlsConfig } return redis.NewClient(opts), nil case clusterSchemes.Has(u.Scheme): opts, err := ParseClusterURL(rawURL) if err != nil { return nil, err } if opts.TLSConfig != nil { opts.TLSConfig = tlsConfig } return redis.NewClusterClient(opts), nil case sentinelSchemes.Has(u.Scheme): opts, err := ParseSentinelURL(rawURL) if err != nil { return nil, err } if opts.TLSConfig != nil { opts.TLSConfig = tlsConfig } return redis.NewFailoverClient(opts), nil case sentinelClusterSchemes.Has(u.Scheme): opts, err := ParseSentinelURL(rawURL) if err != nil { return nil, err } if opts.TLSConfig != nil { opts.TLSConfig = tlsConfig } return redis.NewFailoverClusterClient(opts), nil default: return nil, fmt.Errorf("unsupported URL scheme: %s", u.Scheme) } } // ParseClusterURL parses a redis-cluster URL. Format is: // // redis+cluster://[username:password@]host:port[,host2:port2,...]/[?param1=value1[¶m2=value=2&...]] // // Additionally TLS is supported with rediss+cluster, or redis+clusters. Supported query params: // // max_redirects: int // read_only: bool // route_by_latency: bool // route_randomly: bool // max_retries: int // min_retry_backoff: duration // max_retry_backoff: duration // dial_timeout: duration // read_timeout: duration // write_timeout: duration // pool_size: int // min_idle_conns: int // max_conn_age: duration // pool_timeout: duration // idle_timeout: duration // idle_check_frequency: duration // func ParseClusterURL(rawurl string) (*redis.ClusterOptions, error) { u, err := url.Parse(rawurl) if err != nil { return nil, err } opts := new(redis.ClusterOptions) hostParts := strings.Split(u.Host, ",") for _, hostPart := range hostParts { host, port, err := net.SplitHostPort(hostPart) if err != nil { host = hostPart port = "6379" } opts.Addrs = append(opts.Addrs, net.JoinHostPort(host, port)) } q := u.Query() if err := parseIntParam(&opts.MaxRedirects, q, "max_redirects"); err != nil { return nil, err } if err := parseBoolParam(&opts.ReadOnly, q, "read_only"); err != nil { return nil, err } if err := parseBoolParam(&opts.RouteByLatency, q, "route_by_latency"); err != nil { return nil, err } if err := parseBoolParam(&opts.RouteRandomly, q, "route_randomly"); err != nil { return nil, err } if ui := u.User; ui != nil { opts.Username = ui.Username() opts.Password, _ = ui.Password() } if err := parseIntParam(&opts.MaxRetries, q, "max_retries"); err != nil { return nil, err } if err := parseDurationParam(&opts.MinRetryBackoff, q, "min_retry_backoff"); err != nil { return nil, err } if err := parseDurationParam(&opts.MaxRetryBackoff, q, "max_retry_backoff"); err != nil { return nil, err } if err := parseDurationParam(&opts.DialTimeout, q, "dial_timeout"); err != nil { return nil, err } if err := parseDurationParam(&opts.ReadTimeout, q, "read_timeout"); err != nil { return nil, err } if err := parseDurationParam(&opts.WriteTimeout, q, "write_timeout"); err != nil { return nil, err } if err := parseIntParam(&opts.PoolSize, q, "pool_size"); err != nil { return nil, err } if err := parseIntParam(&opts.MinIdleConns, q, "min_idle_conns"); err != nil { return nil, err } if err := parseDurationParam(&opts.MaxConnAge, q, "max_conn_age"); err != nil { return nil, err } if err := parseDurationParam(&opts.PoolTimeout, q, "pool_timeout"); err != nil { return nil, err } if err := parseDurationParam(&opts.IdleTimeout, q, "idle_timeout"); err != nil { return nil, err } if err := parseDurationParam(&opts.IdleCheckFrequency, q, "idle_check_frequency"); err != nil { return nil, err } if tlsSchemes.Has(u.Scheme) { opts.TLSConfig = &tls.Config{} //nolint } return opts, nil } // ParseSentinelURL parses a redis-sentinel URL. Format is based on https://github.com/exponea/redis-sentinel-url: // // redis+sentinel://[:password@]host:port[,host2:port2,...][/service_name[/db]][?param1=value1[¶m2=value=2&...]] // // Additionally TLS is supported with rediss+sentinel, or redis+sentinels. Supported query params: // // slave_only: bool // use_disconnected_slaves: bool // query_sentinel_randomly: bool // username: string (username for redis connection) // password: string (password for redis connection) // max_retries: int // min_retry_backoff: duration // max_retry_backoff: duration // dial_timeout: duration // read_timeout: duration // write_timeout: duration // pool_size: int // min_idle_conns: int // max_conn_age: duration // pool_timeout: duration // idle_timeout: duration // idle_check_frequency: duration // func ParseSentinelURL(rawurl string) (*redis.FailoverOptions, error) { u, err := url.Parse(rawurl) if err != nil { return nil, err } opts := new(redis.FailoverOptions) pathParts := strings.Split(u.Path, "/") if len(pathParts) > 1 { opts.MasterName = pathParts[1] } if len(pathParts) > 2 { opts.DB, err = strconv.Atoi(pathParts[2]) if err != nil { return nil, fmt.Errorf("invalid database: %w", err) } } hostParts := strings.Split(u.Host, ",") for _, hostPart := range hostParts { host, port, err := net.SplitHostPort(hostPart) if err != nil { host = hostPart port = "26379" // "By default Sentinel runs using TCP port 26379" } opts.SentinelAddrs = append(opts.SentinelAddrs, net.JoinHostPort(host, port)) } if u.User != nil { opts.SentinelPassword, _ = u.User.Password() } q := u.Query() if err := parseBoolParam(&opts.SlaveOnly, q, "slave_only"); err != nil { return nil, err } if err := parseBoolParam(&opts.RouteByLatency, q, "route_by_latency"); err != nil { return nil, err } if err := parseBoolParam(&opts.RouteRandomly, q, "route_randomly"); err != nil { return nil, err } if err := parseBoolParam(&opts.UseDisconnectedSlaves, q, "use_disconnected_slaves"); err != nil { return nil, err } opts.Username = q.Get("username") opts.Password = q.Get("password") if err := parseIntParam(&opts.MaxRetries, q, "max_retries"); err != nil { return nil, err } if err := parseDurationParam(&opts.MinRetryBackoff, q, "min_retry_backoff"); err != nil { return nil, err } if err := parseDurationParam(&opts.MaxRetryBackoff, q, "max_retry_backoff"); err != nil { return nil, err } if err := parseDurationParam(&opts.DialTimeout, q, "dial_timeout"); err != nil { return nil, err } if err := parseDurationParam(&opts.ReadTimeout, q, "read_timeout"); err != nil { return nil, err } if err := parseDurationParam(&opts.WriteTimeout, q, "write_timeout"); err != nil { return nil, err } if err := parseIntParam(&opts.PoolSize, q, "pool_size"); err != nil { return nil, err } if err := parseIntParam(&opts.MinIdleConns, q, "min_idle_conns"); err != nil { return nil, err } if err := parseDurationParam(&opts.MaxConnAge, q, "max_conn_age"); err != nil { return nil, err } if err := parseDurationParam(&opts.PoolTimeout, q, "pool_timeout"); err != nil { return nil, err } if err := parseDurationParam(&opts.IdleTimeout, q, "idle_timeout"); err != nil { return nil, err } if err := parseDurationParam(&opts.IdleCheckFrequency, q, "idle_check_frequency"); err != nil { return nil, err } if tlsSchemes.Has(u.Scheme) { opts.TLSConfig = &tls.Config{} //nolint } return opts, nil } func parseBoolParam(dst *bool, values url.Values, name string) error { v := values.Get(name) if v == "" { return nil } b, err := strconv.ParseBool(v) if err != nil { return fmt.Errorf("invalid %s: %w", name, err) } *dst = b return nil } func parseIntParam(dst *int, values url.Values, name string) error { v := values.Get(name) if v == "" { return nil } i, err := strconv.Atoi(v) if err != nil { return fmt.Errorf("invalid %s: %w", name, err) } *dst = i return nil } func parseDurationParam(dst *time.Duration, values url.Values, name string) error { v := values.Get(name) if v == "" { return nil } d, err := time.ParseDuration(v) if err != nil { return fmt.Errorf("invalid %s: %w", name, err) } *dst = d return nil }