diff --git a/config/policy.go b/config/policy.go index 5a5d0193a..22c2fa97c 100644 --- a/config/policy.go +++ b/config/policy.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "encoding/base64" "encoding/hex" + "errors" "fmt" "net/http" "net/url" @@ -415,19 +416,16 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) { Body: pb.Response.GetBody(), } } else { - p.To = make(WeightedURLs, len(pb.To)) - for i, u := range pb.To { - u, err := urlutil.ParseAndValidateURL(u) - if err != nil { - return nil, err + var err error + p.To, err = ParseWeightedUrls(pb.To...) + if err != nil && !errors.Is(err, errEmptyUrls) { + return nil, fmt.Errorf("error parsing to URLs: %w", err) + } + + if len(pb.LoadBalancingWeights) == len(p.To) { + for i, w := range pb.LoadBalancingWeights { + p.To[i].LbWeight = w } - w := WeightedURL{ - URL: *u, - } - if len(pb.LoadBalancingWeights) == len(pb.To) { - w.LbWeight = pb.LoadBalancingWeights[i] - } - p.To[i] = w } } diff --git a/config/policy_test.go b/config/policy_test.go index 384a68e73..a137b4c9f 100644 --- a/config/policy_test.go +++ b/config/policy_test.go @@ -16,6 +16,7 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/config" ) func Test_PolicyValidate(t *testing.T) { @@ -213,6 +214,23 @@ func TestPolicy_Checksum(t *testing.T) { } } +func TestNewPolicyFromProto(t *testing.T) { + t.Parallel() + + p, err := NewPolicyFromProto(&config.Route{ + To: []string{"http://127.0.0.1:1234,1", "http://127.0.0.1:1234,2"}, + }) + assert.NoError(t, err) + assert.Equal(t, mustParseWeightedURLs(t, "http://127.0.0.1:1234,1", "http://127.0.0.1:1234,2"), p.To) + + p, err = NewPolicyFromProto(&config.Route{ + To: []string{"http://127.0.0.1:1234,1", "http://127.0.0.1:1234,2"}, + LoadBalancingWeights: []uint32{3, 4}, + }) + assert.NoError(t, err) + assert.Equal(t, mustParseWeightedURLs(t, "http://127.0.0.1:1234,3", "http://127.0.0.1:1234,4"), p.To) +} + func TestPolicy_FromToPb(t *testing.T) { t.Parallel() @@ -442,7 +460,7 @@ func TestPolicy_IsTCPUpstream(t *testing.T) { assert.False(t, p3.IsTCPUpstream()) } -func mustParseWeightedURLs(t testing.TB, urls ...string) []WeightedURL { +func mustParseWeightedURLs(t testing.TB, urls ...string) WeightedURLs { wu, err := ParseWeightedUrls(urls...) require.NoError(t, err) return wu