mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 10:56:28 +02:00
device enrollment: fix ip address (#3430)
This commit is contained in:
parent
d1037d784a
commit
a938a23ea2
3 changed files with 44 additions and 5 deletions
|
@ -8,7 +8,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
@ -523,9 +522,7 @@ func getOrCreateDeviceEnrollment(
|
||||||
deviceEnrollment.CredentialId = deviceCredentialID
|
deviceEnrollment.CredentialId = deviceCredentialID
|
||||||
deviceEnrollment.EnrolledAt = timestamppb.Now()
|
deviceEnrollment.EnrolledAt = timestamppb.Now()
|
||||||
deviceEnrollment.UserAgent = r.UserAgent()
|
deviceEnrollment.UserAgent = r.UserAgent()
|
||||||
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
deviceEnrollment.IpAddress = httputil.GetClientIPAddress(r)
|
||||||
deviceEnrollment.IpAddress = ip
|
|
||||||
}
|
|
||||||
|
|
||||||
err := device.PutEnrollment(ctx, state.Client, deviceEnrollment)
|
err := device.PutEnrollment(ctx, state.Client, deviceEnrollment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
package httputil
|
package httputil
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// StatusDeviceUnauthorized is the status code returned when a client's
|
// StatusDeviceUnauthorized is the status code returned when a client's
|
||||||
|
@ -39,3 +42,16 @@ func StatusText(code int) string {
|
||||||
}
|
}
|
||||||
return http.StatusText(code)
|
return http.StatusText(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetClientIPAddress gets a client's IP address for an HTTP request.
|
||||||
|
func GetClientIPAddress(r *http.Request) string {
|
||||||
|
if ip := r.Header.Get("X-Envoy-External-Address"); ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
return "127.0.0.1"
|
||||||
|
}
|
||||||
|
|
26
internal/httputil/httputil_test.go
Normal file
26
internal/httputil/httputil_test.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package httputil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetClientIPAddress(t *testing.T) {
|
||||||
|
r1, err := http.NewRequest("GET", "https://example.com", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "127.0.0.1", GetClientIPAddress(r1))
|
||||||
|
|
||||||
|
r2, err := http.NewRequest("GET", "https://example.com", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
r2.RemoteAddr = "127.0.0.2:1234"
|
||||||
|
assert.Equal(t, "127.0.0.2", GetClientIPAddress(r2))
|
||||||
|
|
||||||
|
r3, err := http.NewRequest("GET", "https://example.com", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
r3.RemoteAddr = "127.0.0.3:1234"
|
||||||
|
r3.Header.Set("X-Envoy-External-Address", "127.0.0.3")
|
||||||
|
assert.Equal(t, "127.0.0.3", GetClientIPAddress(r3))
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue