From dafead3122b65e021bfe54b6e34893ea524090bc Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 3 Jun 2022 12:32:01 -0600 Subject: [PATCH] postgres: fix CIDR query (#3389) --- pkg/storage/postgres/filter.go | 19 +++++++++++++++++-- pkg/storage/postgres/filter_test.go | 8 ++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/pkg/storage/postgres/filter.go b/pkg/storage/postgres/filter.go index 6df5d7596..61dc1fe8f 100644 --- a/pkg/storage/postgres/filter.go +++ b/pkg/storage/postgres/filter.go @@ -2,6 +2,7 @@ package postgres import ( "fmt" + "net/netip" "strings" "github.com/pomerium/pomerium/pkg/storage" @@ -39,8 +40,12 @@ func addFilterExpressionToQuery(query *string, args *[]interface{}, expr storage *args = append(*args, expr.Value) return nil case "$index": - *query += schemaName + "." + recordsTableName + ".index_cidr >>= " + fmt.Sprintf("$%d", len(*args)+1) - *args = append(*args, expr.Value) + if isCIDR(expr.Value) { + *query += schemaName + "." + recordsTableName + ".index_cidr >>= " + fmt.Sprintf("$%d", len(*args)+1) + *args = append(*args, expr.Value) + } else { + *query += " false " + } return nil default: return fmt.Errorf("unsupported equals filter: %v", expr.Fields) @@ -49,3 +54,13 @@ func addFilterExpressionToQuery(query *string, args *[]interface{}, expr storage return fmt.Errorf("unsupported filter expression: %T", expr) } } + +func isCIDR(value string) bool { + if _, err := netip.ParsePrefix(value); err == nil { + return true + } + if _, err := netip.ParseAddr(value); err == nil { + return true + } + return false +} diff --git a/pkg/storage/postgres/filter_test.go b/pkg/storage/postgres/filter_test.go index fce2056a0..3c67449ac 100644 --- a/pkg/storage/postgres/filter_test.go +++ b/pkg/storage/postgres/filter_test.go @@ -21,12 +21,16 @@ func TestAddFilterExpressionToQuery(t *testing.T) { Fields: []string{"$index"}, Value: "v2", }, + storage.EqualsFilterExpression{ + Fields: []string{"$index"}, + Value: "10.0.0.0/8", + }, }, storage.EqualsFilterExpression{ Fields: []string{"type"}, Value: "v3", }, }) - assert.Equal(t, "( ( pomerium.records.id = $1 OR pomerium.records.index_cidr >>= $2 ) AND pomerium.records.type = $3 )", query) - assert.Equal(t, []any{"v1", "v2", "v3"}, args) + assert.Equal(t, "( ( pomerium.records.id = $1 OR false OR pomerium.records.index_cidr >>= $2 ) AND pomerium.records.type = $3 )", query) + assert.Equal(t, []any{"v1", "10.0.0.0/8", "v3"}, args) }