From 4dd5357fe30cadb8e7f4dffa0ed444a6bba71a64 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Fri, 25 Apr 2025 14:47:11 -0400 Subject: [PATCH] mcp: extend code usage (#5588) --- internal/mcp/code.go | 17 +++- internal/mcp/code_test.go | 117 +++++++++++++++++++--- internal/mcp/handler_authorization.go | 1 + internal/oauth21/buf.lock | 4 +- internal/oauth21/gen/code.pb.go | 137 ++++++++++++++++++++------ internal/oauth21/proto/code.proto | 20 +++- 6 files changed, 244 insertions(+), 52 deletions(-) diff --git a/internal/mcp/code.go b/internal/mcp/code.go index d3a317f04..8dced0ac6 100644 --- a/internal/mcp/code.go +++ b/internal/mcp/code.go @@ -14,7 +14,14 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" ) +const ( + CodeTypeAuthorization = oauth21proto.CodeType_CODE_TYPE_AUTHORIZATION + CodeTypeRefresh = oauth21proto.CodeType_CODE_TYPE_REFRESH + CodeTypeAccess = oauth21proto.CodeType_CODE_TYPE_ACCESS +) + func CreateCode( + typ oauth21proto.CodeType, id string, expires time.Time, ad string, @@ -27,6 +34,7 @@ func CreateCode( v := oauth21proto.Code{ Id: id, ExpiresAt: timestamppb.New(expires), + GrantType: typ, } err := protovalidate.Validate(&v) @@ -44,6 +52,7 @@ func CreateCode( } func DecryptCode( + typ oauth21proto.CodeType, code string, cipher cipher.AEAD, ad string, @@ -62,8 +71,12 @@ func DecryptCode( if err != nil { return nil, fmt.Errorf("unmarshal: %w", err) } - if v.ExpiresAt == nil { - return nil, fmt.Errorf("expiration is nil") + err = protovalidate.Validate(&v) + if err != nil { + return nil, fmt.Errorf("validate: %w", err) + } + if v.GrantType != typ { + return nil, fmt.Errorf("code type mismatch: expected %v, got %v", typ, v.GrantType) } if v.ExpiresAt.AsTime().Before(now) { return nil, fmt.Errorf("code expired") diff --git a/internal/mcp/code_test.go b/internal/mcp/code_test.go index a8120c09a..c4261e8e8 100644 --- a/internal/mcp/code_test.go +++ b/internal/mcp/code_test.go @@ -24,6 +24,7 @@ func TestCreateCode(t *testing.T) { tests := []struct { name string + typ oauth21proto.CodeType id string expires time.Time ad string @@ -32,7 +33,26 @@ func TestCreateCode(t *testing.T) { errMessage string }{ { - name: "valid code", + name: "valid authorization code", + typ: CodeTypeAuthorization, + id: "test-id", + expires: time.Now().Add(time.Hour), + ad: "test-ad", + cipher: testCipher, + wantErr: false, + }, + { + name: "valid refresh code", + typ: CodeTypeRefresh, + id: "test-id", + expires: time.Now().Add(time.Hour), + ad: "test-ad", + cipher: testCipher, + wantErr: false, + }, + { + name: "valid access code", + typ: CodeTypeAccess, id: "test-id", expires: time.Now().Add(time.Hour), ad: "test-ad", @@ -41,6 +61,7 @@ func TestCreateCode(t *testing.T) { }, { name: "empty id", + typ: CodeTypeAuthorization, id: "", expires: time.Now().Add(time.Hour), ad: "test-ad", @@ -50,6 +71,7 @@ func TestCreateCode(t *testing.T) { }, { name: "empty expires", + typ: CodeTypeAuthorization, id: "test-id", expires: time.Time{}, ad: "test-ad", @@ -57,11 +79,31 @@ func TestCreateCode(t *testing.T) { wantErr: true, errMessage: "validate", }, + { + name: "invalid code type", + typ: 0, // Unspecified type + id: "test-id", + expires: time.Now().Add(time.Hour), + ad: "test-ad", + cipher: testCipher, + wantErr: true, + errMessage: "validate", + }, + { + name: "undefined code type", + typ: 99, // Undefined type + id: "test-id", + expires: time.Now().Add(time.Hour), + ad: "test-ad", + cipher: testCipher, + wantErr: true, + errMessage: "validate", + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - code, err := CreateCode(tc.id, tc.expires, tc.ad, tc.cipher) + code, err := CreateCode(tc.typ, tc.id, tc.expires, tc.ad, tc.cipher) if tc.wantErr { assert.Error(t, err) @@ -73,9 +115,10 @@ func TestCreateCode(t *testing.T) { assert.NoError(t, err) assert.NotEmpty(t, code) - decodedCode, err := DecryptCode(code, tc.cipher, tc.ad, time.Now()) + decodedCode, err := DecryptCode(tc.typ, code, tc.cipher, tc.ad, time.Now()) require.NoError(t, err) assert.Equal(t, tc.id, decodedCode.Id) + assert.Equal(t, tc.typ, decodedCode.GrantType) assert.True(t, proto.Equal(timestamppb.New(tc.expires), decodedCode.ExpiresAt)) } }) @@ -91,14 +134,18 @@ func TestDecryptCode(t *testing.T) { future := now.Add(time.Hour) past := now.Add(-time.Hour) - validCode, err := CreateCode("test-id", future, "test-ad", testCipher) + validCode, err := CreateCode(CodeTypeAuthorization, "test-id", future, "test-ad", testCipher) require.NoError(t, err) - expiredCode, err := CreateCode("expired-id", past, "test-ad", testCipher) + validRefreshCode, err := CreateCode(CodeTypeRefresh, "refresh-id", future, "test-ad", testCipher) + require.NoError(t, err) + + expiredCode, err := CreateCode(CodeTypeAuthorization, "expired-id", past, "test-ad", testCipher) require.NoError(t, err) codeNoExpiry := &oauth21proto.Code{ - Id: "no-expiry", + Id: "no-expiry", + GrantType: CodeTypeAuthorization, } codeBytes, err := proto.Marshal(codeNoExpiry) require.NoError(t, err) @@ -107,6 +154,7 @@ func TestDecryptCode(t *testing.T) { tests := []struct { name string + typ oauth21proto.CodeType code string cipher cipher.AEAD ad string @@ -117,15 +165,37 @@ func TestDecryptCode(t *testing.T) { }{ { name: "valid code", + typ: CodeTypeAuthorization, code: validCode, cipher: testCipher, ad: "test-ad", now: now, - want: &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future)}, + want: &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeAuthorization}, wantErr: false, }, + { + name: "valid refresh code", + typ: CodeTypeRefresh, + code: validRefreshCode, + cipher: testCipher, + ad: "test-ad", + now: now, + want: &oauth21proto.Code{Id: "refresh-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeRefresh}, + wantErr: false, + }, + { + name: "wrong code type", + typ: CodeTypeAccess, // Using wrong type + code: validCode, // This was created with Authorization type + cipher: testCipher, + ad: "test-ad", + now: now, + wantErr: true, + errMessage: "code type mismatch", + }, { name: "expired code", + typ: CodeTypeAuthorization, code: expiredCode, cipher: testCipher, ad: "test-ad", @@ -135,15 +205,17 @@ func TestDecryptCode(t *testing.T) { }, { name: "nil expiration", + typ: CodeTypeAuthorization, code: codeNoExpiryStr, cipher: testCipher, ad: "test-ad", now: now, wantErr: true, - errMessage: "expiration is nil", + errMessage: "expires_at: value is required", }, { name: "invalid base64", + typ: CodeTypeAuthorization, code: "not-base64", cipher: testCipher, ad: "test-ad", @@ -153,27 +225,48 @@ func TestDecryptCode(t *testing.T) { }, { name: "wrong authentication data", + typ: CodeTypeAuthorization, code: validCode, cipher: testCipher, ad: "wrong-ad", now: now, wantErr: true, - errMessage: "decrypt", + errMessage: "message authentication failed", + }, + { + name: "unspecified code type", + typ: 0, // Unspecified type + code: validCode, + cipher: testCipher, + ad: "test-ad", + now: now, + wantErr: true, + errMessage: "code type mismatch", + }, + { + name: "undefined code type", + typ: 99, // undefined type + code: validCode, + cipher: testCipher, + ad: "test-ad", + now: now, + wantErr: true, + errMessage: "code type mismatch", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got, err := DecryptCode(tc.code, tc.cipher, tc.ad, tc.now) + got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now) if tc.wantErr { - assert.Error(t, err) + require.Error(t, err) if tc.errMessage != "" { assert.Contains(t, err.Error(), tc.errMessage) } assert.Nil(t, got) } else { - assert.NoError(t, err) + require.NoError(t, err) require.NotNil(t, got) diff := cmp.Diff(tc.want, got, protocmp.Transform()) diff --git a/internal/mcp/handler_authorization.go b/internal/mcp/handler_authorization.go index de4613ce0..714ae5ae6 100644 --- a/internal/mcp/handler_authorization.go +++ b/internal/mcp/handler_authorization.go @@ -81,6 +81,7 @@ func (srv *Handler) AuthorizationResponse( req *oauth21proto.AuthorizationRequest, ) { code, err := CreateCode( + CodeTypeAuthorization, id, time.Now().Add(time.Minute*10), req.ClientId, diff --git a/internal/oauth21/buf.lock b/internal/oauth21/buf.lock index 09123d359..ee39cd62d 100644 --- a/internal/oauth21/buf.lock +++ b/internal/oauth21/buf.lock @@ -2,5 +2,5 @@ version: v2 deps: - name: buf.build/bufbuild/protovalidate - commit: 7712fb530c574b95bc1d57c0877543c3 - digest: b5:b3e9c9428384357e3b73e4d5a4614328b0a4b1595b10163bbe9483fa16204749274c41797bd49b0d716479c855aa35c1172a94f471fa120ba8369637fd138829 + commit: 8976f5be98c146529b1cc15cd2012b60 + digest: b5:5d513af91a439d9e78cacac0c9455c7cb885a8737d30405d0b91974fe05276d19c07a876a51a107213a3d01b83ecc912996cdad4cddf7231f91379079cf7488d diff --git a/internal/oauth21/gen/code.pb.go b/internal/oauth21/gen/code.pb.go index b8273c263..e9643099d 100644 --- a/internal/oauth21/gen/code.pb.go +++ b/internal/oauth21/gen/code.pb.go @@ -23,11 +23,64 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type CodeType int32 + +const ( + CodeType_CODE_TYPE_UNSPECIFIED CodeType = 0 + CodeType_CODE_TYPE_AUTHORIZATION CodeType = 1 + CodeType_CODE_TYPE_ACCESS CodeType = 2 + CodeType_CODE_TYPE_REFRESH CodeType = 3 +) + +// Enum value maps for CodeType. +var ( + CodeType_name = map[int32]string{ + 0: "CODE_TYPE_UNSPECIFIED", + 1: "CODE_TYPE_AUTHORIZATION", + 2: "CODE_TYPE_ACCESS", + 3: "CODE_TYPE_REFRESH", + } + CodeType_value = map[string]int32{ + "CODE_TYPE_UNSPECIFIED": 0, + "CODE_TYPE_AUTHORIZATION": 1, + "CODE_TYPE_ACCESS": 2, + "CODE_TYPE_REFRESH": 3, + } +) + +func (x CodeType) Enum() *CodeType { + p := new(CodeType) + *p = x + return p +} + +func (x CodeType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (CodeType) Descriptor() protoreflect.EnumDescriptor { + return file_code_proto_enumTypes[0].Descriptor() +} + +func (CodeType) Type() protoreflect.EnumType { + return &file_code_proto_enumTypes[0] +} + +func (x CodeType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use CodeType.Descriptor instead. +func (CodeType) EnumDescriptor() ([]byte, []int) { + return file_code_proto_rawDescGZIP(), []int{0} +} + // Code is a code used in the authorization code flow. type Code struct { state protoimpl.MessageState `protogen:"open.v1"` Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expires_at,json=expiresAt,proto3" json:"expires_at,omitempty"` + GrantType CodeType `protobuf:"varint,3,opt,name=grant_type,json=grantType,proto3,enum=oauth21.CodeType" json:"grant_type,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -76,32 +129,50 @@ func (x *Code) GetExpiresAt() *timestamppb.Timestamp { return nil } +func (x *Code) GetGrantType() CodeType { + if x != nil { + return x.GrantType + } + return CodeType_CODE_TYPE_UNSPECIFIED +} + var File_code_proto protoreflect.FileDescriptor var file_code_proto_rawDesc = string([]byte{ 0x0a, 0x0a, 0x63, 0x6f, 0x64, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f, 0x61, - 0x75, 0x74, 0x68, 0x32, 0x31, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x61, 0x6c, 0x69, - 0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x22, 0x65, 0x0a, 0x04, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x02, 0x69, - 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0a, 0xba, 0x48, 0x07, 0xc8, 0x01, 0x01, 0x72, - 0x02, 0x10, 0x01, 0x52, 0x02, 0x69, 0x64, 0x12, 0x41, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, - 0x65, 0x73, 0x5f, 0x61, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, - 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, - 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x42, 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, 0x52, - 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x42, 0x9c, 0x01, 0x0a, 0x0b, 0x63, - 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x09, 0x43, 0x6f, 0x64, 0x65, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x2f, 0x62, 0x75, 0x66, - 0x2d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x76, - 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x71, 0x75, 0x69, 0x63, 0x6b, 0x73, 0x74, 0x61, - 0x72, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x2f, 0x67, 0x65, 0x6e, 0xa2, - 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xca, - 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74, - 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, - 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x75, 0x74, 0x68, 0x32, 0x31, 0x1a, 0x1b, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, + 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x22, 0xa4, 0x01, 0x0a, 0x04, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x02, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0a, 0xba, 0x48, 0x07, 0xc8, 0x01, 0x01, + 0x72, 0x02, 0x10, 0x01, 0x52, 0x02, 0x69, 0x64, 0x12, 0x41, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x69, + 0x72, 0x65, 0x73, 0x5f, 0x61, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, + 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x42, 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, + 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x3d, 0x0a, 0x0a, 0x67, + 0x72, 0x61, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x11, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x54, 0x79, + 0x70, 0x65, 0x42, 0x0b, 0xba, 0x48, 0x08, 0xc8, 0x01, 0x01, 0x82, 0x01, 0x02, 0x10, 0x01, 0x52, + 0x09, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x54, 0x79, 0x70, 0x65, 0x2a, 0x6f, 0x0a, 0x08, 0x43, 0x6f, + 0x64, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x19, 0x0a, 0x15, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x54, + 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, + 0x00, 0x12, 0x1b, 0x0a, 0x17, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x41, + 0x55, 0x54, 0x48, 0x4f, 0x52, 0x49, 0x5a, 0x41, 0x54, 0x49, 0x4f, 0x4e, 0x10, 0x01, 0x12, 0x14, + 0x0a, 0x10, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x41, 0x43, 0x43, 0x45, + 0x53, 0x53, 0x10, 0x02, 0x12, 0x15, 0x0a, 0x11, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x54, 0x59, 0x50, + 0x45, 0x5f, 0x52, 0x45, 0x46, 0x52, 0x45, 0x53, 0x48, 0x10, 0x03, 0x42, 0x9c, 0x01, 0x0a, 0x0b, + 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x09, 0x43, 0x6f, 0x64, + 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x2f, 0x62, 0x75, + 0x66, 0x2d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x71, 0x75, 0x69, 0x63, 0x6b, 0x73, 0x74, + 0x61, 0x72, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x2f, 0x67, 0x65, 0x6e, + 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, + 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, + 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, + 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, }) var ( @@ -116,18 +187,21 @@ func file_code_proto_rawDescGZIP() []byte { return file_code_proto_rawDescData } +var file_code_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_code_proto_msgTypes = make([]protoimpl.MessageInfo, 1) var file_code_proto_goTypes = []any{ - (*Code)(nil), // 0: oauth21.Code - (*timestamppb.Timestamp)(nil), // 1: google.protobuf.Timestamp + (CodeType)(0), // 0: oauth21.CodeType + (*Code)(nil), // 1: oauth21.Code + (*timestamppb.Timestamp)(nil), // 2: google.protobuf.Timestamp } var file_code_proto_depIdxs = []int32{ - 1, // 0: oauth21.Code.expires_at:type_name -> google.protobuf.Timestamp - 1, // [1:1] is the sub-list for method output_type - 1, // [1:1] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 2, // 0: oauth21.Code.expires_at:type_name -> google.protobuf.Timestamp + 0, // 1: oauth21.Code.grant_type:type_name -> oauth21.CodeType + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name } func init() { file_code_proto_init() } @@ -140,13 +214,14 @@ func file_code_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_code_proto_rawDesc), len(file_code_proto_rawDesc)), - NumEnums: 0, + NumEnums: 1, NumMessages: 1, NumExtensions: 0, NumServices: 0, }, GoTypes: file_code_proto_goTypes, DependencyIndexes: file_code_proto_depIdxs, + EnumInfos: file_code_proto_enumTypes, MessageInfos: file_code_proto_msgTypes, }.Build() File_code_proto = out.File diff --git a/internal/oauth21/proto/code.proto b/internal/oauth21/proto/code.proto index 7f7ac2dbe..45551fc72 100644 --- a/internal/oauth21/proto/code.proto +++ b/internal/oauth21/proto/code.proto @@ -2,19 +2,29 @@ syntax = "proto3"; package oauth21; -import "google/protobuf/timestamp.proto"; import "buf/validate/validate.proto"; +import "google/protobuf/timestamp.proto"; option go_package = "github.com/pomerium/pomerium/internal/oauth21/gen"; +enum CodeType { + CODE_TYPE_UNSPECIFIED = 0; + CODE_TYPE_AUTHORIZATION = 1; + CODE_TYPE_ACCESS = 2; + CODE_TYPE_REFRESH = 3; +} + // Code is a code used in the authorization code flow. message Code { string id = 1 [ (buf.validate.field).required = true, - (buf.validate.field).string = { - min_len : 1, + (buf.validate.field).string = { + min_len: 1, } ]; - google.protobuf.Timestamp expires_at = 2 - [ (buf.validate.field).required = true ]; + google.protobuf.Timestamp expires_at = 2 [(buf.validate.field).required = true]; + CodeType grant_type = 3 [ + (buf.validate.field).required = true, + (buf.validate.field).enum.defined_only = true + ]; }