first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled

This commit is contained in:
2026-03-04 06:30:47 +00:00
commit bb402d4ccc
777 changed files with 135661 additions and 0 deletions

View File

@@ -0,0 +1,199 @@
package auth
import (
"context"
"log/slog"
"strings"
"time"
"github.com/pkg/errors"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// Authenticator provides shared authentication and authorization logic.
// Used by gRPC interceptor, Connect interceptor, and file server to ensure
// consistent authentication behavior across all API endpoints.
//
// Authentication methods:
// - JWT access tokens: Short-lived tokens (15 minutes) for API access
// - Personal Access Tokens (PAT): Long-lived tokens for programmatic access
//
// This struct is safe for concurrent use.
type Authenticator struct {
store *store.Store
secret string
}
// NewAuthenticator creates a new Authenticator instance.
func NewAuthenticator(store *store.Store, secret string) *Authenticator {
return &Authenticator{
store: store,
secret: secret,
}
}
// AuthenticateByAccessTokenV2 validates a short-lived access token.
// Returns claims without database query (stateless validation).
func (a *Authenticator) AuthenticateByAccessTokenV2(accessToken string) (*UserClaims, error) {
claims, err := ParseAccessTokenV2(accessToken, []byte(a.secret))
if err != nil {
return nil, errors.Wrap(err, "invalid access token")
}
userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil {
return nil, errors.Wrap(err, "invalid user ID in token")
}
return &UserClaims{
UserID: userID,
Username: claims.Username,
Role: claims.Role,
Status: claims.Status,
}, nil
}
// AuthenticateByRefreshToken validates a refresh token against the database.
func (a *Authenticator) AuthenticateByRefreshToken(ctx context.Context, refreshToken string) (*store.User, string, error) {
claims, err := ParseRefreshToken(refreshToken, []byte(a.secret))
if err != nil {
return nil, "", errors.Wrap(err, "invalid refresh token")
}
userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil {
return nil, "", errors.Wrap(err, "invalid user ID in token")
}
// Check token exists in database (revocation check)
token, err := a.store.GetUserRefreshTokenByID(ctx, userID, claims.TokenID)
if err != nil {
return nil, "", errors.Wrap(err, "failed to get refresh token")
}
if token == nil {
return nil, "", errors.New("refresh token revoked")
}
// Check token not expired
if token.ExpiresAt != nil && token.ExpiresAt.AsTime().Before(time.Now()) {
return nil, "", errors.New("refresh token expired")
}
// Get user
user, err := a.store.GetUser(ctx, &store.FindUser{ID: &userID})
if err != nil {
return nil, "", errors.Wrap(err, "failed to get user")
}
if user == nil {
return nil, "", errors.New("user not found")
}
if user.RowStatus == store.Archived {
return nil, "", errors.New("user is archived")
}
return user, claims.TokenID, nil
}
// AuthenticateByPAT validates a Personal Access Token.
func (a *Authenticator) AuthenticateByPAT(ctx context.Context, token string) (*store.User, *storepb.PersonalAccessTokensUserSetting_PersonalAccessToken, error) {
if !strings.HasPrefix(token, PersonalAccessTokenPrefix) {
return nil, nil, errors.New("invalid PAT format")
}
tokenHash := HashPersonalAccessToken(token)
result, err := a.store.GetUserByPATHash(ctx, tokenHash)
if err != nil {
return nil, nil, errors.Wrap(err, "invalid PAT")
}
// Check expiry
if result.PAT.ExpiresAt != nil && result.PAT.ExpiresAt.AsTime().Before(time.Now()) {
return nil, nil, errors.New("PAT expired")
}
// Check user status
if result.User.RowStatus == store.Archived {
return nil, nil, errors.New("user is archived")
}
return result.User, result.PAT, nil
}
// AuthResult contains the result of an authentication attempt.
type AuthResult struct {
User *store.User // Set for PAT authentication
Claims *UserClaims // Set for Access Token V2 (stateless)
AccessToken string // Non-empty if authenticated via JWT
}
// AuthenticateToUser resolves the current request to a *store.User, checking the
// Authorization header first (access token or PAT), then falling back to the
// refresh token cookie. Returns (nil, nil) when no credentials are present.
func (a *Authenticator) AuthenticateToUser(ctx context.Context, authHeader, cookieHeader string) (*store.User, error) {
// Try Bearer token first.
if authHeader != "" {
token := ExtractBearerToken(authHeader)
if token != "" {
if !strings.HasPrefix(token, PersonalAccessTokenPrefix) {
claims, err := a.AuthenticateByAccessTokenV2(token)
if err == nil && claims != nil {
return a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
}
} else {
user, _, err := a.AuthenticateByPAT(ctx, token)
if err == nil {
return user, nil
}
}
}
}
// Fallback: refresh token cookie.
if cookieHeader != "" {
refreshToken := ExtractRefreshTokenFromCookie(cookieHeader)
if refreshToken != "" {
user, _, err := a.AuthenticateByRefreshToken(ctx, refreshToken)
return user, err
}
}
return nil, nil
}
// Authenticate tries to authenticate using the provided credentials.
// Priority: 1. Access Token V2, 2. PAT
// Returns nil if no valid credentials are provided.
func (a *Authenticator) Authenticate(ctx context.Context, authHeader string) *AuthResult {
token := ExtractBearerToken(authHeader)
// Try Access Token V2 (stateless)
if token != "" && !strings.HasPrefix(token, PersonalAccessTokenPrefix) {
claims, err := a.AuthenticateByAccessTokenV2(token)
if err == nil && claims != nil {
return &AuthResult{
Claims: claims,
AccessToken: token,
}
}
}
// Try PAT
if token != "" && strings.HasPrefix(token, PersonalAccessTokenPrefix) {
user, pat, err := a.AuthenticateByPAT(ctx, token)
if err == nil && user != nil {
// Update last used (fire-and-forget with logging)
go func() {
if err := a.store.UpdatePATLastUsed(context.Background(), user.ID, pat.TokenId, timestamppb.Now()); err != nil {
slog.Warn("failed to update PAT last used time", "error", err, "userID", user.ID)
}
}()
return &AuthResult{User: user, AccessToken: token}
}
}
return nil
}

99
server/auth/context.go Normal file
View File

@@ -0,0 +1,99 @@
package auth
import (
"context"
"github.com/usememos/memos/store"
)
// ContextKey is the key type for context values.
// Using a custom type prevents collisions with other packages.
type ContextKey int
const (
// UserIDContextKey stores the authenticated user's ID.
// Set for all authenticated requests.
// Use GetUserID(ctx) to retrieve this value.
UserIDContextKey ContextKey = iota
// AccessTokenContextKey stores the JWT token for token-based auth.
// Only set when authenticated via Bearer token.
AccessTokenContextKey
// UserClaimsContextKey stores the claims from access token.
UserClaimsContextKey
// RefreshTokenIDContextKey stores the refresh token ID.
RefreshTokenIDContextKey
)
// GetUserID retrieves the authenticated user's ID from the context.
// Returns 0 if no user ID is set (unauthenticated request).
func GetUserID(ctx context.Context) int32 {
if v, ok := ctx.Value(UserIDContextKey).(int32); ok {
return v
}
return 0
}
// GetAccessToken retrieves the JWT access token from the context.
// Returns empty string if not authenticated via bearer token.
func GetAccessToken(ctx context.Context) string {
if v, ok := ctx.Value(AccessTokenContextKey).(string); ok {
return v
}
return ""
}
// SetUserInContext sets the authenticated user's information in the context.
// This is a simpler alternative to AuthorizeAndSetContext for cases where
// authorization is handled separately (e.g., HTTP middleware).
//
// Parameters:
// - user: The authenticated user
// - accessToken: Set if authenticated via JWT token (empty string otherwise)
func SetUserInContext(ctx context.Context, user *store.User, accessToken string) context.Context {
ctx = context.WithValue(ctx, UserIDContextKey, user.ID)
if accessToken != "" {
ctx = context.WithValue(ctx, AccessTokenContextKey, accessToken)
}
return ctx
}
// UserClaims represents authenticated user info from access token.
type UserClaims struct {
UserID int32
Username string
Role string
Status string
}
// GetUserClaims retrieves the user claims from context.
// Returns nil if not authenticated via access token.
func GetUserClaims(ctx context.Context) *UserClaims {
if v, ok := ctx.Value(UserClaimsContextKey).(*UserClaims); ok {
return v
}
return nil
}
// SetUserClaimsInContext sets the user claims in context.
func SetUserClaimsInContext(ctx context.Context, claims *UserClaims) context.Context {
return context.WithValue(ctx, UserClaimsContextKey, claims)
}
// ApplyToContext sets the authenticated identity from an AuthResult into the context.
// This is the canonical way to propagate auth state after a successful Authenticate call.
// Safe to call with a nil result (no-op).
func ApplyToContext(ctx context.Context, result *AuthResult) context.Context {
if result == nil {
return ctx
}
if result.Claims != nil {
ctx = SetUserClaimsInContext(ctx, result.Claims)
ctx = context.WithValue(ctx, UserIDContextKey, result.Claims.UserID)
} else if result.User != nil {
ctx = SetUserInContext(ctx, result.User, result.AccessToken)
}
return ctx
}

33
server/auth/extract.go Normal file
View File

@@ -0,0 +1,33 @@
package auth
import (
"net/http"
"strings"
)
// ExtractBearerToken extracts the JWT token from an Authorization header value.
// Expected format: "Bearer {token}"
// Returns empty string if no valid bearer token is found.
func ExtractBearerToken(authHeader string) string {
if authHeader == "" {
return ""
}
parts := strings.Fields(authHeader)
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
return ""
}
return parts[1]
}
// ExtractRefreshTokenFromCookie extracts the refresh token from cookie header.
func ExtractRefreshTokenFromCookie(cookieHeader string) string {
if cookieHeader == "" {
return ""
}
req := &http.Request{Header: http.Header{"Cookie": []string{cookieHeader}}}
cookie, err := req.Cookie(RefreshTokenCookieName)
if err != nil {
return ""
}
return cookie.Value
}

249
server/auth/token.go Normal file
View File

@@ -0,0 +1,249 @@
// Package auth provides authentication and authorization for the Memos server.
//
// This package is used by:
// - server/router/api/v1: gRPC and Connect API interceptors
// - server/router/fileserver: HTTP file server authentication
//
// Authentication methods supported:
// - JWT access tokens: Short-lived tokens (15 minutes) for API access
// - JWT refresh tokens: Long-lived tokens (30 days) for obtaining new access tokens
// - Personal Access Tokens (PAT): Long-lived tokens for programmatic access
package auth
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
)
const (
// Issuer is the issuer claim in JWT tokens.
// This identifies tokens as issued by Memos.
Issuer = "memos"
// KeyID is the key identifier used in JWT header.
// Version "v1" allows for future key rotation while maintaining backward compatibility.
// If signing mechanism changes, add "v2", "v3", etc. and verify both versions.
KeyID = "v1"
// AccessTokenAudienceName is the audience claim for JWT access tokens.
// This ensures tokens are only used for API access, not other purposes.
AccessTokenAudienceName = "user.access-token"
// AccessTokenDuration is the lifetime of access tokens (15 minutes).
AccessTokenDuration = 15 * time.Minute
// RefreshTokenDuration is the lifetime of refresh tokens (30 days).
RefreshTokenDuration = 30 * 24 * time.Hour
// RefreshTokenAudienceName is the audience claim for refresh tokens.
RefreshTokenAudienceName = "user.refresh-token"
// RefreshTokenCookieName is the cookie name for refresh tokens.
RefreshTokenCookieName = "memos_refresh"
// PersonalAccessTokenPrefix is the prefix for PAT tokens.
PersonalAccessTokenPrefix = "memos_pat_"
)
// ClaimsMessage represents the claims structure in a JWT token.
//
// JWT Claims include:
// - name: Username (custom claim)
// - iss: Issuer = "memos"
// - aud: Audience = "user.access-token"
// - sub: Subject = user ID
// - iat: Issued at time
// - exp: Expiration time (optional, may be empty for never-expiring tokens).
type ClaimsMessage struct {
Name string `json:"name"` // Username
jwt.RegisteredClaims
}
// AccessTokenClaims contains claims for short-lived access tokens.
// These tokens are validated by signature only (stateless).
type AccessTokenClaims struct {
Type string `json:"type"` // "access"
Role string `json:"role"` // User role
Status string `json:"status"` // User status
Username string `json:"username"` // Username for display
jwt.RegisteredClaims
}
// RefreshTokenClaims contains claims for long-lived refresh tokens.
// These tokens are validated against the database for revocation.
type RefreshTokenClaims struct {
Type string `json:"type"` // "refresh"
TokenID string `json:"tid"` // Token ID for revocation lookup
jwt.RegisteredClaims
}
// GenerateAccessToken generates a JWT access token for a user.
//
// Parameters:
// - username: The user's username (stored in "name" claim)
// - userID: The user's ID (stored in "sub" claim)
// - expirationTime: When the token expires (pass zero time for no expiration)
// - secret: Server secret used to sign the token
//
// Returns a signed JWT string or an error.
func GenerateAccessToken(username string, userID int32, expirationTime time.Time, secret []byte) (string, error) {
return generateToken(username, userID, AccessTokenAudienceName, expirationTime, secret)
}
// generateToken generates a JWT token with the given claims.
//
// Token structure:
// Header: {"alg": "HS256", "kid": "v1", "typ": "JWT"}
// Claims: {"name": username, "iss": "memos", "aud": [audience], "sub": userID, "iat": now, "exp": expiry}
// Signature: HMACSHA256(base64UrlEncode(header) + "." + base64UrlEncode(payload), secret).
func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) {
registeredClaims := jwt.RegisteredClaims{
Issuer: Issuer,
Audience: jwt.ClaimStrings{audience},
IssuedAt: jwt.NewNumericDate(time.Now()),
Subject: fmt.Sprint(userID),
}
if !expirationTime.IsZero() {
registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime)
}
// Declare the token with the HS256 algorithm used for signing, and the claims.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{
Name: username,
RegisteredClaims: registeredClaims,
})
token.Header["kid"] = KeyID
// Create the JWT string.
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}
// GenerateAccessTokenV2 generates a short-lived access token with user claims.
func GenerateAccessTokenV2(userID int32, username, role, status string, secret []byte) (string, time.Time, error) {
expiresAt := time.Now().Add(AccessTokenDuration)
claims := &AccessTokenClaims{
Type: "access",
Role: role,
Status: status,
Username: username,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: Issuer,
Audience: jwt.ClaimStrings{AccessTokenAudienceName},
Subject: fmt.Sprint(userID),
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = KeyID
tokenString, err := token.SignedString(secret)
if err != nil {
return "", time.Time{}, err
}
return tokenString, expiresAt, nil
}
// GenerateRefreshToken generates a long-lived refresh token.
func GenerateRefreshToken(userID int32, tokenID string, secret []byte) (string, time.Time, error) {
expiresAt := time.Now().Add(RefreshTokenDuration)
claims := &RefreshTokenClaims{
Type: "refresh",
TokenID: tokenID,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: Issuer,
Audience: jwt.ClaimStrings{RefreshTokenAudienceName},
Subject: fmt.Sprint(userID),
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = KeyID
tokenString, err := token.SignedString(secret)
if err != nil {
return "", time.Time{}, err
}
return tokenString, expiresAt, nil
}
// GeneratePersonalAccessToken generates a random PAT string.
func GeneratePersonalAccessToken() string {
randomStr, err := util.RandomString(32)
if err != nil {
// Fallback to UUID if RandomString fails
return PersonalAccessTokenPrefix + util.GenUUID()
}
return PersonalAccessTokenPrefix + randomStr
}
// HashPersonalAccessToken returns SHA-256 hash of a PAT.
func HashPersonalAccessToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// verifyJWTKeyFunc returns a jwt.Keyfunc that validates the signing method and key ID.
func verifyJWTKeyFunc(secret []byte) jwt.Keyfunc {
return func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected signing method: %v", t.Header["alg"])
}
kid, ok := t.Header["kid"].(string)
if !ok || kid != KeyID {
return nil, errors.Errorf("unexpected kid: %v", t.Header["kid"])
}
return secret, nil
}
}
// ParseAccessTokenV2 parses and validates a short-lived access token.
func ParseAccessTokenV2(tokenString string, secret []byte) (*AccessTokenClaims, error) {
claims := &AccessTokenClaims{}
_, err := jwt.ParseWithClaims(tokenString, claims, verifyJWTKeyFunc(secret),
jwt.WithIssuer(Issuer),
jwt.WithAudience(AccessTokenAudienceName),
)
if err != nil {
return nil, err
}
if claims.Type != "access" {
return nil, errors.New("invalid token type: expected access token")
}
return claims, nil
}
// ParseRefreshToken parses and validates a refresh token.
func ParseRefreshToken(tokenString string, secret []byte) (*RefreshTokenClaims, error) {
claims := &RefreshTokenClaims{}
_, err := jwt.ParseWithClaims(tokenString, claims, verifyJWTKeyFunc(secret),
jwt.WithIssuer(Issuer),
jwt.WithAudience(RefreshTokenAudienceName),
)
if err != nil {
return nil, err
}
if claims.Type != "refresh" {
return nil, errors.New("invalid token type: expected refresh token")
}
return claims, nil
}

306
server/auth/token_test.go Normal file
View File

@@ -0,0 +1,306 @@
package auth
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGenerateAccessTokenV2(t *testing.T) {
secret := []byte("test-secret")
t.Run("generates valid access token", func(t *testing.T) {
token, expiresAt, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
require.NoError(t, err)
assert.NotEmpty(t, token)
assert.True(t, expiresAt.After(time.Now()))
assert.True(t, expiresAt.Before(time.Now().Add(AccessTokenDuration+time.Minute)))
})
t.Run("generates different tokens for same user", func(t *testing.T) {
token1, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
require.NoError(t, err)
time.Sleep(2 * time.Second) // Ensure different timestamps (tokens have 1s precision)
token2, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
require.NoError(t, err)
assert.NotEqual(t, token1, token2, "tokens should be different due to different timestamps")
})
}
func TestParseAccessTokenV2(t *testing.T) {
secret := []byte("test-secret")
t.Run("parses valid access token", func(t *testing.T) {
token, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
require.NoError(t, err)
claims, err := ParseAccessTokenV2(token, secret)
require.NoError(t, err)
assert.Equal(t, "1", claims.Subject)
assert.Equal(t, "testuser", claims.Username)
assert.Equal(t, "USER", claims.Role)
assert.Equal(t, "ACTIVE", claims.Status)
assert.Equal(t, "access", claims.Type)
})
t.Run("fails with wrong secret", func(t *testing.T) {
token, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
require.NoError(t, err)
wrongSecret := []byte("wrong-secret")
_, err = ParseAccessTokenV2(token, wrongSecret)
assert.Error(t, err)
})
t.Run("fails with invalid token", func(t *testing.T) {
_, err := ParseAccessTokenV2("invalid-token", secret)
assert.Error(t, err)
})
t.Run("fails with refresh token", func(t *testing.T) {
// Generate a refresh token and try to parse it as access token
// Should fail because audience mismatch is caught before type check
refreshToken, _, err := GenerateRefreshToken(1, "token-id", secret)
require.NoError(t, err)
_, err = ParseAccessTokenV2(refreshToken, secret)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid audience")
})
t.Run("parses token with different roles", func(t *testing.T) {
roles := []string{"USER", "ADMIN"}
for _, role := range roles {
token, _, err := GenerateAccessTokenV2(1, "testuser", role, "ACTIVE", secret)
require.NoError(t, err)
claims, err := ParseAccessTokenV2(token, secret)
require.NoError(t, err)
assert.Equal(t, role, claims.Role)
}
})
}
func TestGenerateRefreshToken(t *testing.T) {
secret := []byte("test-secret")
t.Run("generates valid refresh token", func(t *testing.T) {
token, expiresAt, err := GenerateRefreshToken(1, "token-id-123", secret)
require.NoError(t, err)
assert.NotEmpty(t, token)
assert.True(t, expiresAt.After(time.Now().Add(29*24*time.Hour)))
})
t.Run("generates different tokens for different token IDs", func(t *testing.T) {
token1, _, err := GenerateRefreshToken(1, "token-id-1", secret)
require.NoError(t, err)
token2, _, err := GenerateRefreshToken(1, "token-id-2", secret)
require.NoError(t, err)
assert.NotEqual(t, token1, token2)
})
}
func TestParseRefreshToken(t *testing.T) {
secret := []byte("test-secret")
t.Run("parses valid refresh token", func(t *testing.T) {
token, _, err := GenerateRefreshToken(1, "token-id-123", secret)
require.NoError(t, err)
claims, err := ParseRefreshToken(token, secret)
require.NoError(t, err)
assert.Equal(t, "1", claims.Subject)
assert.Equal(t, "token-id-123", claims.TokenID)
assert.Equal(t, "refresh", claims.Type)
})
t.Run("fails with wrong secret", func(t *testing.T) {
token, _, err := GenerateRefreshToken(1, "token-id-123", secret)
require.NoError(t, err)
wrongSecret := []byte("wrong-secret")
_, err = ParseRefreshToken(token, wrongSecret)
assert.Error(t, err)
})
t.Run("fails with invalid token", func(t *testing.T) {
_, err := ParseRefreshToken("invalid-token", secret)
assert.Error(t, err)
})
t.Run("fails with access token", func(t *testing.T) {
// Generate an access token and try to parse it as refresh token
// Should fail because audience mismatch is caught before type check
accessToken, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
require.NoError(t, err)
_, err = ParseRefreshToken(accessToken, secret)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid audience")
})
}
func TestGeneratePersonalAccessToken(t *testing.T) {
t.Run("generates token with correct prefix", func(t *testing.T) {
token := GeneratePersonalAccessToken()
assert.NotEmpty(t, token)
assert.True(t, len(token) > len(PersonalAccessTokenPrefix))
assert.Equal(t, PersonalAccessTokenPrefix, token[:len(PersonalAccessTokenPrefix)])
})
t.Run("generates unique tokens", func(t *testing.T) {
token1 := GeneratePersonalAccessToken()
token2 := GeneratePersonalAccessToken()
assert.NotEqual(t, token1, token2)
})
t.Run("generates token of sufficient length", func(t *testing.T) {
token := GeneratePersonalAccessToken()
// Prefix is "memos_pat_" (10 chars) + 32 random chars = at least 42 chars
assert.True(t, len(token) >= 42, "token should be at least 42 characters")
})
}
func TestHashPersonalAccessToken(t *testing.T) {
t.Run("generates SHA-256 hash", func(t *testing.T) {
token := "memos_pat_abc123"
hash := HashPersonalAccessToken(token)
assert.NotEmpty(t, hash)
assert.Len(t, hash, 64, "SHA-256 hex should be 64 characters")
})
t.Run("same input produces same hash", func(t *testing.T) {
token := "memos_pat_abc123"
hash1 := HashPersonalAccessToken(token)
hash2 := HashPersonalAccessToken(token)
assert.Equal(t, hash1, hash2)
})
t.Run("different inputs produce different hashes", func(t *testing.T) {
token1 := "memos_pat_abc123"
token2 := "memos_pat_xyz789"
hash1 := HashPersonalAccessToken(token1)
hash2 := HashPersonalAccessToken(token2)
assert.NotEqual(t, hash1, hash2)
})
t.Run("hash is deterministic", func(t *testing.T) {
token := GeneratePersonalAccessToken()
hash1 := HashPersonalAccessToken(token)
hash2 := HashPersonalAccessToken(token)
assert.Equal(t, hash1, hash2)
})
}
func TestAccessTokenV2Integration(t *testing.T) {
secret := []byte("test-secret")
t.Run("full lifecycle: generate, parse, validate", func(t *testing.T) {
userID := int32(42)
username := "john_doe"
role := "ADMIN"
status := "ACTIVE"
// Generate token
token, expiresAt, err := GenerateAccessTokenV2(userID, username, role, status, secret)
require.NoError(t, err)
assert.NotEmpty(t, token)
// Parse token
claims, err := ParseAccessTokenV2(token, secret)
require.NoError(t, err)
// Validate claims
assert.Equal(t, "42", claims.Subject)
assert.Equal(t, username, claims.Username)
assert.Equal(t, role, claims.Role)
assert.Equal(t, status, claims.Status)
assert.Equal(t, "access", claims.Type)
assert.Equal(t, Issuer, claims.Issuer)
assert.NotNil(t, claims.IssuedAt)
assert.NotNil(t, claims.ExpiresAt)
// Validate expiration
assert.True(t, claims.ExpiresAt.Equal(expiresAt) || claims.ExpiresAt.Before(expiresAt))
})
}
func TestRefreshTokenIntegration(t *testing.T) {
secret := []byte("test-secret")
t.Run("full lifecycle: generate, parse, validate", func(t *testing.T) {
userID := int32(42)
tokenID := "unique-token-id-456"
// Generate token
token, expiresAt, err := GenerateRefreshToken(userID, tokenID, secret)
require.NoError(t, err)
assert.NotEmpty(t, token)
// Parse token
claims, err := ParseRefreshToken(token, secret)
require.NoError(t, err)
// Validate claims
assert.Equal(t, "42", claims.Subject)
assert.Equal(t, tokenID, claims.TokenID)
assert.Equal(t, "refresh", claims.Type)
assert.Equal(t, Issuer, claims.Issuer)
assert.NotNil(t, claims.IssuedAt)
assert.NotNil(t, claims.ExpiresAt)
// Validate expiration
assert.True(t, claims.ExpiresAt.Equal(expiresAt) || claims.ExpiresAt.Before(expiresAt))
})
}
func TestPersonalAccessTokenIntegration(t *testing.T) {
t.Run("full lifecycle: generate, hash, verify", func(t *testing.T) {
// Generate token
token := GeneratePersonalAccessToken()
assert.NotEmpty(t, token)
assert.True(t, len(token) > len(PersonalAccessTokenPrefix))
// Hash token
hash := HashPersonalAccessToken(token)
assert.Len(t, hash, 64)
// Verify same token produces same hash
hashAgain := HashPersonalAccessToken(token)
assert.Equal(t, hash, hashAgain)
// Verify different token produces different hash
token2 := GeneratePersonalAccessToken()
hash2 := HashPersonalAccessToken(token2)
assert.NotEqual(t, hash, hash2)
})
}
func TestTokenExpiration(t *testing.T) {
secret := []byte("test-secret")
t.Run("access token expires after AccessTokenDuration", func(t *testing.T) {
_, expiresAt, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
require.NoError(t, err)
expectedExpiry := time.Now().Add(AccessTokenDuration)
delta := expiresAt.Sub(expectedExpiry)
assert.True(t, delta < time.Second, "expiration should be within 1 second of expected")
})
t.Run("refresh token expires after RefreshTokenDuration", func(t *testing.T) {
_, expiresAt, err := GenerateRefreshToken(1, "token-id", secret)
require.NoError(t, err)
expectedExpiry := time.Now().Add(RefreshTokenDuration)
delta := expiresAt.Sub(expectedExpiry)
assert.True(t, delta < time.Second, "expiration should be within 1 second of expected")
})
}