postgres: fix CIDR query (#3389)

This commit is contained in:
Caleb Doxsey 2022-06-03 12:32:01 -06:00 committed by GitHub
parent 2b11ef10f5
commit dafead3122
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 4 deletions

View file

@ -2,6 +2,7 @@ package postgres
import ( import (
"fmt" "fmt"
"net/netip"
"strings" "strings"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
@ -39,8 +40,12 @@ func addFilterExpressionToQuery(query *string, args *[]interface{}, expr storage
*args = append(*args, expr.Value) *args = append(*args, expr.Value)
return nil return nil
case "$index": case "$index":
*query += schemaName + "." + recordsTableName + ".index_cidr >>= " + fmt.Sprintf("$%d", len(*args)+1) if isCIDR(expr.Value) {
*args = append(*args, expr.Value) *query += schemaName + "." + recordsTableName + ".index_cidr >>= " + fmt.Sprintf("$%d", len(*args)+1)
*args = append(*args, expr.Value)
} else {
*query += " false "
}
return nil return nil
default: default:
return fmt.Errorf("unsupported equals filter: %v", expr.Fields) return fmt.Errorf("unsupported equals filter: %v", expr.Fields)
@ -49,3 +54,13 @@ func addFilterExpressionToQuery(query *string, args *[]interface{}, expr storage
return fmt.Errorf("unsupported filter expression: %T", expr) return fmt.Errorf("unsupported filter expression: %T", expr)
} }
} }
func isCIDR(value string) bool {
if _, err := netip.ParsePrefix(value); err == nil {
return true
}
if _, err := netip.ParseAddr(value); err == nil {
return true
}
return false
}

View file

@ -21,12 +21,16 @@ func TestAddFilterExpressionToQuery(t *testing.T) {
Fields: []string{"$index"}, Fields: []string{"$index"},
Value: "v2", Value: "v2",
}, },
storage.EqualsFilterExpression{
Fields: []string{"$index"},
Value: "10.0.0.0/8",
},
}, },
storage.EqualsFilterExpression{ storage.EqualsFilterExpression{
Fields: []string{"type"}, Fields: []string{"type"},
Value: "v3", Value: "v3",
}, },
}) })
assert.Equal(t, "( ( pomerium.records.id = $1 OR pomerium.records.index_cidr >>= $2 ) AND pomerium.records.type = $3 )", query) assert.Equal(t, "( ( pomerium.records.id = $1 OR false OR pomerium.records.index_cidr >>= $2 ) AND pomerium.records.type = $3 )", query)
assert.Equal(t, []any{"v1", "v2", "v3"}, args) assert.Equal(t, []any{"v1", "10.0.0.0/8", "v3"}, args)
} }