wroofauth/internal/entities/user/user.go
2023-10-19 11:44:03 +00:00

216 lines
4.8 KiB
Go

package user
import (
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"
"github.com/pquerna/otp/totp"
"github.com/spf13/viper"
"golang.org/x/crypto/argon2"
"git.1in9.net/raider/wroofauth/internal/parameters"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var (
ErrInvalidHash = errors.New("the encoded hash is not in the correct format")
ErrIncompatibleVersion = errors.New("incompatible version of argon2")
ErrWrongPassword = errors.New("wrong password")
ErrPasswordLoginDeactivated = errors.New("password login deactivated for this user")
)
type UserSecondFactor struct {
Name string `bson:"name"`
Type string `bson:"type"`
Enabled bool `bson:"enabled"`
Arguments map[string]string `bson:"arguments"`
}
type User struct {
ID primitive.ObjectID `bson:"_id"`
Username *string `bson:"username,omitempty"`
Email *string `bson:"email,omitempty"`
PasswordHash *string `bson:"password,omitempty"`
PasswordChangeTimestamp time.Time `bson:"passwordChangeTimestamp,omitempty"`
SecondFactors []*UserSecondFactor `bson:"secondFactors"`
SecondFactorOverride bool `bson:"secondFactorOverride,omitempty"` // Can be used to temporarily disable 2fa
}
func (u *User) GetType() string {
return "user"
}
func (u *User) GetID() primitive.ObjectID {
return u.ID
}
func (u *User) GetFriendlyIdentifier() string {
if u.Username != nil {
return *u.Username
}
if u.Email != nil {
return *u.Email
}
return u.ID.Hex() // User with no identification, fall back to just echoing the ID
}
func (u *User) countActive2FA() int {
count := 0
for _, factor := range u.SecondFactors {
if !factor.Enabled {
continue
}
count++
}
return count
}
func (u *User) Needs2FA() bool {
return !u.SecondFactorOverride && u.countActive2FA() > 0
}
func (u *User) CheckPassword(password string) error {
if u.PasswordHash == nil {
return ErrPasswordLoginDeactivated
}
p, salt, hash, err := decodeHash(*u.PasswordHash)
if err != nil {
return err
}
otherHash := argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength)
if subtle.ConstantTimeCompare(hash, otherHash) == 1 {
return nil
}
return ErrWrongPassword
}
func (u *User) SetPassword(password string) error {
p := parameters.GetPasswordParams()
salt, err := parameters.GenerateRandomBytes(p.SaltLength)
if err != nil {
return err
}
hash := argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength)
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
encodedHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, p.Memory, p.Iterations, p.Parallelism, b64Salt, b64Hash)
u.PasswordHash = &encodedHash
u.PasswordChangeTimestamp = time.Now()
return nil
}
func (u *User) AddTOTP(name string) (string, error) {
key, err := totp.Generate(totp.GenerateOpts{
Issuer: viper.GetString("totp.issuer"),
AccountName: u.GetFriendlyIdentifier(),
})
if err != nil {
return "", err
}
u.SecondFactors = append(u.SecondFactors, &UserSecondFactor{
Name: name,
Type: "totp",
Enabled: false,
Arguments: map[string]string{
"secret": key.Secret(),
},
})
return key.String(), nil
}
func (u *User) EnableTOTP(code string) error {
for _, factor := range u.SecondFactors {
if factor.Enabled || factor.Type != "totp" {
continue
}
success := totp.Validate(code, factor.Arguments["secret"])
if !success {
continue
}
factor.Enabled = true
return nil
}
return errors.New("no such totp")
}
func (u *User) ValidateTOTP(code string) error {
for _, factor := range u.SecondFactors {
if !factor.Enabled || factor.Type != "totp" {
continue
}
success := totp.Validate(code, factor.Arguments["secret"])
if !success {
continue
}
return nil
}
return errors.New("no such totp")
}
func decodeHash(encodedHash string) (p *parameters.PasswordParams, salt, hash []byte, err error) {
values := strings.Split(encodedHash, "$")
if len(values) != 6 {
return nil, nil, nil, ErrInvalidHash
}
var version int
_, err = fmt.Sscanf(values[2], "v=%d", &version)
if err != nil {
return nil, nil, nil, err
}
if version != argon2.Version {
return nil, nil, nil, ErrIncompatibleVersion
}
p = &parameters.PasswordParams{}
_, err = fmt.Sscanf(values[3], "m=%d,t=%d,p=%d", &p.Memory, &p.Iterations, &p.Parallelism)
if err != nil {
return nil, nil, nil, err
}
salt, err = base64.RawStdEncoding.Strict().DecodeString(values[4])
if err != nil {
return nil, nil, nil, err
}
p.SaltLength = uint32(len(salt))
hash, err = base64.RawStdEncoding.Strict().DecodeString(values[5])
if err != nil {
return nil, nil, nil, err
}
p.KeyLength = uint32(len(hash))
return p, salt, hash, nil
}