pomerium/examples/mutual-tls/main.go
2020-07-17 14:23:06 -04:00

125 lines
3.4 KiB
Go

package main
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"log"
"net"
"net/http"
"os"
)
func main() {
port := "8443"
if fromEnv := os.Getenv("PORT"); fromEnv != "" {
port = fromEnv
}
tlsCert := os.Getenv("TLS_CERT")
tlsKey := os.Getenv("TLS_KEY")
clientCA := os.Getenv("CLIENT_CA")
if tlsCert == "" {
log.Fatal("TLS_CERT environment variable must be set")
}
if tlsKey == "" {
log.Fatal("TLS_KEY environment variable must be set")
}
if clientCA == "" {
log.Fatal("CLIENT_CA environment variable must be set")
}
mux := http.NewServeMux()
mux.HandleFunc("/", hello)
srv := &http.Server{Handler: mux}
ln, err := newClientCertTLSListener(":"+port, tlsCert, tlsKey, clientCA)
if err != nil {
log.Fatalf("failed creating tls listener: %v", err)
}
log.Printf("listening on port %s", port)
log.Fatal(srv.Serve(ln))
}
func hello(w http.ResponseWriter, r *http.Request) {
log.Printf("Serving request: %s", r.URL.Path)
fmt.Fprintf(w, "Hello, world!\n")
fmt.Fprintf(w, "%s %s %s\n", r.Method, r.URL, r.Proto)
fmt.Fprintf(w, "TLS\n\tServerName: %s\n\tVersion: %d \n\t CipherSuite:%d \n", r.TLS.ServerName, r.TLS.Version, r.TLS.CipherSuite)
for _, cert := range r.TLS.PeerCertificates {
fmt.Fprintf(w, "TLSPeerCertificate: Subject %+v\n", cert.Subject)
}
if headerIP := r.Header.Get("X-Forwarded-For"); headerIP != "" {
fmt.Fprintf(w, "Client IP (X-Forwarded-For): %s\n", headerIP)
}
fmt.Fprintf(w, "Headers\n")
for k, v := range r.Header {
fmt.Fprintf(w, "\t[%s]:\n\t\t%s\n", k, v)
}
}
func newClientCertTLSListener(addr, tlsCert, tlsKey, clientCA string) (net.Listener, error) {
caPool, err := decodeCertPoolFromPEM(clientCA)
if err != nil {
return nil, err
}
cert, err := decodeCertificate(tlsCert, tlsKey)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: caPool,
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
PreferServerCipherSuites: true,
CurvePreferences: []tls.CurveID{
tls.X25519,
tls.CurveP256,
},
Certificates: []tls.Certificate{*cert},
NextProtos: []string{"h2"},
}
tlsConfig.BuildNameToCertificate()
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
return tls.NewListener(ln, tlsConfig), nil
}
func decodeCertPoolFromPEM(encPemCerts string) (*x509.CertPool, error) {
pemCerts, err := base64.StdEncoding.DecodeString(encPemCerts)
if err != nil {
return nil, fmt.Errorf("couldn't decode pem %v: %v", pemCerts, err)
}
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(pemCerts); !ok {
return nil, fmt.Errorf("failed to append certs from pem")
}
return certPool, nil
}
func decodeCertificate(cert, key string) (*tls.Certificate, error) {
decodedCert, err := base64.StdEncoding.DecodeString(cert)
if err != nil {
return nil, fmt.Errorf("failed to decode certificate cert %v: %v", decodedCert, err)
}
decodedKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, fmt.Errorf("failed to decode certificate key %v: %v", decodedKey, err)
}
x509, err := tls.X509KeyPair(decodedCert, decodedKey)
return &x509, err
}