package queryparam

import (
	"errors"
	"net/http/httptest"
	"net/url"
	"testing"

	"github.com/google/go-cmp/cmp"

	"github.com/pomerium/pomerium/internal/encoding"
	"github.com/pomerium/pomerium/internal/encoding/mock"
	"github.com/pomerium/pomerium/internal/sessions"
)

func TestNewQueryParamStore(t *testing.T) {
	tests := []struct {
		name  string
		State *sessions.State

		enc     encoding.MarshalUnmarshaler
		qp      string
		wantErr bool
		wantURL *url.URL
	}{
		{"simple good", &sessions.State{}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}},
		{"marshall error", &sessions.State{}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got := NewStore(tt.enc, tt.qp)

			r := httptest.NewRequest("GET", "/", nil)
			w := httptest.NewRecorder()

			if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
				t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
			}

			if diff := cmp.Diff(r.URL, tt.wantURL); diff != "" {
				t.Errorf("NewStore() = %v", diff)
			}
			got.ClearSession(w, r)
			if diff := cmp.Diff(r.URL, &url.URL{Path: "/"}); diff != "" {
				t.Errorf("NewStore() = %v", diff)
			}
		})
	}
}