device enrollment: fix ip address (#3430)

This commit is contained in:
Caleb Doxsey 2022-06-16 11:30:38 -06:00 committed by GitHub
parent d1037d784a
commit a938a23ea2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 5 deletions

View file

@ -8,7 +8,6 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
@ -523,9 +522,7 @@ func getOrCreateDeviceEnrollment(
deviceEnrollment.CredentialId = deviceCredentialID
deviceEnrollment.EnrolledAt = timestamppb.Now()
deviceEnrollment.UserAgent = r.UserAgent()
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
deviceEnrollment.IpAddress = ip
}
deviceEnrollment.IpAddress = httputil.GetClientIPAddress(r)
err := device.PutEnrollment(ctx, state.Client, deviceEnrollment)
if err != nil {

View file

@ -1,6 +1,9 @@
package httputil
import "net/http"
import (
"net"
"net/http"
)
const (
// StatusDeviceUnauthorized is the status code returned when a client's
@ -39,3 +42,16 @@ func StatusText(code int) string {
}
return http.StatusText(code)
}
// GetClientIPAddress gets a client's IP address for an HTTP request.
func GetClientIPAddress(r *http.Request) string {
if ip := r.Header.Get("X-Envoy-External-Address"); ip != "" {
return ip
}
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return ip
}
return "127.0.0.1"
}

View file

@ -0,0 +1,26 @@
package httputil
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetClientIPAddress(t *testing.T) {
r1, err := http.NewRequest("GET", "https://example.com", nil)
require.NoError(t, err)
assert.Equal(t, "127.0.0.1", GetClientIPAddress(r1))
r2, err := http.NewRequest("GET", "https://example.com", nil)
require.NoError(t, err)
r2.RemoteAddr = "127.0.0.2:1234"
assert.Equal(t, "127.0.0.2", GetClientIPAddress(r2))
r3, err := http.NewRequest("GET", "https://example.com", nil)
require.NoError(t, err)
r3.RemoteAddr = "127.0.0.3:1234"
r3.Header.Set("X-Envoy-External-Address", "127.0.0.3")
assert.Equal(t, "127.0.0.3", GetClientIPAddress(r3))
}