216 lines
4.8 KiB
Go
216 lines
4.8 KiB
Go
package user
|
|
|
|
import (
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pquerna/otp/totp"
|
|
"github.com/spf13/viper"
|
|
"golang.org/x/crypto/argon2"
|
|
|
|
"git.1in9.net/raider/wroofauth/internal/parameters"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidHash = errors.New("the encoded hash is not in the correct format")
|
|
ErrIncompatibleVersion = errors.New("incompatible version of argon2")
|
|
ErrWrongPassword = errors.New("wrong password")
|
|
ErrPasswordLoginDeactivated = errors.New("password login deactivated for this user")
|
|
)
|
|
|
|
type UserSecondFactor struct {
|
|
Name string `bson:"name"`
|
|
Type string `bson:"type"`
|
|
Enabled bool `bson:"enabled"`
|
|
Arguments map[string]string `bson:"arguments"`
|
|
}
|
|
|
|
type User struct {
|
|
ID primitive.ObjectID `bson:"_id"`
|
|
|
|
Username *string `bson:"username,omitempty"`
|
|
Email *string `bson:"email,omitempty"`
|
|
PasswordHash *string `bson:"password,omitempty"`
|
|
PasswordChangeTimestamp time.Time `bson:"passwordChangeTimestamp,omitempty"`
|
|
|
|
SecondFactors []*UserSecondFactor `bson:"secondFactors"`
|
|
SecondFactorOverride bool `bson:"secondFactorOverride,omitempty"` // Can be used to temporarily disable 2fa
|
|
}
|
|
|
|
func (u *User) GetType() string {
|
|
return "user"
|
|
}
|
|
|
|
func (u *User) GetID() primitive.ObjectID {
|
|
return u.ID
|
|
}
|
|
|
|
func (u *User) GetFriendlyIdentifier() string {
|
|
if u.Username != nil {
|
|
return *u.Username
|
|
}
|
|
|
|
if u.Email != nil {
|
|
return *u.Email
|
|
}
|
|
|
|
return u.ID.Hex() // User with no identification, fall back to just echoing the ID
|
|
}
|
|
|
|
func (u *User) countActive2FA() int {
|
|
count := 0
|
|
|
|
for _, factor := range u.SecondFactors {
|
|
if !factor.Enabled {
|
|
continue
|
|
}
|
|
|
|
count++
|
|
}
|
|
|
|
return count
|
|
}
|
|
|
|
func (u *User) Needs2FA() bool {
|
|
return !u.SecondFactorOverride && u.countActive2FA() > 0
|
|
}
|
|
|
|
func (u *User) CheckPassword(password string) error {
|
|
if u.PasswordHash == nil {
|
|
return ErrPasswordLoginDeactivated
|
|
}
|
|
|
|
p, salt, hash, err := decodeHash(*u.PasswordHash)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
otherHash := argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength)
|
|
|
|
if subtle.ConstantTimeCompare(hash, otherHash) == 1 {
|
|
return nil
|
|
}
|
|
return ErrWrongPassword
|
|
}
|
|
|
|
func (u *User) SetPassword(password string) error {
|
|
p := parameters.GetPasswordParams()
|
|
|
|
salt, err := parameters.GenerateRandomBytes(p.SaltLength)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
hash := argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength)
|
|
|
|
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
|
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
|
|
|
encodedHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, p.Memory, p.Iterations, p.Parallelism, b64Salt, b64Hash)
|
|
|
|
u.PasswordHash = &encodedHash
|
|
u.PasswordChangeTimestamp = time.Now()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (u *User) AddTOTP(name string) (string, error) {
|
|
key, err := totp.Generate(totp.GenerateOpts{
|
|
Issuer: viper.GetString("totp.issuer"),
|
|
AccountName: u.GetFriendlyIdentifier(),
|
|
})
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
u.SecondFactors = append(u.SecondFactors, &UserSecondFactor{
|
|
Name: name,
|
|
Type: "totp",
|
|
Enabled: false,
|
|
Arguments: map[string]string{
|
|
"secret": key.Secret(),
|
|
},
|
|
})
|
|
|
|
return key.String(), nil
|
|
}
|
|
|
|
func (u *User) EnableTOTP(code string) error {
|
|
for _, factor := range u.SecondFactors {
|
|
if factor.Enabled || factor.Type != "totp" {
|
|
continue
|
|
}
|
|
|
|
success := totp.Validate(code, factor.Arguments["secret"])
|
|
|
|
if !success {
|
|
continue
|
|
}
|
|
|
|
factor.Enabled = true
|
|
|
|
return nil
|
|
}
|
|
|
|
return errors.New("no such totp")
|
|
}
|
|
|
|
func (u *User) ValidateTOTP(code string) error {
|
|
for _, factor := range u.SecondFactors {
|
|
if !factor.Enabled || factor.Type != "totp" {
|
|
continue
|
|
}
|
|
|
|
success := totp.Validate(code, factor.Arguments["secret"])
|
|
|
|
if !success {
|
|
continue
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
return errors.New("no such totp")
|
|
}
|
|
|
|
func decodeHash(encodedHash string) (p *parameters.PasswordParams, salt, hash []byte, err error) {
|
|
values := strings.Split(encodedHash, "$")
|
|
if len(values) != 6 {
|
|
return nil, nil, nil, ErrInvalidHash
|
|
}
|
|
|
|
var version int
|
|
_, err = fmt.Sscanf(values[2], "v=%d", &version)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
if version != argon2.Version {
|
|
return nil, nil, nil, ErrIncompatibleVersion
|
|
}
|
|
|
|
p = ¶meters.PasswordParams{}
|
|
_, err = fmt.Sscanf(values[3], "m=%d,t=%d,p=%d", &p.Memory, &p.Iterations, &p.Parallelism)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
salt, err = base64.RawStdEncoding.Strict().DecodeString(values[4])
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
p.SaltLength = uint32(len(salt))
|
|
|
|
hash, err = base64.RawStdEncoding.Strict().DecodeString(values[5])
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
p.KeyLength = uint32(len(hash))
|
|
|
|
return p, salt, hash, nil
|
|
}
|