first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
This commit is contained in:
199
server/auth/authenticator.go
Normal file
199
server/auth/authenticator.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// Authenticator provides shared authentication and authorization logic.
|
||||
// Used by gRPC interceptor, Connect interceptor, and file server to ensure
|
||||
// consistent authentication behavior across all API endpoints.
|
||||
//
|
||||
// Authentication methods:
|
||||
// - JWT access tokens: Short-lived tokens (15 minutes) for API access
|
||||
// - Personal Access Tokens (PAT): Long-lived tokens for programmatic access
|
||||
//
|
||||
// This struct is safe for concurrent use.
|
||||
type Authenticator struct {
|
||||
store *store.Store
|
||||
secret string
|
||||
}
|
||||
|
||||
// NewAuthenticator creates a new Authenticator instance.
|
||||
func NewAuthenticator(store *store.Store, secret string) *Authenticator {
|
||||
return &Authenticator{
|
||||
store: store,
|
||||
secret: secret,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticateByAccessTokenV2 validates a short-lived access token.
|
||||
// Returns claims without database query (stateless validation).
|
||||
func (a *Authenticator) AuthenticateByAccessTokenV2(accessToken string) (*UserClaims, error) {
|
||||
claims, err := ParseAccessTokenV2(accessToken, []byte(a.secret))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid access token")
|
||||
}
|
||||
|
||||
userID, err := util.ConvertStringToInt32(claims.Subject)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid user ID in token")
|
||||
}
|
||||
|
||||
return &UserClaims{
|
||||
UserID: userID,
|
||||
Username: claims.Username,
|
||||
Role: claims.Role,
|
||||
Status: claims.Status,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthenticateByRefreshToken validates a refresh token against the database.
|
||||
func (a *Authenticator) AuthenticateByRefreshToken(ctx context.Context, refreshToken string) (*store.User, string, error) {
|
||||
claims, err := ParseRefreshToken(refreshToken, []byte(a.secret))
|
||||
if err != nil {
|
||||
return nil, "", errors.Wrap(err, "invalid refresh token")
|
||||
}
|
||||
|
||||
userID, err := util.ConvertStringToInt32(claims.Subject)
|
||||
if err != nil {
|
||||
return nil, "", errors.Wrap(err, "invalid user ID in token")
|
||||
}
|
||||
|
||||
// Check token exists in database (revocation check)
|
||||
token, err := a.store.GetUserRefreshTokenByID(ctx, userID, claims.TokenID)
|
||||
if err != nil {
|
||||
return nil, "", errors.Wrap(err, "failed to get refresh token")
|
||||
}
|
||||
if token == nil {
|
||||
return nil, "", errors.New("refresh token revoked")
|
||||
}
|
||||
|
||||
// Check token not expired
|
||||
if token.ExpiresAt != nil && token.ExpiresAt.AsTime().Before(time.Now()) {
|
||||
return nil, "", errors.New("refresh token expired")
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := a.store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
if err != nil {
|
||||
return nil, "", errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, "", errors.New("user not found")
|
||||
}
|
||||
if user.RowStatus == store.Archived {
|
||||
return nil, "", errors.New("user is archived")
|
||||
}
|
||||
|
||||
return user, claims.TokenID, nil
|
||||
}
|
||||
|
||||
// AuthenticateByPAT validates a Personal Access Token.
|
||||
func (a *Authenticator) AuthenticateByPAT(ctx context.Context, token string) (*store.User, *storepb.PersonalAccessTokensUserSetting_PersonalAccessToken, error) {
|
||||
if !strings.HasPrefix(token, PersonalAccessTokenPrefix) {
|
||||
return nil, nil, errors.New("invalid PAT format")
|
||||
}
|
||||
|
||||
tokenHash := HashPersonalAccessToken(token)
|
||||
result, err := a.store.GetUserByPATHash(ctx, tokenHash)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "invalid PAT")
|
||||
}
|
||||
|
||||
// Check expiry
|
||||
if result.PAT.ExpiresAt != nil && result.PAT.ExpiresAt.AsTime().Before(time.Now()) {
|
||||
return nil, nil, errors.New("PAT expired")
|
||||
}
|
||||
|
||||
// Check user status
|
||||
if result.User.RowStatus == store.Archived {
|
||||
return nil, nil, errors.New("user is archived")
|
||||
}
|
||||
|
||||
return result.User, result.PAT, nil
|
||||
}
|
||||
|
||||
// AuthResult contains the result of an authentication attempt.
|
||||
type AuthResult struct {
|
||||
User *store.User // Set for PAT authentication
|
||||
Claims *UserClaims // Set for Access Token V2 (stateless)
|
||||
AccessToken string // Non-empty if authenticated via JWT
|
||||
}
|
||||
|
||||
// AuthenticateToUser resolves the current request to a *store.User, checking the
|
||||
// Authorization header first (access token or PAT), then falling back to the
|
||||
// refresh token cookie. Returns (nil, nil) when no credentials are present.
|
||||
func (a *Authenticator) AuthenticateToUser(ctx context.Context, authHeader, cookieHeader string) (*store.User, error) {
|
||||
// Try Bearer token first.
|
||||
if authHeader != "" {
|
||||
token := ExtractBearerToken(authHeader)
|
||||
if token != "" {
|
||||
if !strings.HasPrefix(token, PersonalAccessTokenPrefix) {
|
||||
claims, err := a.AuthenticateByAccessTokenV2(token)
|
||||
if err == nil && claims != nil {
|
||||
return a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
|
||||
}
|
||||
} else {
|
||||
user, _, err := a.AuthenticateByPAT(ctx, token)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: refresh token cookie.
|
||||
if cookieHeader != "" {
|
||||
refreshToken := ExtractRefreshTokenFromCookie(cookieHeader)
|
||||
if refreshToken != "" {
|
||||
user, _, err := a.AuthenticateByRefreshToken(ctx, refreshToken)
|
||||
return user, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Authenticate tries to authenticate using the provided credentials.
|
||||
// Priority: 1. Access Token V2, 2. PAT
|
||||
// Returns nil if no valid credentials are provided.
|
||||
func (a *Authenticator) Authenticate(ctx context.Context, authHeader string) *AuthResult {
|
||||
token := ExtractBearerToken(authHeader)
|
||||
|
||||
// Try Access Token V2 (stateless)
|
||||
if token != "" && !strings.HasPrefix(token, PersonalAccessTokenPrefix) {
|
||||
claims, err := a.AuthenticateByAccessTokenV2(token)
|
||||
if err == nil && claims != nil {
|
||||
return &AuthResult{
|
||||
Claims: claims,
|
||||
AccessToken: token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try PAT
|
||||
if token != "" && strings.HasPrefix(token, PersonalAccessTokenPrefix) {
|
||||
user, pat, err := a.AuthenticateByPAT(ctx, token)
|
||||
if err == nil && user != nil {
|
||||
// Update last used (fire-and-forget with logging)
|
||||
go func() {
|
||||
if err := a.store.UpdatePATLastUsed(context.Background(), user.ID, pat.TokenId, timestamppb.Now()); err != nil {
|
||||
slog.Warn("failed to update PAT last used time", "error", err, "userID", user.ID)
|
||||
}
|
||||
}()
|
||||
return &AuthResult{User: user, AccessToken: token}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
99
server/auth/context.go
Normal file
99
server/auth/context.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// ContextKey is the key type for context values.
|
||||
// Using a custom type prevents collisions with other packages.
|
||||
type ContextKey int
|
||||
|
||||
const (
|
||||
// UserIDContextKey stores the authenticated user's ID.
|
||||
// Set for all authenticated requests.
|
||||
// Use GetUserID(ctx) to retrieve this value.
|
||||
UserIDContextKey ContextKey = iota
|
||||
|
||||
// AccessTokenContextKey stores the JWT token for token-based auth.
|
||||
// Only set when authenticated via Bearer token.
|
||||
AccessTokenContextKey
|
||||
|
||||
// UserClaimsContextKey stores the claims from access token.
|
||||
UserClaimsContextKey
|
||||
|
||||
// RefreshTokenIDContextKey stores the refresh token ID.
|
||||
RefreshTokenIDContextKey
|
||||
)
|
||||
|
||||
// GetUserID retrieves the authenticated user's ID from the context.
|
||||
// Returns 0 if no user ID is set (unauthenticated request).
|
||||
func GetUserID(ctx context.Context) int32 {
|
||||
if v, ok := ctx.Value(UserIDContextKey).(int32); ok {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetAccessToken retrieves the JWT access token from the context.
|
||||
// Returns empty string if not authenticated via bearer token.
|
||||
func GetAccessToken(ctx context.Context) string {
|
||||
if v, ok := ctx.Value(AccessTokenContextKey).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetUserInContext sets the authenticated user's information in the context.
|
||||
// This is a simpler alternative to AuthorizeAndSetContext for cases where
|
||||
// authorization is handled separately (e.g., HTTP middleware).
|
||||
//
|
||||
// Parameters:
|
||||
// - user: The authenticated user
|
||||
// - accessToken: Set if authenticated via JWT token (empty string otherwise)
|
||||
func SetUserInContext(ctx context.Context, user *store.User, accessToken string) context.Context {
|
||||
ctx = context.WithValue(ctx, UserIDContextKey, user.ID)
|
||||
if accessToken != "" {
|
||||
ctx = context.WithValue(ctx, AccessTokenContextKey, accessToken)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// UserClaims represents authenticated user info from access token.
|
||||
type UserClaims struct {
|
||||
UserID int32
|
||||
Username string
|
||||
Role string
|
||||
Status string
|
||||
}
|
||||
|
||||
// GetUserClaims retrieves the user claims from context.
|
||||
// Returns nil if not authenticated via access token.
|
||||
func GetUserClaims(ctx context.Context) *UserClaims {
|
||||
if v, ok := ctx.Value(UserClaimsContextKey).(*UserClaims); ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetUserClaimsInContext sets the user claims in context.
|
||||
func SetUserClaimsInContext(ctx context.Context, claims *UserClaims) context.Context {
|
||||
return context.WithValue(ctx, UserClaimsContextKey, claims)
|
||||
}
|
||||
|
||||
// ApplyToContext sets the authenticated identity from an AuthResult into the context.
|
||||
// This is the canonical way to propagate auth state after a successful Authenticate call.
|
||||
// Safe to call with a nil result (no-op).
|
||||
func ApplyToContext(ctx context.Context, result *AuthResult) context.Context {
|
||||
if result == nil {
|
||||
return ctx
|
||||
}
|
||||
if result.Claims != nil {
|
||||
ctx = SetUserClaimsInContext(ctx, result.Claims)
|
||||
ctx = context.WithValue(ctx, UserIDContextKey, result.Claims.UserID)
|
||||
} else if result.User != nil {
|
||||
ctx = SetUserInContext(ctx, result.User, result.AccessToken)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
33
server/auth/extract.go
Normal file
33
server/auth/extract.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExtractBearerToken extracts the JWT token from an Authorization header value.
|
||||
// Expected format: "Bearer {token}"
|
||||
// Returns empty string if no valid bearer token is found.
|
||||
func ExtractBearerToken(authHeader string) string {
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Fields(authHeader)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
|
||||
return ""
|
||||
}
|
||||
return parts[1]
|
||||
}
|
||||
|
||||
// ExtractRefreshTokenFromCookie extracts the refresh token from cookie header.
|
||||
func ExtractRefreshTokenFromCookie(cookieHeader string) string {
|
||||
if cookieHeader == "" {
|
||||
return ""
|
||||
}
|
||||
req := &http.Request{Header: http.Header{"Cookie": []string{cookieHeader}}}
|
||||
cookie, err := req.Cookie(RefreshTokenCookieName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
249
server/auth/token.go
Normal file
249
server/auth/token.go
Normal file
@@ -0,0 +1,249 @@
|
||||
// Package auth provides authentication and authorization for the Memos server.
|
||||
//
|
||||
// This package is used by:
|
||||
// - server/router/api/v1: gRPC and Connect API interceptors
|
||||
// - server/router/fileserver: HTTP file server authentication
|
||||
//
|
||||
// Authentication methods supported:
|
||||
// - JWT access tokens: Short-lived tokens (15 minutes) for API access
|
||||
// - JWT refresh tokens: Long-lived tokens (30 days) for obtaining new access tokens
|
||||
// - Personal Access Tokens (PAT): Long-lived tokens for programmatic access
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// Issuer is the issuer claim in JWT tokens.
|
||||
// This identifies tokens as issued by Memos.
|
||||
Issuer = "memos"
|
||||
|
||||
// KeyID is the key identifier used in JWT header.
|
||||
// Version "v1" allows for future key rotation while maintaining backward compatibility.
|
||||
// If signing mechanism changes, add "v2", "v3", etc. and verify both versions.
|
||||
KeyID = "v1"
|
||||
|
||||
// AccessTokenAudienceName is the audience claim for JWT access tokens.
|
||||
// This ensures tokens are only used for API access, not other purposes.
|
||||
AccessTokenAudienceName = "user.access-token"
|
||||
|
||||
// AccessTokenDuration is the lifetime of access tokens (15 minutes).
|
||||
AccessTokenDuration = 15 * time.Minute
|
||||
|
||||
// RefreshTokenDuration is the lifetime of refresh tokens (30 days).
|
||||
RefreshTokenDuration = 30 * 24 * time.Hour
|
||||
|
||||
// RefreshTokenAudienceName is the audience claim for refresh tokens.
|
||||
RefreshTokenAudienceName = "user.refresh-token"
|
||||
|
||||
// RefreshTokenCookieName is the cookie name for refresh tokens.
|
||||
RefreshTokenCookieName = "memos_refresh"
|
||||
|
||||
// PersonalAccessTokenPrefix is the prefix for PAT tokens.
|
||||
PersonalAccessTokenPrefix = "memos_pat_"
|
||||
)
|
||||
|
||||
// ClaimsMessage represents the claims structure in a JWT token.
|
||||
//
|
||||
// JWT Claims include:
|
||||
// - name: Username (custom claim)
|
||||
// - iss: Issuer = "memos"
|
||||
// - aud: Audience = "user.access-token"
|
||||
// - sub: Subject = user ID
|
||||
// - iat: Issued at time
|
||||
// - exp: Expiration time (optional, may be empty for never-expiring tokens).
|
||||
type ClaimsMessage struct {
|
||||
Name string `json:"name"` // Username
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// AccessTokenClaims contains claims for short-lived access tokens.
|
||||
// These tokens are validated by signature only (stateless).
|
||||
type AccessTokenClaims struct {
|
||||
Type string `json:"type"` // "access"
|
||||
Role string `json:"role"` // User role
|
||||
Status string `json:"status"` // User status
|
||||
Username string `json:"username"` // Username for display
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// RefreshTokenClaims contains claims for long-lived refresh tokens.
|
||||
// These tokens are validated against the database for revocation.
|
||||
type RefreshTokenClaims struct {
|
||||
Type string `json:"type"` // "refresh"
|
||||
TokenID string `json:"tid"` // Token ID for revocation lookup
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateAccessToken generates a JWT access token for a user.
|
||||
//
|
||||
// Parameters:
|
||||
// - username: The user's username (stored in "name" claim)
|
||||
// - userID: The user's ID (stored in "sub" claim)
|
||||
// - expirationTime: When the token expires (pass zero time for no expiration)
|
||||
// - secret: Server secret used to sign the token
|
||||
//
|
||||
// Returns a signed JWT string or an error.
|
||||
func GenerateAccessToken(username string, userID int32, expirationTime time.Time, secret []byte) (string, error) {
|
||||
return generateToken(username, userID, AccessTokenAudienceName, expirationTime, secret)
|
||||
}
|
||||
|
||||
// generateToken generates a JWT token with the given claims.
|
||||
//
|
||||
// Token structure:
|
||||
// Header: {"alg": "HS256", "kid": "v1", "typ": "JWT"}
|
||||
// Claims: {"name": username, "iss": "memos", "aud": [audience], "sub": userID, "iat": now, "exp": expiry}
|
||||
// Signature: HMACSHA256(base64UrlEncode(header) + "." + base64UrlEncode(payload), secret).
|
||||
func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) {
|
||||
registeredClaims := jwt.RegisteredClaims{
|
||||
Issuer: Issuer,
|
||||
Audience: jwt.ClaimStrings{audience},
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Subject: fmt.Sprint(userID),
|
||||
}
|
||||
if !expirationTime.IsZero() {
|
||||
registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime)
|
||||
}
|
||||
|
||||
// Declare the token with the HS256 algorithm used for signing, and the claims.
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{
|
||||
Name: username,
|
||||
RegisteredClaims: registeredClaims,
|
||||
})
|
||||
token.Header["kid"] = KeyID
|
||||
|
||||
// Create the JWT string.
|
||||
tokenString, err := token.SignedString(secret)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// GenerateAccessTokenV2 generates a short-lived access token with user claims.
|
||||
func GenerateAccessTokenV2(userID int32, username, role, status string, secret []byte) (string, time.Time, error) {
|
||||
expiresAt := time.Now().Add(AccessTokenDuration)
|
||||
|
||||
claims := &AccessTokenClaims{
|
||||
Type: "access",
|
||||
Role: role,
|
||||
Status: status,
|
||||
Username: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: Issuer,
|
||||
Audience: jwt.ClaimStrings{AccessTokenAudienceName},
|
||||
Subject: fmt.Sprint(userID),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
token.Header["kid"] = KeyID
|
||||
|
||||
tokenString, err := token.SignedString(secret)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
return tokenString, expiresAt, nil
|
||||
}
|
||||
|
||||
// GenerateRefreshToken generates a long-lived refresh token.
|
||||
func GenerateRefreshToken(userID int32, tokenID string, secret []byte) (string, time.Time, error) {
|
||||
expiresAt := time.Now().Add(RefreshTokenDuration)
|
||||
|
||||
claims := &RefreshTokenClaims{
|
||||
Type: "refresh",
|
||||
TokenID: tokenID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: Issuer,
|
||||
Audience: jwt.ClaimStrings{RefreshTokenAudienceName},
|
||||
Subject: fmt.Sprint(userID),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
token.Header["kid"] = KeyID
|
||||
|
||||
tokenString, err := token.SignedString(secret)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
return tokenString, expiresAt, nil
|
||||
}
|
||||
|
||||
// GeneratePersonalAccessToken generates a random PAT string.
|
||||
func GeneratePersonalAccessToken() string {
|
||||
randomStr, err := util.RandomString(32)
|
||||
if err != nil {
|
||||
// Fallback to UUID if RandomString fails
|
||||
return PersonalAccessTokenPrefix + util.GenUUID()
|
||||
}
|
||||
return PersonalAccessTokenPrefix + randomStr
|
||||
}
|
||||
|
||||
// HashPersonalAccessToken returns SHA-256 hash of a PAT.
|
||||
func HashPersonalAccessToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// verifyJWTKeyFunc returns a jwt.Keyfunc that validates the signing method and key ID.
|
||||
func verifyJWTKeyFunc(secret []byte) jwt.Keyfunc {
|
||||
return func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, errors.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
kid, ok := t.Header["kid"].(string)
|
||||
if !ok || kid != KeyID {
|
||||
return nil, errors.Errorf("unexpected kid: %v", t.Header["kid"])
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ParseAccessTokenV2 parses and validates a short-lived access token.
|
||||
func ParseAccessTokenV2(tokenString string, secret []byte) (*AccessTokenClaims, error) {
|
||||
claims := &AccessTokenClaims{}
|
||||
_, err := jwt.ParseWithClaims(tokenString, claims, verifyJWTKeyFunc(secret),
|
||||
jwt.WithIssuer(Issuer),
|
||||
jwt.WithAudience(AccessTokenAudienceName),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if claims.Type != "access" {
|
||||
return nil, errors.New("invalid token type: expected access token")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// ParseRefreshToken parses and validates a refresh token.
|
||||
func ParseRefreshToken(tokenString string, secret []byte) (*RefreshTokenClaims, error) {
|
||||
claims := &RefreshTokenClaims{}
|
||||
_, err := jwt.ParseWithClaims(tokenString, claims, verifyJWTKeyFunc(secret),
|
||||
jwt.WithIssuer(Issuer),
|
||||
jwt.WithAudience(RefreshTokenAudienceName),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if claims.Type != "refresh" {
|
||||
return nil, errors.New("invalid token type: expected refresh token")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
306
server/auth/token_test.go
Normal file
306
server/auth/token_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateAccessTokenV2(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("generates valid access token", func(t *testing.T) {
|
||||
token, expiresAt, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, expiresAt.After(time.Now()))
|
||||
assert.True(t, expiresAt.Before(time.Now().Add(AccessTokenDuration+time.Minute)))
|
||||
})
|
||||
|
||||
t.Run("generates different tokens for same user", func(t *testing.T) {
|
||||
token1, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second) // Ensure different timestamps (tokens have 1s precision)
|
||||
|
||||
token2, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, token1, token2, "tokens should be different due to different timestamps")
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseAccessTokenV2(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("parses valid access token", func(t *testing.T) {
|
||||
token, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := ParseAccessTokenV2(token, secret)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1", claims.Subject)
|
||||
assert.Equal(t, "testuser", claims.Username)
|
||||
assert.Equal(t, "USER", claims.Role)
|
||||
assert.Equal(t, "ACTIVE", claims.Status)
|
||||
assert.Equal(t, "access", claims.Type)
|
||||
})
|
||||
|
||||
t.Run("fails with wrong secret", func(t *testing.T) {
|
||||
token, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongSecret := []byte("wrong-secret")
|
||||
_, err = ParseAccessTokenV2(token, wrongSecret)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid token", func(t *testing.T) {
|
||||
_, err := ParseAccessTokenV2("invalid-token", secret)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with refresh token", func(t *testing.T) {
|
||||
// Generate a refresh token and try to parse it as access token
|
||||
// Should fail because audience mismatch is caught before type check
|
||||
refreshToken, _, err := GenerateRefreshToken(1, "token-id", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ParseAccessTokenV2(refreshToken, secret)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid audience")
|
||||
})
|
||||
|
||||
t.Run("parses token with different roles", func(t *testing.T) {
|
||||
roles := []string{"USER", "ADMIN"}
|
||||
for _, role := range roles {
|
||||
token, _, err := GenerateAccessTokenV2(1, "testuser", role, "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := ParseAccessTokenV2(token, secret)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, role, claims.Role)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateRefreshToken(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("generates valid refresh token", func(t *testing.T) {
|
||||
token, expiresAt, err := GenerateRefreshToken(1, "token-id-123", secret)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, expiresAt.After(time.Now().Add(29*24*time.Hour)))
|
||||
})
|
||||
|
||||
t.Run("generates different tokens for different token IDs", func(t *testing.T) {
|
||||
token1, _, err := GenerateRefreshToken(1, "token-id-1", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
token2, _, err := GenerateRefreshToken(1, "token-id-2", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, token1, token2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseRefreshToken(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("parses valid refresh token", func(t *testing.T) {
|
||||
token, _, err := GenerateRefreshToken(1, "token-id-123", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := ParseRefreshToken(token, secret)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1", claims.Subject)
|
||||
assert.Equal(t, "token-id-123", claims.TokenID)
|
||||
assert.Equal(t, "refresh", claims.Type)
|
||||
})
|
||||
|
||||
t.Run("fails with wrong secret", func(t *testing.T) {
|
||||
token, _, err := GenerateRefreshToken(1, "token-id-123", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongSecret := []byte("wrong-secret")
|
||||
_, err = ParseRefreshToken(token, wrongSecret)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid token", func(t *testing.T) {
|
||||
_, err := ParseRefreshToken("invalid-token", secret)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with access token", func(t *testing.T) {
|
||||
// Generate an access token and try to parse it as refresh token
|
||||
// Should fail because audience mismatch is caught before type check
|
||||
accessToken, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ParseRefreshToken(accessToken, secret)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid audience")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeneratePersonalAccessToken(t *testing.T) {
|
||||
t.Run("generates token with correct prefix", func(t *testing.T) {
|
||||
token := GeneratePersonalAccessToken()
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, len(token) > len(PersonalAccessTokenPrefix))
|
||||
assert.Equal(t, PersonalAccessTokenPrefix, token[:len(PersonalAccessTokenPrefix)])
|
||||
})
|
||||
|
||||
t.Run("generates unique tokens", func(t *testing.T) {
|
||||
token1 := GeneratePersonalAccessToken()
|
||||
token2 := GeneratePersonalAccessToken()
|
||||
assert.NotEqual(t, token1, token2)
|
||||
})
|
||||
|
||||
t.Run("generates token of sufficient length", func(t *testing.T) {
|
||||
token := GeneratePersonalAccessToken()
|
||||
// Prefix is "memos_pat_" (10 chars) + 32 random chars = at least 42 chars
|
||||
assert.True(t, len(token) >= 42, "token should be at least 42 characters")
|
||||
})
|
||||
}
|
||||
|
||||
func TestHashPersonalAccessToken(t *testing.T) {
|
||||
t.Run("generates SHA-256 hash", func(t *testing.T) {
|
||||
token := "memos_pat_abc123"
|
||||
hash := HashPersonalAccessToken(token)
|
||||
assert.NotEmpty(t, hash)
|
||||
assert.Len(t, hash, 64, "SHA-256 hex should be 64 characters")
|
||||
})
|
||||
|
||||
t.Run("same input produces same hash", func(t *testing.T) {
|
||||
token := "memos_pat_abc123"
|
||||
hash1 := HashPersonalAccessToken(token)
|
||||
hash2 := HashPersonalAccessToken(token)
|
||||
assert.Equal(t, hash1, hash2)
|
||||
})
|
||||
|
||||
t.Run("different inputs produce different hashes", func(t *testing.T) {
|
||||
token1 := "memos_pat_abc123"
|
||||
token2 := "memos_pat_xyz789"
|
||||
hash1 := HashPersonalAccessToken(token1)
|
||||
hash2 := HashPersonalAccessToken(token2)
|
||||
assert.NotEqual(t, hash1, hash2)
|
||||
})
|
||||
|
||||
t.Run("hash is deterministic", func(t *testing.T) {
|
||||
token := GeneratePersonalAccessToken()
|
||||
hash1 := HashPersonalAccessToken(token)
|
||||
hash2 := HashPersonalAccessToken(token)
|
||||
assert.Equal(t, hash1, hash2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccessTokenV2Integration(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("full lifecycle: generate, parse, validate", func(t *testing.T) {
|
||||
userID := int32(42)
|
||||
username := "john_doe"
|
||||
role := "ADMIN"
|
||||
status := "ACTIVE"
|
||||
|
||||
// Generate token
|
||||
token, expiresAt, err := GenerateAccessTokenV2(userID, username, role, status, secret)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Parse token
|
||||
claims, err := ParseAccessTokenV2(token, secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate claims
|
||||
assert.Equal(t, "42", claims.Subject)
|
||||
assert.Equal(t, username, claims.Username)
|
||||
assert.Equal(t, role, claims.Role)
|
||||
assert.Equal(t, status, claims.Status)
|
||||
assert.Equal(t, "access", claims.Type)
|
||||
assert.Equal(t, Issuer, claims.Issuer)
|
||||
assert.NotNil(t, claims.IssuedAt)
|
||||
assert.NotNil(t, claims.ExpiresAt)
|
||||
|
||||
// Validate expiration
|
||||
assert.True(t, claims.ExpiresAt.Equal(expiresAt) || claims.ExpiresAt.Before(expiresAt))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenIntegration(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("full lifecycle: generate, parse, validate", func(t *testing.T) {
|
||||
userID := int32(42)
|
||||
tokenID := "unique-token-id-456"
|
||||
|
||||
// Generate token
|
||||
token, expiresAt, err := GenerateRefreshToken(userID, tokenID, secret)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Parse token
|
||||
claims, err := ParseRefreshToken(token, secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate claims
|
||||
assert.Equal(t, "42", claims.Subject)
|
||||
assert.Equal(t, tokenID, claims.TokenID)
|
||||
assert.Equal(t, "refresh", claims.Type)
|
||||
assert.Equal(t, Issuer, claims.Issuer)
|
||||
assert.NotNil(t, claims.IssuedAt)
|
||||
assert.NotNil(t, claims.ExpiresAt)
|
||||
|
||||
// Validate expiration
|
||||
assert.True(t, claims.ExpiresAt.Equal(expiresAt) || claims.ExpiresAt.Before(expiresAt))
|
||||
})
|
||||
}
|
||||
|
||||
func TestPersonalAccessTokenIntegration(t *testing.T) {
|
||||
t.Run("full lifecycle: generate, hash, verify", func(t *testing.T) {
|
||||
// Generate token
|
||||
token := GeneratePersonalAccessToken()
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, len(token) > len(PersonalAccessTokenPrefix))
|
||||
|
||||
// Hash token
|
||||
hash := HashPersonalAccessToken(token)
|
||||
assert.Len(t, hash, 64)
|
||||
|
||||
// Verify same token produces same hash
|
||||
hashAgain := HashPersonalAccessToken(token)
|
||||
assert.Equal(t, hash, hashAgain)
|
||||
|
||||
// Verify different token produces different hash
|
||||
token2 := GeneratePersonalAccessToken()
|
||||
hash2 := HashPersonalAccessToken(token2)
|
||||
assert.NotEqual(t, hash, hash2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenExpiration(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("access token expires after AccessTokenDuration", func(t *testing.T) {
|
||||
_, expiresAt, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedExpiry := time.Now().Add(AccessTokenDuration)
|
||||
delta := expiresAt.Sub(expectedExpiry)
|
||||
assert.True(t, delta < time.Second, "expiration should be within 1 second of expected")
|
||||
})
|
||||
|
||||
t.Run("refresh token expires after RefreshTokenDuration", func(t *testing.T) {
|
||||
_, expiresAt, err := GenerateRefreshToken(1, "token-id", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedExpiry := time.Now().Add(RefreshTokenDuration)
|
||||
delta := expiresAt.Sub(expectedExpiry)
|
||||
assert.True(t, delta < time.Second, "expiration should be within 1 second of expected")
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user