package user import ( "context" "crypto/subtle" "encoding/base64" "errors" "fmt" "strings" "time" "github.com/bsm/redislock" "github.com/pquerna/otp/totp" "github.com/spf13/viper" "golang.org/x/crypto/argon2" "git.1in9.net/raider/wroofauth/internal/database" "git.1in9.net/raider/wroofauth/internal/parameters" "go.mongodb.org/mongo-driver/bson/primitive" ) var ( ErrInvalidHash = errors.New("the encoded hash is not in the correct format") ErrIncompatibleVersion = errors.New("incompatible version of argon2") ErrWrongPassword = errors.New("wrong password") ErrPasswordLoginDeactivated = errors.New("password login deactivated for this user") ) type UserSecondFactor struct { Name string `bson:"name"` Type string `bson:"type"` Enabled bool `bson:"enabled"` Arguments map[string]string `bson:"arguments"` } type User struct { ID primitive.ObjectID `bson:"_id"` Username *string `bson:"username,omitempty"` Email *string `bson:"email,omitempty"` PasswordHash *string `bson:"password,omitempty"` PasswordChangeTimestamp time.Time `bson:"passwordChangeTimestamp,omitempty"` SecondFactors []*UserSecondFactor `bson:"secondFactors"` SecondFactorOverride bool `bson:"secondFactorOverride,omitempty"` // Can be used to temporarily disable 2fa } func New() User { return User{ ID: primitive.NewObjectID(), } } func (u *User) GetType() string { return "user" } func (u *User) GetID() primitive.ObjectID { return u.ID } func (u *User) Lock(ctx context.Context) (*redislock.Lock, error) { return database.Locker.Obtain(ctx, u.ID.Hex(), 100*time.Millisecond, &redislock.Options{ RetryStrategy: database.LockBackoff, }) } func (u *User) GetFriendlyIdentifier() string { if u.Username != nil { return *u.Username } if u.Email != nil { return *u.Email } return u.ID.Hex() // User with no identification, fall back to just echoing the ID } func (u *User) countActive2FA() int { count := 0 for _, factor := range u.SecondFactors { if !factor.Enabled { continue } count++ } return count } func (u *User) Needs2FA() bool { return !u.SecondFactorOverride && u.countActive2FA() > 0 } func (u *User) CheckPassword(password string) error { if u.PasswordHash == nil { return ErrPasswordLoginDeactivated } p, salt, hash, err := decodeHash(*u.PasswordHash) if err != nil { return err } otherHash := argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength) if subtle.ConstantTimeCompare(hash, otherHash) == 1 { return nil } return ErrWrongPassword } func (u *User) SetPassword(password string) error { p := parameters.GetPasswordParams() salt, err := parameters.GenerateRandomBytes(p.SaltLength) if err != nil { return err } hash := argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength) b64Salt := base64.RawStdEncoding.EncodeToString(salt) b64Hash := base64.RawStdEncoding.EncodeToString(hash) encodedHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, p.Memory, p.Iterations, p.Parallelism, b64Salt, b64Hash) u.PasswordHash = &encodedHash u.PasswordChangeTimestamp = time.Now() return nil } func (u *User) AddTOTP(name string) (string, error) { key, err := totp.Generate(totp.GenerateOpts{ Issuer: viper.GetString("totp.issuer"), AccountName: u.GetFriendlyIdentifier(), }) if err != nil { return "", err } u.SecondFactors = append(u.SecondFactors, &UserSecondFactor{ Name: name, Type: "totp", Enabled: false, Arguments: map[string]string{ "secret": key.Secret(), }, }) return key.String(), nil } func (u *User) EnableTOTP(code string) error { for _, factor := range u.SecondFactors { if factor.Enabled || factor.Type != "totp" { continue } success := totp.Validate(code, factor.Arguments["secret"]) if !success { continue } factor.Enabled = true return nil } return errors.New("no such totp") } func (u *User) ValidateTOTP(code string) error { for _, factor := range u.SecondFactors { if !factor.Enabled || factor.Type != "totp" { continue } success := totp.Validate(code, factor.Arguments["secret"]) if !success { continue } return nil } return errors.New("no such totp") } func decodeHash(encodedHash string) (p *parameters.PasswordParams, salt, hash []byte, err error) { values := strings.Split(encodedHash, "$") if len(values) != 6 { return nil, nil, nil, ErrInvalidHash } var version int _, err = fmt.Sscanf(values[2], "v=%d", &version) if err != nil { return nil, nil, nil, err } if version != argon2.Version { return nil, nil, nil, ErrIncompatibleVersion } p = ¶meters.PasswordParams{} _, err = fmt.Sscanf(values[3], "m=%d,t=%d,p=%d", &p.Memory, &p.Iterations, &p.Parallelism) if err != nil { return nil, nil, nil, err } salt, err = base64.RawStdEncoding.Strict().DecodeString(values[4]) if err != nil { return nil, nil, nil, err } p.SaltLength = uint32(len(salt)) hash, err = base64.RawStdEncoding.Strict().DecodeString(values[5]) if err != nil { return nil, nil, nil, err } p.KeyLength = uint32(len(hash)) return p, salt, hash, nil }