first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled

This commit is contained in:
2026-03-04 06:30:47 +00:00
commit bb402d4ccc
777 changed files with 135661 additions and 0 deletions

13
store/test/README.md Normal file
View File

@@ -0,0 +1,13 @@
# Store tests
## How to test store with MySQL?
1. Create a database in your MySQL server.
2. Run the following command with two environment variables set:
```go
DRIVER=mysql DSN=root@/memos_test go test -v ./test/store/...
```
- `DRIVER` should be set to `mysql`.
- `DSN` should be set to the DSN of your MySQL server.

380
store/test/activity_test.go Normal file
View File

@@ -0,0 +1,380 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func TestActivityStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
create := &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
}
activity, err := ts.CreateActivity(ctx, create)
require.NoError(t, err)
require.NotNil(t, activity)
activities, err := ts.ListActivities(ctx, &store.FindActivity{
ID: &activity.ID,
})
require.NoError(t, err)
require.Equal(t, 1, len(activities))
require.Equal(t, activity, activities[0])
ts.Close()
}
func TestActivityGetByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
// Get activity by ID
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, activity.ID, found.ID)
// Get non-existent activity
nonExistentID := int32(99999)
notFound, err := ts.GetActivity(ctx, &store.FindActivity{ID: &nonExistentID})
require.NoError(t, err)
require.Nil(t, notFound)
ts.Close()
}
func TestActivityListMultiple(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create multiple activities
_, err = ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
_, err = ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
// List all activities
allActivities, err := ts.ListActivities(ctx, &store.FindActivity{})
require.NoError(t, err)
require.Equal(t, 2, len(allActivities))
// List by type
commentType := store.ActivityTypeMemoComment
commentActivities, err := ts.ListActivities(ctx, &store.FindActivity{Type: &commentType})
require.NoError(t, err)
require.Equal(t, 2, len(commentActivities))
require.Equal(t, store.ActivityTypeMemoComment, commentActivities[0].Type)
ts.Close()
}
func TestActivityListByType(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create activities with MEMO_COMMENT type
_, err = ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
_, err = ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
// List by type
activityType := store.ActivityTypeMemoComment
activities, err := ts.ListActivities(ctx, &store.FindActivity{Type: &activityType})
require.NoError(t, err)
require.Len(t, activities, 2)
for _, activity := range activities {
require.Equal(t, store.ActivityTypeMemoComment, activity.Type)
}
ts.Close()
}
func TestActivityPayloadMemoComment(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create activity with MemoComment payload
memoID := int32(123)
relatedMemoID := int32(456)
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{
MemoComment: &storepb.ActivityMemoCommentPayload{
MemoId: memoID,
RelatedMemoId: relatedMemoID,
},
},
})
require.NoError(t, err)
require.NotNil(t, activity.Payload)
require.NotNil(t, activity.Payload.MemoComment)
require.Equal(t, memoID, activity.Payload.MemoComment.MemoId)
require.Equal(t, relatedMemoID, activity.Payload.MemoComment.RelatedMemoId)
// Verify payload is preserved when listing
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.NotNil(t, found.Payload.MemoComment)
require.Equal(t, memoID, found.Payload.MemoComment.MemoId)
require.Equal(t, relatedMemoID, found.Payload.MemoComment.RelatedMemoId)
ts.Close()
}
func TestActivityEmptyPayload(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create activity with empty payload
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
require.NotNil(t, activity.Payload)
// Verify empty payload is handled correctly
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.NotNil(t, found.Payload)
require.Nil(t, found.Payload.MemoComment)
ts.Close()
}
func TestActivityLevel(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create activity with INFO level
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
require.Equal(t, store.ActivityLevelInfo, activity.Level)
// Verify level is preserved when listing
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.Equal(t, store.ActivityLevelInfo, found.Level)
ts.Close()
}
func TestActivityCreatorID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
// Create activity for user1
activity1, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user1.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
require.Equal(t, user1.ID, activity1.CreatorID)
// Create activity for user2
activity2, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user2.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
require.Equal(t, user2.ID, activity2.CreatorID)
// List all and verify creator IDs
activities, err := ts.ListActivities(ctx, &store.FindActivity{})
require.NoError(t, err)
require.Len(t, activities, 2)
ts.Close()
}
func TestActivityCreatedTs(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
require.NotZero(t, activity.CreatedTs)
// Verify timestamp is preserved when listing
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.Equal(t, activity.CreatedTs, found.CreatedTs)
ts.Close()
}
func TestActivityListEmpty(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// List activities when none exist
activities, err := ts.ListActivities(ctx, &store.FindActivity{})
require.NoError(t, err)
require.Len(t, activities, 0)
ts.Close()
}
func TestActivityListWithIDAndType(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
// List with both ID and Type filters
activityType := store.ActivityTypeMemoComment
activities, err := ts.ListActivities(ctx, &store.FindActivity{
ID: &activity.ID,
Type: &activityType,
})
require.NoError(t, err)
require.Len(t, activities, 1)
require.Equal(t, activity.ID, activities[0].ID)
ts.Close()
}
func TestActivityPayloadComplexMemoComment(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a memo first to use its ID
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "test-memo-for-activity",
CreatorID: user.ID,
Content: "Test memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create comment memo
commentMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "comment-memo",
CreatorID: user.ID,
Content: "This is a comment",
Visibility: store.Public,
})
require.NoError(t, err)
// Create activity with real memo IDs
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{
MemoComment: &storepb.ActivityMemoCommentPayload{
MemoId: memo.ID,
RelatedMemoId: commentMemo.ID,
},
},
})
require.NoError(t, err)
require.Equal(t, memo.ID, activity.Payload.MemoComment.MemoId)
require.Equal(t, commentMemo.ID, activity.Payload.MemoComment.RelatedMemoId)
// Verify payload is preserved
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.Equal(t, memo.ID, found.Payload.MemoComment.MemoId)
require.Equal(t, commentMemo.ID, found.Payload.MemoComment.RelatedMemoId)
ts.Close()
}

View File

@@ -0,0 +1,375 @@
package test
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
// =============================================================================
// Filename Field Tests
// Schema: filename (string, supports contains)
// =============================================================================
func TestAttachmentFilterFilenameContains(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("report.pdf").MimeType("application/pdf"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
// Test: filename.contains("report") - single match
attachments := tc.ListWithFilter(`filename.contains("report")`)
require.Len(t, attachments, 1)
require.Contains(t, attachments[0].Filename, "report")
// Test: filename.contains(".pdf") - multiple matches
attachments = tc.ListWithFilter(`filename.contains(".pdf")`)
require.Len(t, attachments, 2)
// Test: filename.contains("nonexistent") - no matches
attachments = tc.ListWithFilter(`filename.contains("nonexistent")`)
require.Len(t, attachments, 0)
}
func TestAttachmentFilterFilenameSpecialCharacters(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).
Filename("file_with-special.chars@2024.pdf").MimeType("application/pdf"))
// Test: filename.contains with underscore
attachments := tc.ListWithFilter(`filename.contains("_with")`)
require.Len(t, attachments, 1)
// Test: filename.contains with @
attachments = tc.ListWithFilter(`filename.contains("@2024")`)
require.Len(t, attachments, 1)
}
func TestAttachmentFilterFilenameUnicode(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).
Filename("document_报告.pdf").MimeType("application/pdf"))
attachments := tc.ListWithFilter(`filename.contains("报告")`)
require.Len(t, attachments, 1)
}
// =============================================================================
// Mime Type Field Tests
// Schema: mime_type (string, ==, !=)
// =============================================================================
func TestAttachmentFilterMimeTypeEquals(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("photo.jpeg").MimeType("image/jpeg"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
// Test: mime_type == "image/png"
attachments := tc.ListWithFilter(`mime_type == "image/png"`)
require.Len(t, attachments, 1)
require.Equal(t, "image/png", attachments[0].Type)
// Test: mime_type == "application/pdf"
attachments = tc.ListWithFilter(`mime_type == "application/pdf"`)
require.Len(t, attachments, 1)
require.Equal(t, "application/pdf", attachments[0].Type)
}
func TestAttachmentFilterMimeTypeNotEquals(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
attachments := tc.ListWithFilter(`mime_type != "image/png"`)
require.Len(t, attachments, 1)
require.Equal(t, "application/pdf", attachments[0].Type)
}
func TestAttachmentFilterMimeTypeInList(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("photo.jpeg").MimeType("image/jpeg"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
// Test: mime_type in ["image/png", "image/jpeg"] - matches images
attachments := tc.ListWithFilter(`mime_type in ["image/png", "image/jpeg"]`)
require.Len(t, attachments, 2)
// Test: mime_type in ["video/mp4"] - no matches
attachments = tc.ListWithFilter(`mime_type in ["video/mp4"]`)
require.Len(t, attachments, 0)
}
// =============================================================================
// Create Time Field Tests
// Schema: create_time (timestamp, all comparison operators)
// Functions: now(), arithmetic (+, -, *)
// =============================================================================
func TestAttachmentFilterCreateTimeComparison(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
now := time.Now().Unix()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
// Test: create_time < future (should match)
attachments := tc.ListWithFilter(`create_time < ` + formatInt64(now+3600))
require.Len(t, attachments, 1)
// Test: create_time > past (should match)
attachments = tc.ListWithFilter(`create_time > ` + formatInt64(now-3600))
require.Len(t, attachments, 1)
// Test: create_time > future (should not match)
attachments = tc.ListWithFilter(`create_time > ` + formatInt64(now+3600))
require.Len(t, attachments, 0)
}
func TestAttachmentFilterCreateTimeWithNow(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
// Test: create_time < now() + 5 (buffer for container clock drift)
attachments := tc.ListWithFilter(`create_time < now() + 5`)
require.Len(t, attachments, 1)
// Test: create_time > now() + 5 (should not match)
attachments = tc.ListWithFilter(`create_time > now() + 5`)
require.Len(t, attachments, 0)
}
func TestAttachmentFilterCreateTimeArithmetic(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
// Test: create_time >= now() - 3600 (attachments created in last hour)
attachments := tc.ListWithFilter(`create_time >= now() - 3600`)
require.Len(t, attachments, 1)
// Test: create_time < now() - 86400 (attachments older than 1 day - should be empty)
attachments = tc.ListWithFilter(`create_time < now() - 86400`)
require.Len(t, attachments, 0)
// Test: Multiplication - create_time >= now() - 60 * 60
attachments = tc.ListWithFilter(`create_time >= now() - 60 * 60`)
require.Len(t, attachments, 1)
}
func TestAttachmentFilterAllComparisonOperators(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
// Test: < (less than)
attachments := tc.ListWithFilter(`create_time < now() + 3600`)
require.Len(t, attachments, 1)
// Test: <= (less than or equal) with buffer for clock drift
attachments = tc.ListWithFilter(`create_time < now() + 5`)
require.Len(t, attachments, 1)
// Test: > (greater than)
attachments = tc.ListWithFilter(`create_time > now() - 3600`)
require.Len(t, attachments, 1)
// Test: >= (greater than or equal)
attachments = tc.ListWithFilter(`create_time >= now() - 60`)
require.Len(t, attachments, 1)
}
// =============================================================================
// Memo ID Field Tests
// Schema: memo_id (int, ==, !=)
// =============================================================================
func TestAttachmentFilterMemoIdEquals(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContextWithUser(t)
defer tc.Close()
memo1 := tc.CreateMemo("memo-1", "Memo 1")
memo2 := tc.CreateMemo("memo-2", "Memo 2")
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("memo1_attachment.png").MimeType("image/png").MemoID(&memo1.ID))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("memo2_attachment.png").MimeType("image/png").MemoID(&memo2.ID))
attachments := tc.ListWithFilter(`memo_id == ` + formatInt32(memo1.ID))
require.Len(t, attachments, 1)
require.Equal(t, &memo1.ID, attachments[0].MemoID)
}
func TestAttachmentFilterMemoIdNotEquals(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContextWithUser(t)
defer tc.Close()
memo1 := tc.CreateMemo("memo-1", "Memo 1")
memo2 := tc.CreateMemo("memo-2", "Memo 2")
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("memo1_attachment.png").MimeType("image/png").MemoID(&memo1.ID))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("memo2_attachment.png").MimeType("image/png").MemoID(&memo2.ID))
attachments := tc.ListWithFilter(`memo_id != ` + formatInt32(memo1.ID))
require.Len(t, attachments, 1)
require.Equal(t, &memo2.ID, attachments[0].MemoID)
}
// =============================================================================
// Logical Operator Tests
// Operators: && (AND), || (OR), ! (NOT)
// =============================================================================
func TestAttachmentFilterLogicalAnd(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("photo.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.pdf").MimeType("application/pdf"))
attachments := tc.ListWithFilter(`mime_type == "image/png" && filename.contains("image")`)
require.Len(t, attachments, 1)
require.Equal(t, "image.png", attachments[0].Filename)
}
func TestAttachmentFilterLogicalOr(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("video.mp4").MimeType("video/mp4"))
attachments := tc.ListWithFilter(`mime_type == "image/png" || mime_type == "application/pdf"`)
require.Len(t, attachments, 2)
}
func TestAttachmentFilterLogicalNot(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
attachments := tc.ListWithFilter(`!(mime_type == "image/png")`)
require.Len(t, attachments, 1)
require.Equal(t, "application/pdf", attachments[0].Type)
}
func TestAttachmentFilterComplexLogical(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("report.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("report.pdf").MimeType("application/pdf"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("other.png").MimeType("image/png"))
attachments := tc.ListWithFilter(`(mime_type == "image/png" || mime_type == "application/pdf") && filename.contains("report")`)
require.Len(t, attachments, 2)
}
// =============================================================================
// Multiple Filters Tests
// =============================================================================
func TestAttachmentFilterMultipleFilters(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("report.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("other.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("report.pdf").MimeType("application/pdf"))
// Test: Multiple filters (applied as AND)
attachments := tc.ListWithFilters(`filename.contains("report")`, `mime_type == "image/png"`)
require.Len(t, attachments, 1)
require.Contains(t, attachments[0].Filename, "report")
require.Equal(t, "image/png", attachments[0].Type)
}
// =============================================================================
// Edge Cases
// =============================================================================
func TestAttachmentFilterNoMatches(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
attachments := tc.ListWithFilter(`filename.contains("nonexistent12345")`)
require.Len(t, attachments, 0)
}
func TestAttachmentFilterNullMemoId(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContextWithUser(t)
defer tc.Close()
memo := tc.CreateMemo("memo-1", "Memo 1")
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("with_memo.png").MimeType("image/png").MemoID(&memo.ID))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("no_memo.png").MimeType("image/png"))
// Test: memo_id == null
attachments := tc.ListWithFilter(`memo_id == null`)
require.Len(t, attachments, 1)
require.Equal(t, "no_memo.png", attachments[0].Filename)
require.Nil(t, attachments[0].MemoID)
// Test: memo_id != null
attachments = tc.ListWithFilter(`memo_id != null`)
require.Len(t, attachments, 1)
require.Equal(t, "with_memo.png", attachments[0].Filename)
require.NotNil(t, attachments[0].MemoID)
require.Equal(t, memo.ID, *attachments[0].MemoID)
}
func TestAttachmentFilterEmptyFilename(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("other.pdf").MimeType("application/pdf"))
// Test: filename.contains("") - should match all
attachments := tc.ListWithFilter(`filename.contains("")`)
require.Len(t, attachments, 2)
}

View File

@@ -0,0 +1,247 @@
package test
import (
"context"
"fmt"
"testing"
"github.com/lithammer/shortuuid/v4"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
func TestAttachmentStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
_, err := ts.CreateAttachment(ctx, &store.Attachment{
UID: shortuuid.New(),
CreatorID: 101,
Filename: "test.epub",
Blob: []byte("test"),
Type: "application/epub+zip",
Size: 637607,
})
require.NoError(t, err)
correctFilename := "test.epub"
incorrectFilename := "test.png"
attachment, err := ts.GetAttachment(ctx, &store.FindAttachment{
Filename: &correctFilename,
})
require.NoError(t, err)
require.Equal(t, correctFilename, attachment.Filename)
require.Equal(t, int32(1), attachment.ID)
notFoundAttachment, err := ts.GetAttachment(ctx, &store.FindAttachment{
Filename: &incorrectFilename,
})
require.NoError(t, err)
require.Nil(t, notFoundAttachment)
var correctCreatorID int32 = 101
var incorrectCreatorID int32 = 102
_, err = ts.GetAttachment(ctx, &store.FindAttachment{
CreatorID: &correctCreatorID,
})
require.NoError(t, err)
notFoundAttachment, err = ts.GetAttachment(ctx, &store.FindAttachment{
CreatorID: &incorrectCreatorID,
})
require.NoError(t, err)
require.Nil(t, notFoundAttachment)
err = ts.DeleteAttachment(ctx, &store.DeleteAttachment{
ID: 1,
})
require.NoError(t, err)
err = ts.DeleteAttachment(ctx, &store.DeleteAttachment{
ID: 2,
})
require.ErrorContains(t, err, "attachment not found")
ts.Close()
}
func TestAttachmentStoreWithFilter(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
_, err := ts.CreateAttachment(ctx, &store.Attachment{
UID: shortuuid.New(),
CreatorID: 101,
Filename: "test.png",
Blob: []byte("test"),
Type: "image/png",
Size: 1000,
})
require.NoError(t, err)
_, err = ts.CreateAttachment(ctx, &store.Attachment{
UID: shortuuid.New(),
CreatorID: 101,
Filename: "test.jpg",
Blob: []byte("test"),
Type: "image/jpeg",
Size: 2000,
})
require.NoError(t, err)
_, err = ts.CreateAttachment(ctx, &store.Attachment{
UID: shortuuid.New(),
CreatorID: 101,
Filename: "test.pdf",
Blob: []byte("test"),
Type: "application/pdf",
Size: 3000,
})
require.NoError(t, err)
attachments, err := ts.ListAttachments(ctx, &store.FindAttachment{
CreatorID: &[]int32{101}[0],
Filters: []string{`mime_type == "image/png"`},
})
require.NoError(t, err)
require.Len(t, attachments, 1)
require.Equal(t, "image/png", attachments[0].Type)
attachments, err = ts.ListAttachments(ctx, &store.FindAttachment{
CreatorID: &[]int32{101}[0],
Filters: []string{`mime_type in ["image/png", "image/jpeg"]`},
})
require.NoError(t, err)
require.Len(t, attachments, 2)
attachments, err = ts.ListAttachments(ctx, &store.FindAttachment{
CreatorID: &[]int32{101}[0],
Filters: []string{`filename.contains("test")`},
})
require.NoError(t, err)
require.Len(t, attachments, 3)
ts.Close()
}
func TestAttachmentUpdate(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
attachment, err := ts.CreateAttachment(ctx, &store.Attachment{
UID: shortuuid.New(),
CreatorID: 101,
Filename: "original.png",
Blob: []byte("test"),
Type: "image/png",
Size: 1000,
})
require.NoError(t, err)
// Update filename
newFilename := "updated.png"
err = ts.UpdateAttachment(ctx, &store.UpdateAttachment{
ID: attachment.ID,
Filename: &newFilename,
})
require.NoError(t, err)
// Verify update
found, err := ts.GetAttachment(ctx, &store.FindAttachment{ID: &attachment.ID})
require.NoError(t, err)
require.Equal(t, newFilename, found.Filename)
ts.Close()
}
func TestAttachmentGetByUID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
uid := shortuuid.New()
_, err := ts.CreateAttachment(ctx, &store.Attachment{
UID: uid,
CreatorID: 101,
Filename: "test.png",
Blob: []byte("test"),
Type: "image/png",
Size: 1000,
})
require.NoError(t, err)
// Get by UID
found, err := ts.GetAttachment(ctx, &store.FindAttachment{UID: &uid})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, uid, found.UID)
// Get non-existent UID
nonExistentUID := "non-existent-uid"
notFound, err := ts.GetAttachment(ctx, &store.FindAttachment{UID: &nonExistentUID})
require.NoError(t, err)
require.Nil(t, notFound)
ts.Close()
}
func TestAttachmentListWithPagination(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create 5 attachments
for i := 0; i < 5; i++ {
_, err := ts.CreateAttachment(ctx, &store.Attachment{
UID: shortuuid.New(),
CreatorID: 101,
Filename: fmt.Sprintf("test%d.png", i),
Blob: []byte("test"),
Type: "image/png",
Size: int64(1000 + i),
})
require.NoError(t, err)
}
// Test limit
limit := 3
attachments, err := ts.ListAttachments(ctx, &store.FindAttachment{
CreatorID: &[]int32{101}[0],
Limit: &limit,
})
require.NoError(t, err)
require.Equal(t, 3, len(attachments))
// Test offset
offset := 2
offsetAttachments, err := ts.ListAttachments(ctx, &store.FindAttachment{
CreatorID: &[]int32{101}[0],
Limit: &limit,
Offset: &offset,
})
require.NoError(t, err)
require.Equal(t, 3, len(offsetAttachments))
ts.Close()
}
func TestAttachmentInvalidUID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create with invalid UID (contains spaces)
_, err := ts.CreateAttachment(ctx, &store.Attachment{
UID: "invalid uid with spaces",
CreatorID: 101,
Filename: "test.png",
Blob: []byte("test"),
Type: "image/png",
Size: 1000,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid uid")
ts.Close()
}

317
store/test/containers.go Normal file
View File

@@ -0,0 +1,317 @@
package test
import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/docker/docker/api/types/container"
"github.com/pkg/errors"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mysql"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/network"
"github.com/testcontainers/testcontainers-go/wait"
// Database drivers for connection verification.
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
)
const (
testUser = "root"
testPassword = "test"
// Memos container settings for migration testing.
MemosDockerImage = "neosmemo/memos"
StableMemosVersion = "stable" // Always points to the latest stable release
)
var (
mysqlContainer atomic.Pointer[mysql.MySQLContainer]
postgresContainer atomic.Pointer[postgres.PostgresContainer]
mysqlOnce sync.Once
postgresOnce sync.Once
mysqlBaseDSN atomic.Value // stores string
postgresBaseDSN atomic.Value // stores string
dbCounter atomic.Int64
dbCreationMutex sync.Mutex // Protects database creation operations
// Network for container communication.
testDockerNetwork atomic.Pointer[testcontainers.DockerNetwork]
testNetworkOnce sync.Once
)
// getTestNetwork creates or returns the shared Docker network for container communication.
func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) {
var networkErr error
testNetworkOnce.Do(func() {
nw, err := network.New(ctx, network.WithDriver("bridge"))
if err != nil {
networkErr = err
return
}
testDockerNetwork.Store(nw)
})
return testDockerNetwork.Load(), networkErr
}
// GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test.
func GetMySQLDSN(t *testing.T) string {
ctx := context.Background()
mysqlOnce.Do(func() {
nw, err := getTestNetwork(ctx)
if err != nil {
t.Fatalf("failed to create test network: %v", err)
}
container, err := mysql.Run(ctx,
"mysql:8",
mysql.WithDatabase("init_db"),
mysql.WithUsername("root"),
mysql.WithPassword(testPassword),
testcontainers.WithEnv(map[string]string{
"MYSQL_ROOT_PASSWORD": testPassword,
}),
testcontainers.WithWaitStrategy(
wait.ForAll(
wait.ForLog("ready for connections").WithOccurrence(2),
wait.ForListeningPort("3306/tcp"),
).WithDeadline(120*time.Second),
),
network.WithNetwork(nil, nw),
)
if err != nil {
t.Fatalf("failed to start MySQL container: %v", err)
}
mysqlContainer.Store(container)
dsn, err := container.ConnectionString(ctx, "multiStatements=true")
if err != nil {
t.Fatalf("failed to get MySQL connection string: %v", err)
}
if err := waitForDB("mysql", dsn, 30*time.Second); err != nil {
t.Fatalf("MySQL not ready for connections: %v", err)
}
mysqlBaseDSN.Store(dsn)
})
dsn, ok := mysqlBaseDSN.Load().(string)
if !ok || dsn == "" {
t.Fatal("MySQL container failed to start in a previous test")
}
// Serialize database creation to avoid "table already exists" race conditions
dbCreationMutex.Lock()
defer dbCreationMutex.Unlock()
// Create a fresh database for this test
dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1))
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("failed to connect to MySQL: %v", err)
}
defer db.Close()
if _, err := db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE `%s`", dbName)); err != nil {
t.Fatalf("failed to create database %s: %v", dbName, err)
}
// Return DSN pointing to the new database
return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1)
}
// waitForDB polls the database until it's ready or timeout is reached.
func waitForDB(driver, dsn string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
var lastErr error
for {
select {
case <-ctx.Done():
if lastErr != nil {
return errors.Errorf("timeout waiting for %s database: %v", driver, lastErr)
}
return errors.Errorf("timeout waiting for %s database to be ready", driver)
case <-ticker.C:
db, err := sql.Open(driver, dsn)
if err != nil {
lastErr = err
continue
}
err = db.PingContext(ctx)
db.Close()
if err == nil {
return nil
}
lastErr = err
}
}
}
// GetPostgresDSN starts a PostgreSQL container (if not already running) and creates a fresh database for this test.
func GetPostgresDSN(t *testing.T) string {
ctx := context.Background()
postgresOnce.Do(func() {
nw, err := getTestNetwork(ctx)
if err != nil {
t.Fatalf("failed to create test network: %v", err)
}
container, err := postgres.Run(ctx,
"postgres:18",
postgres.WithDatabase("init_db"),
postgres.WithUsername(testUser),
postgres.WithPassword(testPassword),
testcontainers.WithWaitStrategy(
wait.ForAll(
wait.ForLog("database system is ready to accept connections").WithOccurrence(2),
wait.ForListeningPort("5432/tcp"),
).WithDeadline(120*time.Second),
),
network.WithNetwork(nil, nw),
)
if err != nil {
t.Fatalf("failed to start PostgreSQL container: %v", err)
}
postgresContainer.Store(container)
dsn, err := container.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatalf("failed to get PostgreSQL connection string: %v", err)
}
if err := waitForDB("postgres", dsn, 30*time.Second); err != nil {
t.Fatalf("PostgreSQL not ready for connections: %v", err)
}
postgresBaseDSN.Store(dsn)
})
dsn, ok := postgresBaseDSN.Load().(string)
if !ok || dsn == "" {
t.Fatal("PostgreSQL container failed to start in a previous test")
}
// Serialize database creation to avoid "table already exists" race conditions
dbCreationMutex.Lock()
defer dbCreationMutex.Unlock()
// Create a fresh database for this test
dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1))
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatalf("failed to connect to PostgreSQL: %v", err)
}
defer db.Close()
if _, err := db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", dbName)); err != nil {
t.Fatalf("failed to create database %s: %v", dbName, err)
}
// Return DSN pointing to the new database
return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1)
}
// TerminateContainers cleans up all running containers and network.
// This is typically called from TestMain.
func TerminateContainers() {
ctx := context.Background()
if container := mysqlContainer.Load(); container != nil {
_ = container.Terminate(ctx)
}
if container := postgresContainer.Load(); container != nil {
_ = container.Terminate(ctx)
}
if network := testDockerNetwork.Load(); network != nil {
_ = network.Remove(ctx)
}
}
// MemosContainerConfig holds configuration for starting a Memos container.
type MemosContainerConfig struct {
Version string // Memos version tag (e.g., "0.24.0")
Driver string // Database driver: sqlite, mysql, postgres
DSN string // Database DSN (for mysql/postgres)
DataDir string // Host directory to mount for SQLite data
}
// MemosStartupWaitStrategy defines the wait strategy for Memos container startup.
// Uses regex to match various log message formats across versions.
var MemosStartupWaitStrategy = wait.ForAll(
wait.ForLog("(started successfully|has been started on port)").AsRegexp(),
wait.ForListeningPort("5230/tcp"),
).WithDeadline(180 * time.Second)
// StartMemosContainer starts a Memos container for migration testing.
// For SQLite, it mounts the dataDir to /var/opt/memos.
func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcontainers.Container, error) {
env := map[string]string{
"MEMOS_MODE": "prod",
}
var opts []testcontainers.ContainerCustomizer
switch cfg.Driver {
case "sqlite":
env["MEMOS_DRIVER"] = "sqlite"
opts = append(opts, testcontainers.WithHostConfigModifier(func(hc *container.HostConfig) {
hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", cfg.DataDir, "/var/opt/memos"))
}))
default:
return nil, errors.Errorf("unsupported driver for migration testing: %s", cfg.Driver)
}
req := testcontainers.ContainerRequest{
Image: fmt.Sprintf("%s:%s", MemosDockerImage, cfg.Version),
Env: env,
ExposedPorts: []string{"5230/tcp"},
WaitingFor: MemosStartupWaitStrategy,
User: fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()),
}
// Use local image if specified
if cfg.Version == "local" {
if os.Getenv("MEMOS_TEST_IMAGE_BUILT") == "1" {
req.Image = "memos-test:local"
} else {
req.FromDockerfile = testcontainers.FromDockerfile{
Context: "../../",
Dockerfile: "Dockerfile",
}
}
}
genericReq := testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
}
// Apply options
for _, opt := range opts {
if err := opt.Customize(&genericReq); err != nil {
return nil, errors.Wrap(err, "failed to apply container option")
}
}
ctr, err := testcontainers.GenericContainer(ctx, genericReq)
if err != nil {
return nil, errors.Wrap(err, "failed to start memos container")
}
return ctr, nil
}

View File

@@ -0,0 +1,290 @@
package test
import (
"context"
"strconv"
"testing"
"github.com/lithammer/shortuuid/v4"
"github.com/stretchr/testify/require"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// =============================================================================
// Formatting Helpers
// =============================================================================
func formatInt64(n int64) string {
return strconv.FormatInt(n, 10)
}
func formatInt32(n int32) string {
return strconv.FormatInt(int64(n), 10)
}
func formatInt(n int) string {
return strconv.Itoa(n)
}
// =============================================================================
// Pointer Helpers
// =============================================================================
func boolPtr(b bool) *bool {
return &b
}
// =============================================================================
// Test Fixture Builders
// =============================================================================
// MemoBuilder provides a fluent API for creating test memos.
type MemoBuilder struct {
memo *store.Memo
}
// NewMemoBuilder creates a new memo builder with required fields.
func NewMemoBuilder(uid string, creatorID int32) *MemoBuilder {
return &MemoBuilder{
memo: &store.Memo{
UID: uid,
CreatorID: creatorID,
Visibility: store.Public,
},
}
}
func (b *MemoBuilder) Content(content string) *MemoBuilder {
b.memo.Content = content
return b
}
func (b *MemoBuilder) Visibility(v store.Visibility) *MemoBuilder {
b.memo.Visibility = v
return b
}
func (b *MemoBuilder) Tags(tags ...string) *MemoBuilder {
if b.memo.Payload == nil {
b.memo.Payload = &storepb.MemoPayload{}
}
b.memo.Payload.Tags = tags
return b
}
func (b *MemoBuilder) Property(fn func(*storepb.MemoPayload_Property)) *MemoBuilder {
if b.memo.Payload == nil {
b.memo.Payload = &storepb.MemoPayload{}
}
if b.memo.Payload.Property == nil {
b.memo.Payload.Property = &storepb.MemoPayload_Property{}
}
fn(b.memo.Payload.Property)
return b
}
func (b *MemoBuilder) Build() *store.Memo {
return b.memo
}
// AttachmentBuilder provides a fluent API for creating test attachments.
type AttachmentBuilder struct {
attachment *store.Attachment
}
// NewAttachmentBuilder creates a new attachment builder with required fields.
func NewAttachmentBuilder(creatorID int32) *AttachmentBuilder {
return &AttachmentBuilder{
attachment: &store.Attachment{
UID: shortuuid.New(),
CreatorID: creatorID,
Blob: []byte("test"),
Size: 1000,
},
}
}
func (b *AttachmentBuilder) Filename(filename string) *AttachmentBuilder {
b.attachment.Filename = filename
return b
}
func (b *AttachmentBuilder) MimeType(mimeType string) *AttachmentBuilder {
b.attachment.Type = mimeType
return b
}
func (b *AttachmentBuilder) MemoID(memoID *int32) *AttachmentBuilder {
b.attachment.MemoID = memoID
return b
}
func (b *AttachmentBuilder) Size(size int64) *AttachmentBuilder {
b.attachment.Size = size
return b
}
func (b *AttachmentBuilder) Build() *store.Attachment {
return b.attachment
}
// =============================================================================
// Test Context Helpers
// =============================================================================
// MemoFilterTestContext holds common test dependencies for memo filter tests.
type MemoFilterTestContext struct {
Ctx context.Context
T *testing.T
Store *store.Store
User *store.User
}
// NewMemoFilterTestContext creates a new test context with store and user.
func NewMemoFilterTestContext(t *testing.T) *MemoFilterTestContext {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
return &MemoFilterTestContext{
Ctx: ctx,
T: t,
Store: ts,
User: user,
}
}
// CreateMemo creates a memo using the builder pattern.
func (tc *MemoFilterTestContext) CreateMemo(b *MemoBuilder) *store.Memo {
memo, err := tc.Store.CreateMemo(tc.Ctx, b.Build())
require.NoError(tc.T, err)
return memo
}
// PinMemo pins a memo by ID.
func (tc *MemoFilterTestContext) PinMemo(memoID int32) {
err := tc.Store.UpdateMemo(tc.Ctx, &store.UpdateMemo{
ID: memoID,
Pinned: boolPtr(true),
})
require.NoError(tc.T, err)
}
// ListWithFilter lists memos with the given filter and returns the count.
func (tc *MemoFilterTestContext) ListWithFilter(filter string) []*store.Memo {
memos, err := tc.Store.ListMemos(tc.Ctx, &store.FindMemo{
Filters: []string{filter},
})
require.NoError(tc.T, err)
return memos
}
// ListWithFilters lists memos with multiple filters and returns the count.
func (tc *MemoFilterTestContext) ListWithFilters(filters ...string) []*store.Memo {
memos, err := tc.Store.ListMemos(tc.Ctx, &store.FindMemo{
Filters: filters,
})
require.NoError(tc.T, err)
return memos
}
// Close closes the test store.
func (tc *MemoFilterTestContext) Close() {
tc.Store.Close()
}
// AttachmentFilterTestContext holds common test dependencies for attachment filter tests.
type AttachmentFilterTestContext struct {
Ctx context.Context
T *testing.T
Store *store.Store
User *store.User
CreatorID int32
}
// NewAttachmentFilterTestContext creates a new test context for attachments.
func NewAttachmentFilterTestContext(t *testing.T) *AttachmentFilterTestContext {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
return &AttachmentFilterTestContext{
Ctx: ctx,
T: t,
Store: ts,
CreatorID: 101,
}
}
// NewAttachmentFilterTestContextWithUser creates a new test context with a user.
func NewAttachmentFilterTestContextWithUser(t *testing.T) *AttachmentFilterTestContext {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
return &AttachmentFilterTestContext{
Ctx: ctx,
T: t,
Store: ts,
User: user,
CreatorID: user.ID,
}
}
// CreateAttachment creates an attachment using the builder pattern.
func (tc *AttachmentFilterTestContext) CreateAttachment(b *AttachmentBuilder) *store.Attachment {
attachment, err := tc.Store.CreateAttachment(tc.Ctx, b.Build())
require.NoError(tc.T, err)
return attachment
}
// CreateMemo creates a memo (for attachment tests that need memos).
func (tc *AttachmentFilterTestContext) CreateMemo(uid, content string) *store.Memo {
memo, err := tc.Store.CreateMemo(tc.Ctx, &store.Memo{
UID: uid,
CreatorID: tc.CreatorID,
Content: content,
Visibility: store.Public,
})
require.NoError(tc.T, err)
return memo
}
// ListWithFilter lists attachments with the given filter.
func (tc *AttachmentFilterTestContext) ListWithFilter(filter string) []*store.Attachment {
attachments, err := tc.Store.ListAttachments(tc.Ctx, &store.FindAttachment{
CreatorID: &tc.CreatorID,
Filters: []string{filter},
})
require.NoError(tc.T, err)
return attachments
}
// ListWithFilters lists attachments with multiple filters.
func (tc *AttachmentFilterTestContext) ListWithFilters(filters ...string) []*store.Attachment {
attachments, err := tc.Store.ListAttachments(tc.Ctx, &store.FindAttachment{
CreatorID: &tc.CreatorID,
Filters: filters,
})
require.NoError(tc.T, err)
return attachments
}
// Close closes the test store.
func (tc *AttachmentFilterTestContext) Close() {
tc.Store.Close()
}
// =============================================================================
// Filter Test Case Definition
// =============================================================================
// FilterTestCase defines a single filter test case for table-driven tests.
type FilterTestCase struct {
Name string
Filter string
ExpectedCount int
}

454
store/test/idp_test.go Normal file
View File

@@ -0,0 +1,454 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func TestIdentityProviderStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
createdIDP, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Name: "GitHub OAuth",
Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: "",
Config: &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "client_id",
ClientSecret: "client_secret",
AuthUrl: "https://github.com/auth",
TokenUrl: "https://github.com/token",
UserInfoUrl: "https://github.com/user",
Scopes: []string{"login"},
FieldMapping: &storepb.FieldMapping{
Identifier: "login",
DisplayName: "name",
Email: "email",
},
},
},
},
})
require.NoError(t, err)
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &createdIDP.Id,
})
require.NoError(t, err)
require.NotNil(t, idp)
require.Equal(t, createdIDP, idp)
newName := "My GitHub OAuth"
updatedIdp, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
ID: idp.Id,
Name: &newName,
})
require.NoError(t, err)
require.Equal(t, newName, updatedIdp.Name)
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{
ID: idp.Id,
})
require.NoError(t, err)
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
require.NoError(t, err)
require.Equal(t, 0, len(idpList))
ts.Close()
}
func TestIdentityProviderGetByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create IDP
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
require.NoError(t, err)
// Get by ID
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, idp.Id, found.Id)
require.Equal(t, idp.Name, found.Name)
// Get by non-existent ID
nonExistentID := int32(99999)
notFound, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &nonExistentID})
require.NoError(t, err)
require.Nil(t, notFound)
ts.Close()
}
func TestIdentityProviderListMultiple(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create multiple IDPs
_, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth"))
require.NoError(t, err)
// List all
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
require.NoError(t, err)
require.Len(t, idpList, 3)
ts.Close()
}
func TestIdentityProviderListByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create multiple IDPs
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
require.NoError(t, err)
// List by specific ID
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{ID: &idp1.Id})
require.NoError(t, err)
require.Len(t, idpList, 1)
require.Equal(t, "GitHub OAuth", idpList[0].Name)
ts.Close()
}
func TestIdentityProviderUpdateName(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name"))
require.NoError(t, err)
require.Equal(t, "Original Name", idp.Name)
// Update name
newName := "Updated Name"
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
ID: idp.Id,
Type: storepb.IdentityProvider_OAUTH2,
Name: &newName,
})
require.NoError(t, err)
require.Equal(t, "Updated Name", updated.Name)
// Verify update persisted
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Equal(t, "Updated Name", found.Name)
ts.Close()
}
func TestIdentityProviderUpdateIdentifierFilter(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
require.NoError(t, err)
require.Equal(t, "", idp.IdentifierFilter)
// Update identifier filter
newFilter := "@example.com$"
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
ID: idp.Id,
Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: &newFilter,
})
require.NoError(t, err)
require.Equal(t, "@example.com$", updated.IdentifierFilter)
// Verify update persisted
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Equal(t, "@example.com$", found.IdentifierFilter)
ts.Close()
}
func TestIdentityProviderUpdateConfig(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
require.NoError(t, err)
// Update config
newConfig := &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "new_client_id",
ClientSecret: "new_client_secret",
AuthUrl: "https://newprovider.com/auth",
TokenUrl: "https://newprovider.com/token",
UserInfoUrl: "https://newprovider.com/user",
Scopes: []string{"openid", "profile", "email"},
FieldMapping: &storepb.FieldMapping{
Identifier: "sub",
DisplayName: "name",
Email: "email",
},
},
},
}
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
ID: idp.Id,
Type: storepb.IdentityProvider_OAUTH2,
Config: newConfig,
})
require.NoError(t, err)
require.Equal(t, "new_client_id", updated.Config.GetOauth2Config().ClientId)
require.Equal(t, "new_client_secret", updated.Config.GetOauth2Config().ClientSecret)
require.Contains(t, updated.Config.GetOauth2Config().Scopes, "openid")
// Verify update persisted
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Equal(t, "new_client_id", found.Config.GetOauth2Config().ClientId)
ts.Close()
}
func TestIdentityProviderUpdateMultipleFields(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original"))
require.NoError(t, err)
// Update multiple fields at once
newName := "Updated IDP"
newFilter := "^admin@"
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
ID: idp.Id,
Type: storepb.IdentityProvider_OAUTH2,
Name: &newName,
IdentifierFilter: &newFilter,
})
require.NoError(t, err)
require.Equal(t, "Updated IDP", updated.Name)
require.Equal(t, "^admin@", updated.IdentifierFilter)
ts.Close()
}
func TestIdentityProviderDelete(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
require.NoError(t, err)
// Delete
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp.Id})
require.NoError(t, err)
// Verify deletion
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Nil(t, found)
ts.Close()
}
func TestIdentityProviderDeleteNotAffectOthers(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create multiple IDPs
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1"))
require.NoError(t, err)
idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2"))
require.NoError(t, err)
// Delete first one
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp1.Id})
require.NoError(t, err)
// Verify second still exists
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp2.Id})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, "IDP 2", found.Name)
// Verify list only contains second
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
require.NoError(t, err)
require.Len(t, idpList, 1)
ts.Close()
}
func TestIdentityProviderOAuth2ConfigScopes(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create IDP with multiple scopes
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Name: "Multi-Scope OAuth",
Type: storepb.IdentityProvider_OAUTH2,
Config: &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "client_id",
ClientSecret: "client_secret",
AuthUrl: "https://provider.com/auth",
TokenUrl: "https://provider.com/token",
UserInfoUrl: "https://provider.com/userinfo",
Scopes: []string{"openid", "profile", "email", "groups"},
FieldMapping: &storepb.FieldMapping{
Identifier: "sub",
DisplayName: "name",
Email: "email",
},
},
},
},
})
require.NoError(t, err)
// Verify scopes are preserved
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Len(t, found.Config.GetOauth2Config().Scopes, 4)
require.Contains(t, found.Config.GetOauth2Config().Scopes, "openid")
require.Contains(t, found.Config.GetOauth2Config().Scopes, "groups")
ts.Close()
}
func TestIdentityProviderFieldMapping(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create IDP with custom field mapping
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Name: "Custom Field Mapping",
Type: storepb.IdentityProvider_OAUTH2,
Config: &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "client_id",
ClientSecret: "client_secret",
AuthUrl: "https://provider.com/auth",
TokenUrl: "https://provider.com/token",
UserInfoUrl: "https://provider.com/userinfo",
Scopes: []string{"login"},
FieldMapping: &storepb.FieldMapping{
Identifier: "preferred_username",
DisplayName: "full_name",
Email: "email_address",
},
},
},
},
})
require.NoError(t, err)
// Verify field mapping is preserved
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Equal(t, "preferred_username", found.Config.GetOauth2Config().FieldMapping.Identifier)
require.Equal(t, "full_name", found.Config.GetOauth2Config().FieldMapping.DisplayName)
require.Equal(t, "email_address", found.Config.GetOauth2Config().FieldMapping.Email)
ts.Close()
}
func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
testCases := []struct {
name string
filter string
}{
{"Domain filter", "@company\\.com$"},
{"Prefix filter", "^admin_"},
{"Complex regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"},
{"Empty filter", ""},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Name: tc.name,
Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: tc.filter,
Config: &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "client_id",
ClientSecret: "client_secret",
AuthUrl: "https://provider.com/auth",
TokenUrl: "https://provider.com/token",
UserInfoUrl: "https://provider.com/userinfo",
Scopes: []string{"login"},
FieldMapping: &storepb.FieldMapping{
Identifier: "sub",
},
},
},
},
})
require.NoError(t, err)
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Equal(t, tc.filter, found.IdentifierFilter)
// Cleanup
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp.Id})
require.NoError(t, err)
})
}
ts.Close()
}
// Helper function to create a test OAuth2 IDP.
func createTestOAuth2IDP(name string) *storepb.IdentityProvider {
return &storepb.IdentityProvider{
Name: name,
Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: "",
Config: &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "client_id",
ClientSecret: "client_secret",
AuthUrl: "https://provider.com/auth",
TokenUrl: "https://provider.com/token",
UserInfoUrl: "https://provider.com/userinfo",
Scopes: []string{"login"},
FieldMapping: &storepb.FieldMapping{
Identifier: "login",
DisplayName: "name",
Email: "email",
},
},
},
},
}
}

592
store/test/inbox_test.go Normal file
View File

@@ -0,0 +1,592 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func TestInboxStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
const systemBotID int32 = 0
create := &store.Inbox{
SenderID: systemBotID,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
},
}
inbox, err := ts.CreateInbox(ctx, create)
require.NoError(t, err)
require.NotNil(t, inbox)
require.Equal(t, create.Message, inbox.Message)
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
})
require.NoError(t, err)
require.Equal(t, 1, len(inboxes))
require.Equal(t, inbox, inboxes[0])
updatedInbox, err := ts.UpdateInbox(ctx, &store.UpdateInbox{
ID: inbox.ID,
Status: store.ARCHIVED,
})
require.NoError(t, err)
require.NotNil(t, updatedInbox)
require.Equal(t, store.ARCHIVED, updatedInbox.Status)
err = ts.DeleteInbox(ctx, &store.DeleteInbox{
ID: inbox.ID,
})
require.NoError(t, err)
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
})
require.NoError(t, err)
require.Equal(t, 0, len(inboxes))
ts.Close()
}
func TestInboxListByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// List by ID
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ID: &inbox.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, inbox.ID, inboxes[0].ID)
// List by non-existent ID
nonExistentID := int32(99999)
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ID: &nonExistentID})
require.NoError(t, err)
require.Len(t, inboxes, 0)
ts.Close()
}
func TestInboxListBySenderID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
// Create inbox from system bot (senderID = 0)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// Create inbox from user2
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: user2.ID,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// List by sender ID = user2
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{SenderID: &user2.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, user2.ID, inboxes[0].SenderID)
// List by sender ID = 0 (system bot)
systemBotID := int32(0)
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{SenderID: &systemBotID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, int32(0), inboxes[0].SenderID)
ts.Close()
}
func TestInboxListByStatus(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create UNREAD inbox
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// Create another inbox and archive it
inbox2, err := ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: inbox2.ID, Status: store.ARCHIVED})
require.NoError(t, err)
// List by UNREAD status
unreadStatus := store.UNREAD
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{Status: &unreadStatus})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, store.UNREAD, inboxes[0].Status)
// List by ARCHIVED status
archivedStatus := store.ARCHIVED
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{Status: &archivedStatus})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, store.ARCHIVED, inboxes[0].Status)
ts.Close()
}
func TestInboxListByMessageType(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create MEMO_COMMENT inboxes
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// List by MEMO_COMMENT type
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{MessageType: &memoCommentType})
require.NoError(t, err)
require.Len(t, inboxes, 2)
for _, inbox := range inboxes {
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
}
ts.Close()
}
func TestInboxListPagination(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create 5 inboxes
for i := 0; i < 5; i++ {
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
}
// Test Limit only
limit := 3
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
Limit: &limit,
})
require.NoError(t, err)
require.Len(t, inboxes, 3)
// Test Limit + Offset (offset requires limit in the implementation)
limit = 2
offset := 2
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
Limit: &limit,
Offset: &offset,
})
require.NoError(t, err)
require.Len(t, inboxes, 2)
// Test Limit + Offset skipping to end
limit = 10
offset = 3
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
Limit: &limit,
Offset: &offset,
})
require.NoError(t, err)
require.Len(t, inboxes, 2) // Only 2 remaining after offset of 3
ts.Close()
}
func TestInboxListCombinedFilters(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
// Create various inboxes
// user2 -> user1, MEMO_COMMENT, UNREAD
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: user2.ID,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// user2 -> user1, TYPE_UNSPECIFIED, UNREAD
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: user2.ID,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_TYPE_UNSPECIFIED},
})
require.NoError(t, err)
// system -> user1, MEMO_COMMENT, ARCHIVED
inbox3, err := ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: inbox3.ID, Status: store.ARCHIVED})
require.NoError(t, err)
// Combined filter: ReceiverID + SenderID + Status
unreadStatus := store.UNREAD
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user1.ID,
SenderID: &user2.ID,
Status: &unreadStatus,
})
require.NoError(t, err)
require.Len(t, inboxes, 2)
// Combined filter: ReceiverID + MessageType + Status
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user1.ID,
MessageType: &memoCommentType,
Status: &unreadStatus,
})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, user2.ID, inboxes[0].SenderID)
ts.Close()
}
func TestInboxMessagePayload(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create inbox with message payload containing activity ID
activityID := int32(123)
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
ActivityId: &activityID,
},
})
require.NoError(t, err)
require.NotNil(t, inbox.Message)
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
require.Equal(t, activityID, *inbox.Message.ActivityId)
// List and verify payload is preserved
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, activityID, *inboxes[0].Message.ActivityId)
ts.Close()
}
func TestInboxUpdateStatus(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
require.Equal(t, store.UNREAD, inbox.Status)
// Update to ARCHIVED
updated, err := ts.UpdateInbox(ctx, &store.UpdateInbox{
ID: inbox.ID,
Status: store.ARCHIVED,
})
require.NoError(t, err)
require.Equal(t, store.ARCHIVED, updated.Status)
require.Equal(t, inbox.ID, updated.ID)
// Verify the update persisted
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ID: &inbox.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, store.ARCHIVED, inboxes[0].Status)
ts.Close()
}
func TestInboxListByMessageTypeMultipleTypes(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create inboxes with different message types
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_TYPE_UNSPECIFIED},
})
require.NoError(t, err)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// Filter by MEMO_COMMENT - should get 2
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
MessageType: &memoCommentType,
})
require.NoError(t, err)
require.Len(t, inboxes, 2)
for _, inbox := range inboxes {
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
}
// Filter by TYPE_UNSPECIFIED - should get 1
unspecifiedType := storepb.InboxMessage_TYPE_UNSPECIFIED
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
MessageType: &unspecifiedType,
})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, storepb.InboxMessage_TYPE_UNSPECIFIED, inboxes[0].Message.Type)
ts.Close()
}
func TestInboxMessageTypeFilterWithPayload(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create inbox with full payload
activityID := int32(456)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
ActivityId: &activityID,
},
})
require.NoError(t, err)
// Create inbox with different type but also has payload
otherActivityID := int32(789)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_TYPE_UNSPECIFIED,
ActivityId: &otherActivityID,
},
})
require.NoError(t, err)
// Filter by type should work correctly even with complex JSON payload
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
MessageType: &memoCommentType,
})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, activityID, *inboxes[0].Message.ActivityId)
ts.Close()
}
func TestInboxMessageTypeFilterWithStatusAndPagination(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create multiple inboxes with various combinations
for i := 0; i < 5; i++ {
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
}
// Archive 2 of them
allInboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user.ID})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: allInboxes[i].ID, Status: store.ARCHIVED})
require.NoError(t, err)
}
// Filter by type + status + pagination
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
unreadStatus := store.UNREAD
limit := 2
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
MessageType: &memoCommentType,
Status: &unreadStatus,
Limit: &limit,
})
require.NoError(t, err)
require.Len(t, inboxes, 2)
for _, inbox := range inboxes {
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
require.Equal(t, store.UNREAD, inbox.Status)
}
// Get next page
offset := 2
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
MessageType: &memoCommentType,
Status: &unreadStatus,
Limit: &limit,
Offset: &offset,
})
require.NoError(t, err)
require.Len(t, inboxes, 1) // Only 1 remaining (3 unread total, got 2, now 1 left)
ts.Close()
}
func TestInboxMultipleReceivers(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
// Create inbox for user1
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// Create inbox for user2
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user2.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// User1 should only see their inbox
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user1.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, user1.ID, inboxes[0].ReceiverID)
// User2 should only see their inbox
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user2.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, user2.ID, inboxes[0].ReceiverID)
ts.Close()
}

View File

@@ -0,0 +1,300 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func TestInstanceSettingV1Store(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
instanceSetting, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
AdditionalScript: "",
},
},
})
require.NoError(t, err)
setting, err := ts.GetInstanceSetting(ctx, &store.FindInstanceSetting{
Name: storepb.InstanceSettingKey_GENERAL.String(),
})
require.NoError(t, err)
require.Equal(t, instanceSetting, setting)
ts.Close()
}
func TestInstanceSettingGetNonExistent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Get non-existent setting
setting, err := ts.GetInstanceSetting(ctx, &store.FindInstanceSetting{
Name: storepb.InstanceSettingKey_STORAGE.String(),
})
require.NoError(t, err)
require.Nil(t, setting)
ts.Close()
}
func TestInstanceSettingUpsertUpdate(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create setting
_, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
AdditionalScript: "console.log('v1')",
},
},
})
require.NoError(t, err)
// Update setting
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
AdditionalScript: "console.log('v2')",
},
},
})
require.NoError(t, err)
// Verify update
setting, err := ts.GetInstanceSetting(ctx, &store.FindInstanceSetting{
Name: storepb.InstanceSettingKey_GENERAL.String(),
})
require.NoError(t, err)
require.Equal(t, "console.log('v2')", setting.GetGeneralSetting().AdditionalScript)
// Verify only one setting exists
list, err := ts.ListInstanceSettings(ctx, &store.FindInstanceSetting{
Name: storepb.InstanceSettingKey_GENERAL.String(),
})
require.NoError(t, err)
require.Equal(t, 1, len(list))
ts.Close()
}
func TestInstanceSettingBasicSetting(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Get default basic setting (should return empty defaults)
basicSetting, err := ts.GetInstanceBasicSetting(ctx)
require.NoError(t, err)
require.NotNil(t, basicSetting)
// Set basic setting
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_BASIC,
Value: &storepb.InstanceSetting_BasicSetting{
BasicSetting: &storepb.InstanceBasicSetting{
SecretKey: "my-secret-key",
},
},
})
require.NoError(t, err)
// Verify
basicSetting, err = ts.GetInstanceBasicSetting(ctx)
require.NoError(t, err)
require.Equal(t, "my-secret-key", basicSetting.SecretKey)
ts.Close()
}
func TestInstanceSettingGeneralSetting(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Get default general setting
generalSetting, err := ts.GetInstanceGeneralSetting(ctx)
require.NoError(t, err)
require.NotNil(t, generalSetting)
// Set general setting
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
AdditionalScript: "console.log('test')",
AdditionalStyle: "body { color: red; }",
},
},
})
require.NoError(t, err)
// Verify
generalSetting, err = ts.GetInstanceGeneralSetting(ctx)
require.NoError(t, err)
require.Equal(t, "console.log('test')", generalSetting.AdditionalScript)
require.Equal(t, "body { color: red; }", generalSetting.AdditionalStyle)
ts.Close()
}
func TestInstanceSettingMemoRelatedSetting(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Get default memo related setting (should have defaults)
memoSetting, err := ts.GetInstanceMemoRelatedSetting(ctx)
require.NoError(t, err)
require.NotNil(t, memoSetting)
require.GreaterOrEqual(t, memoSetting.ContentLengthLimit, int32(store.DefaultContentLengthLimit))
require.NotEmpty(t, memoSetting.Reactions)
// Set custom memo related setting
customReactions := []string{"👍", "👎", "🚀"}
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_MEMO_RELATED,
Value: &storepb.InstanceSetting_MemoRelatedSetting{
MemoRelatedSetting: &storepb.InstanceMemoRelatedSetting{
ContentLengthLimit: 16384,
Reactions: customReactions,
},
},
})
require.NoError(t, err)
// Verify
memoSetting, err = ts.GetInstanceMemoRelatedSetting(ctx)
require.NoError(t, err)
require.Equal(t, int32(16384), memoSetting.ContentLengthLimit)
require.Equal(t, customReactions, memoSetting.Reactions)
ts.Close()
}
func TestInstanceSettingStorageSetting(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Get default storage setting (should have defaults)
storageSetting, err := ts.GetInstanceStorageSetting(ctx)
require.NoError(t, err)
require.NotNil(t, storageSetting)
require.Equal(t, storepb.InstanceStorageSetting_LOCAL, storageSetting.StorageType)
require.Equal(t, int64(30), storageSetting.UploadSizeLimitMb)
require.Equal(t, "assets/{timestamp}_{filename}", storageSetting.FilepathTemplate)
// Set custom storage setting
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_STORAGE,
Value: &storepb.InstanceSetting_StorageSetting{
StorageSetting: &storepb.InstanceStorageSetting{
StorageType: storepb.InstanceStorageSetting_LOCAL,
UploadSizeLimitMb: 100,
FilepathTemplate: "uploads/{date}/{filename}",
},
},
})
require.NoError(t, err)
// Verify
storageSetting, err = ts.GetInstanceStorageSetting(ctx)
require.NoError(t, err)
require.Equal(t, storepb.InstanceStorageSetting_LOCAL, storageSetting.StorageType)
require.Equal(t, int64(100), storageSetting.UploadSizeLimitMb)
require.Equal(t, "uploads/{date}/{filename}", storageSetting.FilepathTemplate)
ts.Close()
}
func TestInstanceSettingListAll(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Count initial settings
initialList, err := ts.ListInstanceSettings(ctx, &store.FindInstanceSetting{})
require.NoError(t, err)
initialCount := len(initialList)
// Create multiple settings
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{},
},
})
require.NoError(t, err)
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_STORAGE,
Value: &storepb.InstanceSetting_StorageSetting{
StorageSetting: &storepb.InstanceStorageSetting{},
},
})
require.NoError(t, err)
// List all - should have 2 more than initial
list, err := ts.ListInstanceSettings(ctx, &store.FindInstanceSetting{})
require.NoError(t, err)
require.Equal(t, initialCount+2, len(list))
ts.Close()
}
func TestInstanceSettingEdgeCases(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Case 1: General Setting with special characters and Unicode
specialScript := `<script>alert("你好"); var x = 'test\'s';</script>`
specialStyle := `body { font-family: "Noto Sans SC", sans-serif; content: "\u2764"; }`
_, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
AdditionalScript: specialScript,
AdditionalStyle: specialStyle,
},
},
})
require.NoError(t, err)
generalSetting, err := ts.GetInstanceGeneralSetting(ctx)
require.NoError(t, err)
require.Equal(t, specialScript, generalSetting.AdditionalScript)
require.Equal(t, specialStyle, generalSetting.AdditionalStyle)
// Case 2: Memo Related Setting with Unicode reactions
unicodeReactions := []string{"🐱", "🐶", "🦊", "🦄"}
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_MEMO_RELATED,
Value: &storepb.InstanceSetting_MemoRelatedSetting{
MemoRelatedSetting: &storepb.InstanceMemoRelatedSetting{
ContentLengthLimit: 1000,
Reactions: unicodeReactions,
},
},
})
require.NoError(t, err)
memoSetting, err := ts.GetInstanceMemoRelatedSetting(ctx)
require.NoError(t, err)
require.Equal(t, unicodeReactions, memoSetting.Reactions)
ts.Close()
}

50
store/test/main_test.go Normal file
View File

@@ -0,0 +1,50 @@
package test
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"testing"
)
func TestMain(m *testing.M) {
// If DRIVER is set, run tests for that driver only
if os.Getenv("DRIVER") != "" {
defer TerminateContainers()
m.Run() //nolint:revive // Exit code is handled by test runner
return
}
// No DRIVER set - run tests for all drivers sequentially
runAllDrivers()
}
func runAllDrivers() {
drivers := []string{"sqlite", "mysql", "postgres"}
_, currentFile, _, _ := runtime.Caller(0)
projectRoot := filepath.Dir(filepath.Dir(filepath.Dir(currentFile)))
var failed []string
for _, driver := range drivers {
fmt.Printf("\n==================== %s ====================\n\n", driver)
cmd := exec.Command("go", "test", "-v", "-count=1", "./store/test/...")
cmd.Dir = projectRoot
cmd.Env = append(os.Environ(), "DRIVER="+driver)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
failed = append(failed, driver)
}
}
fmt.Println()
if len(failed) > 0 {
fmt.Printf("FAIL: %v\n", failed)
panic("some drivers failed")
}
fmt.Println("PASS: all drivers")
}

View File

@@ -0,0 +1,939 @@
package test
import (
"testing"
"time"
"github.com/stretchr/testify/require"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// =============================================================================
// Content Field Tests
// Schema: content (string, supports contains)
// =============================================================================
func TestMemoFilterContentContains(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memos with different content
tc.CreateMemo(NewMemoBuilder("memo-hello", tc.User.ID).Content("Hello world"))
tc.CreateMemo(NewMemoBuilder("memo-goodbye", tc.User.ID).Content("Goodbye world"))
tc.CreateMemo(NewMemoBuilder("memo-test", tc.User.ID).Content("Testing content"))
// Test: content.contains("Hello") - single match
memos := tc.ListWithFilter(`content.contains("Hello")`)
require.Len(t, memos, 1)
require.Contains(t, memos[0].Content, "Hello")
// Test: content.contains("world") - multiple matches
memos = tc.ListWithFilter(`content.contains("world")`)
require.Len(t, memos, 2)
// Test: content.contains("nonexistent") - no matches
memos = tc.ListWithFilter(`content.contains("nonexistent")`)
require.Len(t, memos, 0)
}
func TestMemoFilterContentSpecialCharacters(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-special", tc.User.ID).Content("Special chars: @#$%^&*()"))
memos := tc.ListWithFilter(`content.contains("@#$%")`)
require.Len(t, memos, 1)
}
func TestMemoFilterContentUnicode(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-unicode", tc.User.ID).Content("Unicode test: 你好世界 🌍"))
memos := tc.ListWithFilter(`content.contains("你好")`)
require.Len(t, memos, 1)
}
func TestMemoFilterContentUnicodeCaseFold(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-unicode-case", tc.User.ID).Content("Привет Мир"))
memos := tc.ListWithFilter(`content.contains("привет")`)
require.Len(t, memos, 1)
}
func TestMemoFilterContentCaseSensitivity(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-case", tc.User.ID).Content("MixedCase Content"))
// Exact match
memos := tc.ListWithFilter(`content.contains("MixedCase")`)
require.Len(t, memos, 1)
// Lowercase match (depends on DB collation, usually case-insensitive in default installs but good to verify behavior)
// SQLite default LIKE is case-insensitive for ASCII.
memosLower := tc.ListWithFilter(`content.contains("mixedcase")`)
// We just verify it doesn't crash; strict case sensitivity expectation depends on DB config.
// For standard Memos setup (SQLite), it's often case-insensitive.
// Let's check if we get a result or not to characterize current behavior.
if len(memosLower) > 0 {
require.Equal(t, "MixedCase Content", memosLower[0].Content)
}
}
// =============================================================================
// Visibility Field Tests
// Schema: visibility (string, ==, !=)
// =============================================================================
func TestMemoFilterVisibilityEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-public", tc.User.ID).Content("Public memo").Visibility(store.Public))
tc.CreateMemo(NewMemoBuilder("memo-private", tc.User.ID).Content("Private memo").Visibility(store.Private))
tc.CreateMemo(NewMemoBuilder("memo-protected", tc.User.ID).Content("Protected memo").Visibility(store.Protected))
// Test: visibility == "PUBLIC"
memos := tc.ListWithFilter(`visibility == "PUBLIC"`)
require.Len(t, memos, 1)
require.Equal(t, store.Public, memos[0].Visibility)
// Test: visibility == "PRIVATE"
memos = tc.ListWithFilter(`visibility == "PRIVATE"`)
require.Len(t, memos, 1)
require.Equal(t, store.Private, memos[0].Visibility)
}
func TestMemoFilterVisibilityNotEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-public", tc.User.ID).Content("Public memo").Visibility(store.Public))
tc.CreateMemo(NewMemoBuilder("memo-private", tc.User.ID).Content("Private memo").Visibility(store.Private))
memos := tc.ListWithFilter(`visibility != "PUBLIC"`)
require.Len(t, memos, 1)
require.Equal(t, store.Private, memos[0].Visibility)
}
func TestMemoFilterVisibilityInList(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-pub", tc.User.ID).Visibility(store.Public))
tc.CreateMemo(NewMemoBuilder("memo-priv", tc.User.ID).Visibility(store.Private))
tc.CreateMemo(NewMemoBuilder("memo-prot", tc.User.ID).Visibility(store.Protected))
memos := tc.ListWithFilter(`visibility in ["PUBLIC", "PRIVATE"]`)
require.Len(t, memos, 2)
}
// =============================================================================
// Pinned Field Tests
// Schema: pinned (bool column, ==, !=, predicate)
// =============================================================================
func TestMemoFilterPinnedEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
pinnedMemo := tc.CreateMemo(NewMemoBuilder("memo-pinned", tc.User.ID).Content("Pinned memo"))
tc.PinMemo(pinnedMemo.ID)
tc.CreateMemo(NewMemoBuilder("memo-unpinned", tc.User.ID).Content("Unpinned memo"))
// Test: pinned == true
memos := tc.ListWithFilter(`pinned == true`)
require.Len(t, memos, 1)
require.True(t, memos[0].Pinned)
// Test: pinned == false
memos = tc.ListWithFilter(`pinned == false`)
require.Len(t, memos, 1)
require.False(t, memos[0].Pinned)
}
func TestMemoFilterPinnedPredicate(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
pinnedMemo := tc.CreateMemo(NewMemoBuilder("memo-pinned", tc.User.ID).Content("Pinned memo"))
tc.PinMemo(pinnedMemo.ID)
tc.CreateMemo(NewMemoBuilder("memo-unpinned", tc.User.ID).Content("Unpinned memo"))
memos := tc.ListWithFilter(`pinned`)
require.Len(t, memos, 1)
require.True(t, memos[0].Pinned)
}
// =============================================================================
// Creator ID Field Tests
// Schema: creator_id (int, ==, !=)
// =============================================================================
func TestMemoFilterCreatorIdEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
user2, err := tc.Store.CreateUser(tc.Ctx, &store.User{
Username: "user2",
Role: store.RoleUser,
Email: "user2@example.com",
Nickname: "User 2",
})
require.NoError(t, err)
tc.CreateMemo(NewMemoBuilder("memo-user1", tc.User.ID).Content("User 1 memo"))
tc.CreateMemo(NewMemoBuilder("memo-user2", user2.ID).Content("User 2 memo"))
memos := tc.ListWithFilter(`creator_id == ` + formatInt(int(tc.User.ID)))
require.Len(t, memos, 1)
require.Equal(t, tc.User.ID, memos[0].CreatorID)
}
func TestMemoFilterCreatorIdNotEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
user2, err := tc.Store.CreateUser(tc.Ctx, &store.User{
Username: "user2",
Role: store.RoleUser,
Email: "user2@example.com",
Nickname: "User 2",
})
require.NoError(t, err)
tc.CreateMemo(NewMemoBuilder("memo-user1", tc.User.ID).Content("User 1 memo"))
tc.CreateMemo(NewMemoBuilder("memo-user2", user2.ID).Content("User 2 memo"))
memos := tc.ListWithFilter(`creator_id != ` + formatInt(int(tc.User.ID)))
require.Len(t, memos, 1)
require.Equal(t, user2.ID, memos[0].CreatorID)
}
// =============================================================================
// Tags Field Tests
// Schema: tags (JSON list), tag (virtual alias)
// Operators: tag in [...], "value" in tags
// =============================================================================
func TestMemoFilterTagInList(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-work", tc.User.ID).Content("Work memo").Tags("work", "important"))
tc.CreateMemo(NewMemoBuilder("memo-personal", tc.User.ID).Content("Personal memo").Tags("personal", "fun"))
tc.CreateMemo(NewMemoBuilder("memo-no-tags", tc.User.ID).Content("No tags"))
// Test: tag in ["work"]
memos := tc.ListWithFilter(`tag in ["work"]`)
require.Len(t, memos, 1)
require.Contains(t, memos[0].Payload.Tags, "work")
// Test: tag in ["work", "personal"]
memos = tc.ListWithFilter(`tag in ["work", "personal"]`)
require.Len(t, memos, 2)
}
func TestMemoFilterElementInTags(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-tagged", tc.User.ID).Content("Tagged memo").Tags("project", "todo"))
tc.CreateMemo(NewMemoBuilder("memo-untagged", tc.User.ID).Content("Untagged memo"))
// Test: "project" in tags
memos := tc.ListWithFilter(`"project" in tags`)
require.Len(t, memos, 1)
// Test: "nonexistent" in tags
memos = tc.ListWithFilter(`"nonexistent" in tags`)
require.Len(t, memos, 0)
}
func TestMemoFilterHierarchicalTags(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-book", tc.User.ID).Content("Book memo").Tags("book"))
tc.CreateMemo(NewMemoBuilder("memo-book-fiction", tc.User.ID).Content("Fiction book memo").Tags("book/fiction"))
// Test: tag in ["book"] should match both (hierarchical matching)
memos := tc.ListWithFilter(`tag in ["book"]`)
require.Len(t, memos, 2)
}
func TestMemoFilterEmptyTags(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-empty-tags", tc.User.ID).Content("Empty tags").Tags())
memos := tc.ListWithFilter(`tag in ["anything"]`)
require.Len(t, memos, 0)
}
// =============================================================================
// JSON Bool Field Tests
// Schema: has_task_list, has_link, has_code, has_incomplete_tasks
// Operators: ==, !=, predicate
// =============================================================================
func TestMemoFilterHasTaskList(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-with-tasks", tc.User.ID).
Content("- [ ] Task 1\n- [x] Task 2").
Property(func(p *storepb.MemoPayload_Property) { p.HasTaskList = true }))
tc.CreateMemo(NewMemoBuilder("memo-no-tasks", tc.User.ID).Content("No tasks here"))
// Test: has_task_list (predicate)
memos := tc.ListWithFilter(`has_task_list`)
require.Len(t, memos, 1)
require.True(t, memos[0].Payload.Property.HasTaskList)
// Test: has_task_list == true
memos = tc.ListWithFilter(`has_task_list == true`)
require.Len(t, memos, 1)
// Note: has_task_list == false is not tested because JSON boolean fields
// with false value may not be queryable when the field is not present in JSON
}
func TestMemoFilterHasLink(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-with-link", tc.User.ID).
Content("Check out https://example.com").
Property(func(p *storepb.MemoPayload_Property) { p.HasLink = true }))
tc.CreateMemo(NewMemoBuilder("memo-no-link", tc.User.ID).Content("No links"))
memos := tc.ListWithFilter(`has_link`)
require.Len(t, memos, 1)
require.True(t, memos[0].Payload.Property.HasLink)
}
func TestMemoFilterHasCode(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-with-code", tc.User.ID).
Content("```go\nfmt.Println(\"Hello\")\n```").
Property(func(p *storepb.MemoPayload_Property) { p.HasCode = true }))
tc.CreateMemo(NewMemoBuilder("memo-no-code", tc.User.ID).Content("No code"))
memos := tc.ListWithFilter(`has_code`)
require.Len(t, memos, 1)
require.True(t, memos[0].Payload.Property.HasCode)
}
func TestMemoFilterHasIncompleteTasks(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-incomplete", tc.User.ID).
Content("- [ ] Incomplete task").
Property(func(p *storepb.MemoPayload_Property) {
p.HasTaskList = true
p.HasIncompleteTasks = true
}))
tc.CreateMemo(NewMemoBuilder("memo-complete", tc.User.ID).
Content("- [x] Complete task").
Property(func(p *storepb.MemoPayload_Property) {
p.HasTaskList = true
p.HasIncompleteTasks = false
}))
memos := tc.ListWithFilter(`has_incomplete_tasks`)
require.Len(t, memos, 1)
require.True(t, memos[0].Payload.Property.HasIncompleteTasks)
}
func TestMemoFilterCombinedJSONBool(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Memo with all properties
tc.CreateMemo(NewMemoBuilder("memo-all-props", tc.User.ID).
Content("All properties").
Property(func(p *storepb.MemoPayload_Property) {
p.HasLink = true
p.HasTaskList = true
p.HasCode = true
p.HasIncompleteTasks = true
}))
// Memo with only link
tc.CreateMemo(NewMemoBuilder("memo-only-link", tc.User.ID).
Content("Only link").
Property(func(p *storepb.MemoPayload_Property) { p.HasLink = true }))
// Test: has_link && has_code
memos := tc.ListWithFilter(`has_link && has_code`)
require.Len(t, memos, 1)
// Test: has_task_list && has_incomplete_tasks
memos = tc.ListWithFilter(`has_task_list && has_incomplete_tasks`)
require.Len(t, memos, 1)
// Test: has_link || has_code
memos = tc.ListWithFilter(`has_link || has_code`)
require.Len(t, memos, 2)
}
// =============================================================================
// Timestamp Field Tests
// Schema: created_ts, updated_ts (timestamp, all comparison operators)
// Functions: now(), arithmetic (+, -, *)
// =============================================================================
func TestMemoFilterCreatedTsComparison(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
now := time.Now().Unix()
tc.CreateMemo(NewMemoBuilder("memo-ts", tc.User.ID).Content("Timestamp test"))
// Test: created_ts < future (should match)
memos := tc.ListWithFilter(`created_ts < ` + formatInt64(now+3600))
require.Len(t, memos, 1)
// Test: created_ts > past (should match)
memos = tc.ListWithFilter(`created_ts > ` + formatInt64(now-3600))
require.Len(t, memos, 1)
// Test: created_ts > future (should not match)
memos = tc.ListWithFilter(`created_ts > ` + formatInt64(now+3600))
require.Len(t, memos, 0)
}
func TestMemoFilterCreatedTsWithNow(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-ts-test", tc.User.ID).Content("Timestamp test"))
// Test: created_ts < now() + 5 (buffer for container clock drift)
memos := tc.ListWithFilter(`created_ts < now() + 5`)
require.Len(t, memos, 1)
// Test: created_ts > now() + 5 (should not match)
memos = tc.ListWithFilter(`created_ts > now() + 5`)
require.Len(t, memos, 0)
}
func TestMemoFilterCreatedTsArithmetic(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-ts-arith", tc.User.ID).Content("Timestamp arithmetic test"))
// Test: created_ts >= now() - 3600 (memos created in last hour)
memos := tc.ListWithFilter(`created_ts >= now() - 3600`)
require.Len(t, memos, 1)
// Test: created_ts < now() - 86400 (memos older than 1 day - should be empty)
memos = tc.ListWithFilter(`created_ts < now() - 86400`)
require.Len(t, memos, 0)
// Test: Multiplication - created_ts >= now() - 60 * 60
memos = tc.ListWithFilter(`created_ts >= now() - 60 * 60`)
require.Len(t, memos, 1)
}
func TestMemoFilterUpdatedTs(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
memo := tc.CreateMemo(NewMemoBuilder("memo-updated", tc.User.ID).Content("Will be updated"))
// Update the memo
newContent := "Updated content"
err := tc.Store.UpdateMemo(tc.Ctx, &store.UpdateMemo{
ID: memo.ID,
Content: &newContent,
})
require.NoError(t, err)
// Test: updated_ts >= now() - 60 (updated in last minute)
memos := tc.ListWithFilter(`updated_ts >= now() - 60`)
require.Len(t, memos, 1)
// Test: updated_ts > now() + 3600 (should be empty)
memos = tc.ListWithFilter(`updated_ts > now() + 3600`)
require.Len(t, memos, 0)
}
func TestMemoFilterAllComparisonOperators(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-ops", tc.User.ID).Content("Comparison operators test"))
// Test: < (less than)
memos := tc.ListWithFilter(`created_ts < now() + 3600`)
require.Len(t, memos, 1)
// Test: <= (less than or equal) with buffer for clock drift
memos = tc.ListWithFilter(`created_ts < now() + 5`)
require.Len(t, memos, 1)
// Test: > (greater than)
memos = tc.ListWithFilter(`created_ts > now() - 3600`)
require.Len(t, memos, 1)
// Test: >= (greater than or equal)
memos = tc.ListWithFilter(`created_ts >= now() - 60`)
require.Len(t, memos, 1)
}
// =============================================================================
// Logical Operator Tests
// Operators: && (AND), || (OR), ! (NOT)
// =============================================================================
func TestMemoFilterLogicalAnd(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
pinnedMemo := tc.CreateMemo(NewMemoBuilder("memo-pinned-public", tc.User.ID).Content("Pinned public"))
tc.PinMemo(pinnedMemo.ID)
tc.CreateMemo(NewMemoBuilder("memo-unpinned-public", tc.User.ID).Content("Unpinned public"))
memos := tc.ListWithFilter(`pinned && visibility == "PUBLIC"`)
require.Len(t, memos, 1)
require.True(t, memos[0].Pinned)
}
func TestMemoFilterLogicalOr(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-public", tc.User.ID).Visibility(store.Public))
tc.CreateMemo(NewMemoBuilder("memo-private", tc.User.ID).Visibility(store.Private))
tc.CreateMemo(NewMemoBuilder("memo-protected", tc.User.ID).Visibility(store.Protected))
memos := tc.ListWithFilter(`visibility == "PUBLIC" || visibility == "PRIVATE"`)
require.Len(t, memos, 2)
}
func TestMemoFilterLogicalNot(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
pinnedMemo := tc.CreateMemo(NewMemoBuilder("memo-pinned", tc.User.ID).Content("Pinned"))
tc.PinMemo(pinnedMemo.ID)
tc.CreateMemo(NewMemoBuilder("memo-unpinned", tc.User.ID).Content("Unpinned"))
memos := tc.ListWithFilter(`!pinned`)
require.Len(t, memos, 1)
require.False(t, memos[0].Pinned)
}
func TestMemoFilterNegatedComparison(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-public", tc.User.ID).Visibility(store.Public))
tc.CreateMemo(NewMemoBuilder("memo-private", tc.User.ID).Visibility(store.Private))
memos := tc.ListWithFilter(`!(visibility == "PUBLIC")`)
require.Len(t, memos, 1)
require.Equal(t, store.Private, memos[0].Visibility)
}
func TestMemoFilterComplexLogical(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create pinned public memo with tags
pinnedMemo := tc.CreateMemo(NewMemoBuilder("memo-pinned-tagged", tc.User.ID).
Content("Pinned and tagged").Tags("important"))
tc.PinMemo(pinnedMemo.ID)
// Create unpinned memo with same tag
tc.CreateMemo(NewMemoBuilder("memo-unpinned-tagged", tc.User.ID).
Content("Unpinned but tagged").Tags("important"))
// Create pinned memo without tag
pinned2 := tc.CreateMemo(NewMemoBuilder("memo-pinned-untagged", tc.User.ID).Content("Pinned but untagged"))
tc.PinMemo(pinned2.ID)
// Test: pinned && tag in ["important"]
memos := tc.ListWithFilter(`pinned && tag in ["important"]`)
require.Len(t, memos, 1)
// Test: (pinned || tag in ["important"]) && visibility == "PUBLIC"
memos = tc.ListWithFilter(`(pinned || tag in ["important"]) && visibility == "PUBLIC"`)
require.Len(t, memos, 3)
// Test: De Morgan's Law ! (A || B) == !A && !B
// ! (pinned || has_task_list)
tc.CreateMemo(NewMemoBuilder("memo-no-props", tc.User.ID).Content("Nothing special"))
memos = tc.ListWithFilter(`!(pinned || has_task_list)`)
require.Len(t, memos, 2) // Unpinned-tagged + Nothing special (pinned-untagged is pinned)
}
// =============================================================================
// Tag Comprehension Tests (exists macro)
// Schema: tags (list of strings, supports exists/all macros with predicates)
// =============================================================================
func TestMemoFilterTagsExistsStartsWith(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memos with different tags
tc.CreateMemo(NewMemoBuilder("memo-archive1", tc.User.ID).
Content("Archived project memo").
Tags("archive/project", "done"))
tc.CreateMemo(NewMemoBuilder("memo-archive2", tc.User.ID).
Content("Archived work memo").
Tags("archive/work", "old"))
tc.CreateMemo(NewMemoBuilder("memo-active", tc.User.ID).
Content("Active project memo").
Tags("project/active", "todo"))
tc.CreateMemo(NewMemoBuilder("memo-homelab", tc.User.ID).
Content("Homelab memo").
Tags("homelab/memos", "tech"))
// Test: tags.exists(t, t.startsWith("archive")) - should match archived memos
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("archive"))`)
require.Len(t, memos, 2, "Should find 2 archived memos")
for _, memo := range memos {
hasArchiveTag := false
for _, tag := range memo.Payload.Tags {
if len(tag) >= 7 && tag[:7] == "archive" {
hasArchiveTag = true
break
}
}
require.True(t, hasArchiveTag, "Memo should have tag starting with 'archive'")
}
// Test: !tags.exists(t, t.startsWith("archive")) - should match non-archived memos
memos = tc.ListWithFilter(`!tags.exists(t, t.startsWith("archive"))`)
require.Len(t, memos, 2, "Should find 2 non-archived memos")
// Test: tags.exists(t, t.startsWith("project")) - should match project memos
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("project"))`)
require.Len(t, memos, 1, "Should find 1 project memo")
// Test: tags.exists(t, t.startsWith("homelab")) - should match homelab memos
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("homelab"))`)
require.Len(t, memos, 1, "Should find 1 homelab memo")
// Test: tags.exists(t, t.startsWith("nonexistent")) - should match nothing
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("nonexistent"))`)
require.Len(t, memos, 0, "Should find no memos")
}
func TestMemoFilterTagsExistsContains(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memos with different tags
tc.CreateMemo(NewMemoBuilder("memo-todo1", tc.User.ID).
Content("Todo task 1").
Tags("project/todo", "urgent"))
tc.CreateMemo(NewMemoBuilder("memo-todo2", tc.User.ID).
Content("Todo task 2").
Tags("work/todo-list", "pending"))
tc.CreateMemo(NewMemoBuilder("memo-done", tc.User.ID).
Content("Done task").
Tags("project/completed", "done"))
// Test: tags.exists(t, t.contains("todo")) - should match todos
memos := tc.ListWithFilter(`tags.exists(t, t.contains("todo"))`)
require.Len(t, memos, 2, "Should find 2 todo memos")
// Test: tags.exists(t, t.contains("done")) - should match done
memos = tc.ListWithFilter(`tags.exists(t, t.contains("done"))`)
require.Len(t, memos, 1, "Should find 1 done memo")
// Test: !tags.exists(t, t.contains("todo")) - should exclude todos
memos = tc.ListWithFilter(`!tags.exists(t, t.contains("todo"))`)
require.Len(t, memos, 1, "Should find 1 non-todo memo")
}
func TestMemoFilterTagsExistsEndsWith(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memos with different tag endings
tc.CreateMemo(NewMemoBuilder("memo-bug", tc.User.ID).
Content("Bug report").
Tags("project/bug", "critical"))
tc.CreateMemo(NewMemoBuilder("memo-debug", tc.User.ID).
Content("Debug session").
Tags("work/debug", "dev"))
tc.CreateMemo(NewMemoBuilder("memo-feature", tc.User.ID).
Content("New feature").
Tags("project/feature", "new"))
// Test: tags.exists(t, t.endsWith("bug")) - should match bug-related tags
memos := tc.ListWithFilter(`tags.exists(t, t.endsWith("bug"))`)
require.Len(t, memos, 2, "Should find 2 bug-related memos")
// Test: tags.exists(t, t.endsWith("feature")) - should match feature
memos = tc.ListWithFilter(`tags.exists(t, t.endsWith("feature"))`)
require.Len(t, memos, 1, "Should find 1 feature memo")
// Test: !tags.exists(t, t.endsWith("bug")) - should exclude bug-related
memos = tc.ListWithFilter(`!tags.exists(t, t.endsWith("bug"))`)
require.Len(t, memos, 1, "Should find 1 non-bug memo")
}
func TestMemoFilterTagsExistsCombinedWithOtherFilters(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memos with tags and other properties
tc.CreateMemo(NewMemoBuilder("memo-archived-old", tc.User.ID).
Content("Old archived memo").
Tags("archive/old", "done"))
tc.CreateMemo(NewMemoBuilder("memo-archived-recent", tc.User.ID).
Content("Recent archived memo with TODO").
Tags("archive/recent", "done"))
tc.CreateMemo(NewMemoBuilder("memo-active-todo", tc.User.ID).
Content("Active TODO").
Tags("project/active", "todo"))
// Test: Combine tag filter with content filter
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) && content.contains("TODO")`)
require.Len(t, memos, 1, "Should find 1 archived memo with TODO in content")
// Test: OR condition with tag filters
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) || tags.exists(t, t.contains("todo"))`)
require.Len(t, memos, 3, "Should find all memos (archived or with todo tag)")
// Test: Complex filter - archived but not containing "Recent"
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) && !content.contains("Recent")`)
require.Len(t, memos, 1, "Should find 1 old archived memo")
}
func TestMemoFilterTagsExistsEmptyAndNullCases(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memo with no tags
tc.CreateMemo(NewMemoBuilder("memo-no-tags", tc.User.ID).
Content("Memo without tags"))
// Create memo with tags
tc.CreateMemo(NewMemoBuilder("memo-with-tags", tc.User.ID).
Content("Memo with tags").
Tags("tag1", "tag2"))
// Test: tags.exists should not match memos without tags
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("tag"))`)
require.Len(t, memos, 1, "Should only find memo with tags")
// Test: Negation should match memos without matching tags
memos = tc.ListWithFilter(`!tags.exists(t, t.startsWith("tag"))`)
require.Len(t, memos, 1, "Should find memo without matching tags")
}
func TestMemoFilterIssue5480_ArchiveWorkflow(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create a realistic scenario as described in issue #5480
// User has hierarchical tags and archives memos by prefixing with "archive"
// Active memos
tc.CreateMemo(NewMemoBuilder("memo-homelab", tc.User.ID).
Content("Setting up Memos").
Tags("homelab/memos", "tech"))
tc.CreateMemo(NewMemoBuilder("memo-project-alpha", tc.User.ID).
Content("Project Alpha notes").
Tags("work/project-alpha", "active"))
// Archived memos (user prefixed tags with "archive")
tc.CreateMemo(NewMemoBuilder("memo-old-homelab", tc.User.ID).
Content("Old homelab setup").
Tags("archive/homelab/old-server", "done"))
tc.CreateMemo(NewMemoBuilder("memo-old-project", tc.User.ID).
Content("Old project beta").
Tags("archive/work/project-beta", "completed"))
tc.CreateMemo(NewMemoBuilder("memo-archived-personal", tc.User.ID).
Content("Archived personal note").
Tags("archive/personal/2024", "old"))
// Test: Filter out ALL archived memos using startsWith
memos := tc.ListWithFilter(`!tags.exists(t, t.startsWith("archive"))`)
require.Len(t, memos, 2, "Should only show active memos (not archived)")
for _, memo := range memos {
for _, tag := range memo.Payload.Tags {
require.NotContains(t, tag, "archive", "Active memos should not have archive prefix")
}
}
// Test: Show ONLY archived memos
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive"))`)
require.Len(t, memos, 3, "Should find all archived memos")
for _, memo := range memos {
hasArchiveTag := false
for _, tag := range memo.Payload.Tags {
if len(tag) >= 7 && tag[:7] == "archive" {
hasArchiveTag = true
break
}
}
require.True(t, hasArchiveTag, "All returned memos should have archive prefix")
}
// Test: Filter archived homelab memos specifically
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive/homelab"))`)
require.Len(t, memos, 1, "Should find only archived homelab memos")
}
// =============================================================================
// Multiple Filters Tests
// =============================================================================
func TestMemoFilterMultipleFilters(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-public-hello", tc.User.ID).Content("Hello world").Visibility(store.Public))
tc.CreateMemo(NewMemoBuilder("memo-private-hello", tc.User.ID).Content("Hello private").Visibility(store.Private))
// Test: Multiple filters (applied as AND)
memos := tc.ListWithFilters(`content.contains("Hello")`, `visibility == "PUBLIC"`)
require.Len(t, memos, 1)
require.Contains(t, memos[0].Content, "Hello")
require.Equal(t, store.Public, memos[0].Visibility)
}
// =============================================================================
// Edge Cases
// =============================================================================
func TestMemoFilterNullPayload(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-null-payload", tc.User.ID).Content("Null payload"))
// Test: has_link should not crash and return no results
memos := tc.ListWithFilter(`has_link`)
require.Len(t, memos, 0)
}
func TestMemoFilterNoMatches(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-test", tc.User.ID).Content("Test content"))
memos := tc.ListWithFilter(`content.contains("nonexistent12345")`)
require.Len(t, memos, 0)
}
func TestMemoFilterJSONBooleanLogic(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// 1. Memo with task list (true) and NO link (null)
tc.CreateMemo(NewMemoBuilder("memo-task-only", tc.User.ID).
Content("Task only").
Property(func(p *storepb.MemoPayload_Property) { p.HasTaskList = true }))
// 2. Memo with link (true) and NO task list (null)
tc.CreateMemo(NewMemoBuilder("memo-link-only", tc.User.ID).
Content("Link only").
Property(func(p *storepb.MemoPayload_Property) { p.HasLink = true }))
// 3. Memo with both (true)
tc.CreateMemo(NewMemoBuilder("memo-both", tc.User.ID).
Content("Both").
Property(func(p *storepb.MemoPayload_Property) {
p.HasTaskList = true
p.HasLink = true
}))
// 4. Memo with neither (null)
tc.CreateMemo(NewMemoBuilder("memo-neither", tc.User.ID).Content("Neither"))
// Test A: has_task_list || has_link
// Expected: 3 memos (task-only, link-only, both). Neither should be excluded.
// This specifically tests the NULL handling in OR logic (NULL || TRUE should be TRUE)
memos := tc.ListWithFilter(`has_task_list || has_link`)
require.Len(t, memos, 3, "Should find 3 memos with OR logic")
// Test B: !has_task_list
// Expected: 2 memos (link-only, neither). Memos where has_task_list is NULL or FALSE.
// Note: If NULL is not handled, !NULL is still NULL (false-y in WHERE), so "neither" might be missed depending on logic.
// In our implementation, we want missing fields to behave as false.
memos = tc.ListWithFilter(`!has_task_list`)
require.Len(t, memos, 2, "Should find 2 memos where task list is false or missing")
// Test C: has_task_list && !has_link
// Expected: 1 memo (task-only).
memos = tc.ListWithFilter(`has_task_list && !has_link`)
require.Len(t, memos, 1, "Should find 1 memo (task only)")
}

View File

@@ -0,0 +1,682 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
func TestMemoRelationStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memoCreate := &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
}
memo, err := ts.CreateMemo(ctx, memoCreate)
require.NoError(t, err)
require.Equal(t, memoCreate.Content, memo.Content)
relatedMemoCreate := &store.Memo{
UID: "related-memo",
CreatorID: user.ID,
Content: "related memo content",
Visibility: store.Public,
}
relatedMemo, err := ts.CreateMemo(ctx, relatedMemoCreate)
require.NoError(t, err)
require.Equal(t, relatedMemoCreate.Content, relatedMemo.Content)
commentMemoCreate := &store.Memo{
UID: "comment-memo",
CreatorID: user.ID,
Content: "comment memo content",
Visibility: store.Public,
}
commentMemo, err := ts.CreateMemo(ctx, commentMemoCreate)
require.NoError(t, err)
require.Equal(t, commentMemoCreate.Content, commentMemo.Content)
// Reference relation.
referenceRelation := &store.MemoRelation{
MemoID: memo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
}
_, err = ts.UpsertMemoRelation(ctx, referenceRelation)
require.NoError(t, err)
// Comment relation.
commentRelation := &store.MemoRelation{
MemoID: memo.ID,
RelatedMemoID: commentMemo.ID,
Type: store.MemoRelationComment,
}
_, err = ts.UpsertMemoRelation(ctx, commentRelation)
require.NoError(t, err)
ts.Close()
}
func TestMemoRelationListByMemoID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create main memo
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create related memos
relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-1",
CreatorID: user.ID,
Content: "related memo 1 content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-2",
CreatorID: user.ID,
Content: "related memo 2 content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create relations
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo1.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo2.ID,
Type: store.MemoRelationComment,
})
require.NoError(t, err)
// List by memo ID
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
})
require.NoError(t, err)
require.Equal(t, 2, len(relations))
// List by type
refType := store.MemoRelationReference
refRelations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
Type: &refType,
})
require.NoError(t, err)
require.Equal(t, 1, len(refRelations))
require.Equal(t, store.MemoRelationReference, refRelations[0].Type)
// List by related memo ID
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &relatedMemo1.ID,
})
require.NoError(t, err)
require.Equal(t, 1, len(relations))
ts.Close()
}
func TestMemoRelationDelete(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create memos
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo",
CreatorID: user.ID,
Content: "related memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create relation
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Verify relation exists
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
})
require.NoError(t, err)
require.Equal(t, 1, len(relations))
// Delete relation by memo ID
relType := store.MemoRelationReference
err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &mainMemo.ID,
RelatedMemoID: &relatedMemo.ID,
Type: &relType,
})
require.NoError(t, err)
// Verify relation is deleted
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
})
require.NoError(t, err)
require.Equal(t, 0, len(relations))
ts.Close()
}
func TestMemoRelationDifferentTypes(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo",
CreatorID: user.ID,
Content: "related memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create reference relation
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Create comment relation (same memos, different type - should be allowed)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationComment,
})
require.NoError(t, err)
// Verify both relations exist
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
})
require.NoError(t, err)
require.Equal(t, 2, len(relations))
ts.Close()
}
func TestMemoRelationUpsertSameRelation(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo",
CreatorID: user.ID,
Content: "related memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create relation
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Upsert the same relation again (should not create duplicate)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Verify only one relation exists
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
})
require.NoError(t, err)
require.Len(t, relations, 1)
ts.Close()
}
func TestMemoRelationDeleteByType(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-1",
CreatorID: user.ID,
Content: "related memo 1 content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-2",
CreatorID: user.ID,
Content: "related memo 2 content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create reference relations
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo1.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Create comment relation
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo2.ID,
Type: store.MemoRelationComment,
})
require.NoError(t, err)
// Delete only reference type relations
refType := store.MemoRelationReference
err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &mainMemo.ID,
Type: &refType,
})
require.NoError(t, err)
// Verify only comment relation remains
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
})
require.NoError(t, err)
require.Len(t, relations, 1)
require.Equal(t, store.MemoRelationComment, relations[0].Type)
ts.Close()
}
func TestMemoRelationDeleteByMemoID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memo1, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-1",
CreatorID: user.ID,
Content: "memo 1 content",
Visibility: store.Public,
})
require.NoError(t, err)
memo2, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-2",
CreatorID: user.ID,
Content: "memo 2 content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo",
CreatorID: user.ID,
Content: "related memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create relations for both memos
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo1.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo2.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Delete all relations for memo1
err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &memo1.ID,
})
require.NoError(t, err)
// Verify memo1's relations are gone
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memo1.ID,
})
require.NoError(t, err)
require.Len(t, relations, 0)
// Verify memo2's relations still exist
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memo2.ID,
})
require.NoError(t, err)
require.Len(t, relations, 1)
ts.Close()
}
func TestMemoRelationListByRelatedMemoID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a memo that will be referenced by others
targetMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "target-memo",
CreatorID: user.ID,
Content: "target memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create memos that reference the target
referrer1, err := ts.CreateMemo(ctx, &store.Memo{
UID: "referrer-1",
CreatorID: user.ID,
Content: "referrer 1 content",
Visibility: store.Public,
})
require.NoError(t, err)
referrer2, err := ts.CreateMemo(ctx, &store.Memo{
UID: "referrer-2",
CreatorID: user.ID,
Content: "referrer 2 content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create relations pointing to target
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: referrer1.ID,
RelatedMemoID: targetMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: referrer2.ID,
RelatedMemoID: targetMemo.ID,
Type: store.MemoRelationComment,
})
require.NoError(t, err)
// List by related memo ID (find all memos that reference the target)
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &targetMemo.ID,
})
require.NoError(t, err)
require.Len(t, relations, 2)
ts.Close()
}
func TestMemoRelationListCombinedFilters(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-1",
CreatorID: user.ID,
Content: "related memo 1 content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-2",
CreatorID: user.ID,
Content: "related memo 2 content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create multiple relations
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo1.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo2.ID,
Type: store.MemoRelationComment,
})
require.NoError(t, err)
// List with MemoID and Type filter
refType := store.MemoRelationReference
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
Type: &refType,
})
require.NoError(t, err)
require.Len(t, relations, 1)
require.Equal(t, relatedMemo1.ID, relations[0].RelatedMemoID)
// List with MemoID, RelatedMemoID, and Type filter
commentType := store.MemoRelationComment
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
RelatedMemoID: &relatedMemo2.ID,
Type: &commentType,
})
require.NoError(t, err)
require.Len(t, relations, 1)
ts.Close()
}
func TestMemoRelationListEmpty(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-no-relations",
CreatorID: user.ID,
Content: "memo with no relations",
Visibility: store.Public,
})
require.NoError(t, err)
// List relations for memo with none
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memo.ID,
})
require.NoError(t, err)
require.Len(t, relations, 0)
ts.Close()
}
func TestMemoRelationBidirectional(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memoA, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-a",
CreatorID: user.ID,
Content: "memo A content",
Visibility: store.Public,
})
require.NoError(t, err)
memoB, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-b",
CreatorID: user.ID,
Content: "memo B content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create relation A -> B
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memoA.ID,
RelatedMemoID: memoB.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Create relation B -> A (reverse direction)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memoB.ID,
RelatedMemoID: memoA.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Verify A -> B exists
relationsFromA, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memoA.ID,
})
require.NoError(t, err)
require.Len(t, relationsFromA, 1)
require.Equal(t, memoB.ID, relationsFromA[0].RelatedMemoID)
// Verify B -> A exists
relationsFromB, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memoB.ID,
})
require.NoError(t, err)
require.Len(t, relationsFromB, 1)
require.Equal(t, memoA.ID, relationsFromB[0].RelatedMemoID)
ts.Close()
}
func TestMemoRelationMultipleRelationsToSameMemo(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create multiple memos that all relate to the main memo
for i := 1; i <= 5; i++ {
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-" + string(rune('0'+i)),
CreatorID: user.ID,
Content: "related memo content",
Visibility: store.Public,
})
require.NoError(t, err)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
}
// Verify all 5 relations exist
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
})
require.NoError(t, err)
require.Len(t, relations, 5)
ts.Close()
}

478
store/test/memo_test.go Normal file
View File

@@ -0,0 +1,478 @@
package test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestMemoStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memoCreate := &store.Memo{
UID: "test-resource-name",
CreatorID: user.ID,
Content: "test_content",
Visibility: store.Public,
}
memo, err := ts.CreateMemo(ctx, memoCreate)
require.NoError(t, err)
require.Equal(t, memoCreate.Content, memo.Content)
memoPatchContent := "test_content_2"
memoPatch := &store.UpdateMemo{
ID: memo.ID,
Content: &memoPatchContent,
}
err = ts.UpdateMemo(ctx, memoPatch)
require.NoError(t, err)
memo, err = ts.GetMemo(ctx, &store.FindMemo{
ID: &memo.ID,
})
require.NoError(t, err)
require.NotNil(t, memo)
memoList, err := ts.ListMemos(ctx, &store.FindMemo{
CreatorID: &user.ID,
})
require.NoError(t, err)
require.Equal(t, 1, len(memoList))
require.Equal(t, memo, memoList[0])
err = ts.DeleteMemo(ctx, &store.DeleteMemo{
ID: memo.ID,
})
require.NoError(t, err)
memoList, err = ts.ListMemos(ctx, &store.FindMemo{
CreatorID: &user.ID,
})
require.NoError(t, err)
require.Equal(t, 0, len(memoList))
memoList, err = ts.ListMemos(ctx, &store.FindMemo{
CreatorID: &user.ID,
VisibilityList: []store.Visibility{store.Public},
})
require.NoError(t, err)
require.Equal(t, 0, len(memoList))
ts.Close()
}
func TestMemoListByTags(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memoCreate := &store.Memo{
UID: "test-resource-name",
CreatorID: user.ID,
Content: "test_content",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: []string{"test_tag"},
},
}
memo, err := ts.CreateMemo(ctx, memoCreate)
require.NoError(t, err)
require.Equal(t, memoCreate.Content, memo.Content)
memo, err = ts.GetMemo(ctx, &store.FindMemo{
ID: &memo.ID,
})
require.NoError(t, err)
require.NotNil(t, memo)
memoList, err := ts.ListMemos(ctx, &store.FindMemo{
Filters: []string{"tag in [\"test_tag\"]"},
})
require.NoError(t, err)
require.Equal(t, 1, len(memoList))
require.Equal(t, memo, memoList[0])
ts.Close()
}
func TestDeleteMemoStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memoCreate := &store.Memo{
UID: "test-resource-name",
CreatorID: user.ID,
Content: "test_content",
Visibility: store.Public,
}
memo, err := ts.CreateMemo(ctx, memoCreate)
require.NoError(t, err)
require.Equal(t, memoCreate.Content, memo.Content)
err = ts.DeleteMemo(ctx, &store.DeleteMemo{
ID: memo.ID,
})
require.NoError(t, err)
ts.Close()
}
func TestMemoGetByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "test-memo-1",
CreatorID: user.ID,
Content: "test content",
Visibility: store.Public,
})
require.NoError(t, err)
// Get by ID
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, memo.ID, found.ID)
require.Equal(t, memo.Content, found.Content)
// Get non-existent
nonExistentID := int32(99999)
notFound, err := ts.GetMemo(ctx, &store.FindMemo{ID: &nonExistentID})
require.NoError(t, err)
require.Nil(t, notFound)
ts.Close()
}
func TestMemoGetByUID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
uid := "unique-memo-uid"
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: uid,
CreatorID: user.ID,
Content: "test content",
Visibility: store.Public,
})
require.NoError(t, err)
// Get by UID
found, err := ts.GetMemo(ctx, &store.FindMemo{UID: &uid})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, memo.UID, found.UID)
// Get non-existent UID
nonExistentUID := "non-existent-uid"
notFound, err := ts.GetMemo(ctx, &store.FindMemo{UID: &nonExistentUID})
require.NoError(t, err)
require.Nil(t, notFound)
ts.Close()
}
func TestMemoListByVisibility(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create memos with different visibilities
_, err = ts.CreateMemo(ctx, &store.Memo{
UID: "public-memo",
CreatorID: user.ID,
Content: "public content",
Visibility: store.Public,
})
require.NoError(t, err)
_, err = ts.CreateMemo(ctx, &store.Memo{
UID: "protected-memo",
CreatorID: user.ID,
Content: "protected content",
Visibility: store.Protected,
})
require.NoError(t, err)
_, err = ts.CreateMemo(ctx, &store.Memo{
UID: "private-memo",
CreatorID: user.ID,
Content: "private content",
Visibility: store.Private,
})
require.NoError(t, err)
// List public memos only
publicMemos, err := ts.ListMemos(ctx, &store.FindMemo{
VisibilityList: []store.Visibility{store.Public},
})
require.NoError(t, err)
require.Equal(t, 1, len(publicMemos))
require.Equal(t, store.Public, publicMemos[0].Visibility)
// List protected memos only
protectedMemos, err := ts.ListMemos(ctx, &store.FindMemo{
VisibilityList: []store.Visibility{store.Protected},
})
require.NoError(t, err)
require.Equal(t, 1, len(protectedMemos))
require.Equal(t, store.Protected, protectedMemos[0].Visibility)
// List public and protected (multiple visibility)
publicAndProtected, err := ts.ListMemos(ctx, &store.FindMemo{
VisibilityList: []store.Visibility{store.Public, store.Protected},
})
require.NoError(t, err)
require.Equal(t, 2, len(publicAndProtected))
// List all
allMemos, err := ts.ListMemos(ctx, &store.FindMemo{})
require.NoError(t, err)
require.Equal(t, 3, len(allMemos))
ts.Close()
}
func TestMemoListWithPagination(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create 10 memos
for i := 0; i < 10; i++ {
_, err := ts.CreateMemo(ctx, &store.Memo{
UID: fmt.Sprintf("memo-%d", i),
CreatorID: user.ID,
Content: fmt.Sprintf("content %d", i),
Visibility: store.Public,
})
require.NoError(t, err)
}
// Test limit
limit := 5
limitedMemos, err := ts.ListMemos(ctx, &store.FindMemo{Limit: &limit})
require.NoError(t, err)
require.Equal(t, 5, len(limitedMemos))
// Test offset
offset := 3
offsetMemos, err := ts.ListMemos(ctx, &store.FindMemo{Limit: &limit, Offset: &offset})
require.NoError(t, err)
require.Equal(t, 5, len(offsetMemos))
// Verify offset works correctly (different memos)
require.NotEqual(t, limitedMemos[0].ID, offsetMemos[0].ID)
ts.Close()
}
func TestMemoUpdatePinned(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "pinnable-memo",
CreatorID: user.ID,
Content: "content",
Visibility: store.Public,
})
require.NoError(t, err)
require.False(t, memo.Pinned)
// Pin the memo
pinned := true
err = ts.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Pinned: &pinned,
})
require.NoError(t, err)
// Verify pinned
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
require.NoError(t, err)
require.True(t, found.Pinned)
// Unpin
unpinned := false
err = ts.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Pinned: &unpinned,
})
require.NoError(t, err)
found, err = ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
require.NoError(t, err)
require.False(t, found.Pinned)
ts.Close()
}
func TestMemoUpdateVisibility(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "visibility-memo",
CreatorID: user.ID,
Content: "content",
Visibility: store.Public,
})
require.NoError(t, err)
require.Equal(t, store.Public, memo.Visibility)
// Change to private
privateVisibility := store.Private
err = ts.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Visibility: &privateVisibility,
})
require.NoError(t, err)
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
require.NoError(t, err)
require.Equal(t, store.Private, found.Visibility)
// Change to protected
protectedVisibility := store.Protected
err = ts.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Visibility: &protectedVisibility,
})
require.NoError(t, err)
found, err = ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
require.NoError(t, err)
require.Equal(t, store.Protected, found.Visibility)
ts.Close()
}
func TestMemoInvalidUID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create memo with invalid UID (contains special characters)
_, err = ts.CreateMemo(ctx, &store.Memo{
UID: "invalid uid with spaces",
CreatorID: user.ID,
Content: "content",
Visibility: store.Public,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid uid")
ts.Close()
}
func TestMemoCreateWithCustomTimestamps(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
customCreatedTs := int64(1700000000) // 2023-11-14 22:13:20 UTC
customUpdatedTs := int64(1700000001)
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "custom-timestamp-memo",
CreatorID: user.ID,
Content: "content with custom timestamps",
Visibility: store.Public,
CreatedTs: customCreatedTs,
UpdatedTs: customUpdatedTs,
})
require.NoError(t, err)
require.Equal(t, customCreatedTs, memo.CreatedTs)
require.Equal(t, customUpdatedTs, memo.UpdatedTs)
// Fetch and verify timestamps are preserved
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, customCreatedTs, found.CreatedTs)
require.Equal(t, customUpdatedTs, found.UpdatedTs)
ts.Close()
}
func TestMemoCreateWithOnlyCreatedTs(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
customCreatedTs := int64(1609459200) // 2021-01-01 00:00:00 UTC
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "custom-created-ts-only",
CreatorID: user.ID,
Content: "content with custom created_ts only",
Visibility: store.Public,
CreatedTs: customCreatedTs,
})
require.NoError(t, err)
require.Equal(t, customCreatedTs, memo.CreatedTs)
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, customCreatedTs, found.CreatedTs)
ts.Close()
}
func TestMemoWithPayload(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create memo with tags in payload
tags := []string{"tag1", "tag2", "tag3"}
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-with-payload",
CreatorID: user.ID,
Content: "content with tags",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: tags,
},
})
require.NoError(t, err)
require.NotNil(t, memo.Payload)
require.Equal(t, tags, memo.Payload.Tags)
// Fetch and verify
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
require.NoError(t, err)
require.NotNil(t, found.Payload)
require.Equal(t, tags, found.Payload.Tags)
ts.Close()
}

200
store/test/migrator_test.go Normal file
View File

@@ -0,0 +1,200 @@
package test
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
// TestFreshInstall verifies that LATEST.sql applies correctly on a fresh database.
// This is essentially what NewTestingStore already does, but we make it explicit.
func TestFreshInstall(t *testing.T) {
t.Parallel()
ctx := context.Background()
// NewTestingStore creates a fresh database and runs Migrate()
// which applies LATEST.sql for uninitialized databases
ts := NewTestingStore(ctx, t)
// Verify migration completed successfully
currentSchemaVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
require.NotEmpty(t, currentSchemaVersion, "schema version should be set after fresh install")
// Verify we can read instance settings (basic sanity check)
instanceSetting, err := ts.GetInstanceBasicSetting(ctx)
require.NoError(t, err)
require.Equal(t, currentSchemaVersion, instanceSetting.SchemaVersion)
}
// TestMigrationReRun verifies that re-running the migration on an already
// migrated database does not fail or cause issues. This simulates a
// scenario where the server is restarted.
func TestMigrationReRun(t *testing.T) {
t.Parallel()
ctx := context.Background()
// Use the shared testing store which already runs migrations on init
ts := NewTestingStore(ctx, t)
// Get current version
initialVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
// Manually trigger migration again
err = ts.Migrate(ctx)
require.NoError(t, err, "re-running migration should not fail")
// Verify version hasn't changed (or at least is valid)
finalVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
require.Equal(t, initialVersion, finalVersion, "version should match after re-run")
}
// TestMigrationWithData verifies that migration preserves data integrity.
// Creates data, then re-runs migration and verifies data is still accessible.
func TestMigrationWithData(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create a user and memo before re-running migration
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err, "should create user")
originalMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "migration-data-test",
CreatorID: user.ID,
Content: "Data before migration re-run",
Visibility: store.Public,
})
require.NoError(t, err, "should create memo")
// Re-run migration
err = ts.Migrate(ctx)
require.NoError(t, err, "re-running migration should not fail")
// Verify data is still accessible
memo, err := ts.GetMemo(ctx, &store.FindMemo{UID: &originalMemo.UID})
require.NoError(t, err, "should retrieve memo after migration")
require.Equal(t, "Data before migration re-run", memo.Content, "memo content should be preserved")
}
// TestMigrationMultipleReRuns verifies that migration is idempotent
// even when run multiple times in succession.
func TestMigrationMultipleReRuns(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Get initial version
initialVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
// Run migration multiple times
for i := 0; i < 3; i++ {
err = ts.Migrate(ctx)
require.NoError(t, err, "migration run %d should not fail", i+1)
}
// Verify version is still correct
finalVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
require.Equal(t, initialVersion, finalVersion, "version should remain unchanged after multiple re-runs")
}
// TestMigrationFromStableVersion verifies that upgrading from a stable Memos version
// to the current version works correctly. This is the critical upgrade path test.
//
// Test flow:
// 1. Start a stable Memos container to create a database with the old schema
// 2. Stop the container and wait for cleanup
// 3. Use the store directly to run migration with current code
// 4. Verify the migration succeeded and data can be written
//
// Note: This test is skipped when running with -race flag because testcontainers
// has known race conditions in its reaper code that are outside our control.
func TestMigrationFromStableVersion(t *testing.T) {
// Skip for non-SQLite drivers (simplifies the test)
if getDriverFromEnv() != "sqlite" {
t.Skip("skipping upgrade test for non-sqlite driver")
}
// Skip if explicitly disabled (e.g., in environments without Docker)
if os.Getenv("SKIP_CONTAINER_TESTS") == "1" {
t.Skip("skipping container-based test (SKIP_CONTAINER_TESTS=1)")
}
ctx := context.Background()
dataDir := t.TempDir()
// 1. Start stable Memos container to create database with old schema
cfg := MemosContainerConfig{
Driver: "sqlite",
DataDir: dataDir,
Version: StableMemosVersion,
}
t.Logf("Starting Memos %s container to create old-schema database...", cfg.Version)
container, err := StartMemosContainer(ctx, cfg)
require.NoError(t, err, "failed to start stable memos container")
// Wait for the container to fully initialize the database
time.Sleep(10 * time.Second)
// Stop the container gracefully
t.Log("Stopping stable Memos container...")
err = container.Terminate(ctx)
require.NoError(t, err, "failed to stop memos container")
// Wait for file handles to be released
time.Sleep(2 * time.Second)
// 2. Connect to the database directly and run migration with current code
dsn := fmt.Sprintf("%s/memos_prod.db", dataDir)
t.Logf("Connecting to database at %s...", dsn)
ts := NewTestingStoreWithDSN(ctx, t, "sqlite", dsn)
// Get the schema version before migration
oldSetting, err := ts.GetInstanceBasicSetting(ctx)
require.NoError(t, err)
t.Logf("Old schema version: %s", oldSetting.SchemaVersion)
// 3. Run migration with current code
t.Log("Running migration with current code...")
err = ts.Migrate(ctx)
require.NoError(t, err, "migration from stable version should succeed")
// 4. Verify migration succeeded
newVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
t.Logf("New schema version: %s", newVersion)
newSetting, err := ts.GetInstanceBasicSetting(ctx)
require.NoError(t, err)
require.Equal(t, newVersion, newSetting.SchemaVersion, "schema version should be updated")
// Verify we can write data to the migrated database
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err, "should create user after migration")
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "post-upgrade-memo",
CreatorID: user.ID,
Content: "Content after upgrade from stable",
Visibility: store.Public,
})
require.NoError(t, err, "should create memo after migration")
require.Equal(t, "Content after upgrade from stable", memo.Content)
t.Logf("Migration successful: %s -> %s", oldSetting.SchemaVersion, newVersion)
}

189
store/test/reaction_test.go Normal file
View File

@@ -0,0 +1,189 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
func TestReactionStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
contentID := "test_content_id"
reaction, err := ts.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: contentID,
ReactionType: "💗",
})
require.NoError(t, err)
require.NotNil(t, reaction)
require.NotEmpty(t, reaction.ID)
reactions, err := ts.ListReactions(ctx, &store.FindReaction{
ContentID: &contentID,
})
require.NoError(t, err)
require.Len(t, reactions, 1)
require.Equal(t, reaction, reactions[0])
// Test GetReaction.
gotReaction, err := ts.GetReaction(ctx, &store.FindReaction{
ID: &reaction.ID,
})
require.NoError(t, err)
require.NotNil(t, gotReaction)
require.Equal(t, reaction.ID, gotReaction.ID)
require.Equal(t, reaction.CreatorID, gotReaction.CreatorID)
require.Equal(t, reaction.ContentID, gotReaction.ContentID)
require.Equal(t, reaction.ReactionType, gotReaction.ReactionType)
// Test GetReaction with non-existent ID.
nonExistentID := int32(99999)
notFoundReaction, err := ts.GetReaction(ctx, &store.FindReaction{
ID: &nonExistentID,
})
require.NoError(t, err)
require.Nil(t, notFoundReaction)
err = ts.DeleteReaction(ctx, &store.DeleteReaction{
ID: reaction.ID,
})
require.NoError(t, err)
reactions, err = ts.ListReactions(ctx, &store.FindReaction{
ContentID: &contentID,
})
require.NoError(t, err)
require.Len(t, reactions, 0)
ts.Close()
}
func TestReactionListByCreatorID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
contentID := "shared_content"
// User 1 creates reaction
_, err = ts.UpsertReaction(ctx, &store.Reaction{
CreatorID: user1.ID,
ContentID: contentID,
ReactionType: "👍",
})
require.NoError(t, err)
// User 2 creates reaction
_, err = ts.UpsertReaction(ctx, &store.Reaction{
CreatorID: user2.ID,
ContentID: contentID,
ReactionType: "❤️",
})
require.NoError(t, err)
// List all reactions for content
reactions, err := ts.ListReactions(ctx, &store.FindReaction{
ContentID: &contentID,
})
require.NoError(t, err)
require.Len(t, reactions, 2)
// List by creator ID
user1Reactions, err := ts.ListReactions(ctx, &store.FindReaction{
CreatorID: &user1.ID,
})
require.NoError(t, err)
require.Len(t, user1Reactions, 1)
require.Equal(t, "👍", user1Reactions[0].ReactionType)
ts.Close()
}
func TestReactionMultipleContentIDs(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
contentID1 := "content_1"
contentID2 := "content_2"
// Create reactions for different contents
_, err = ts.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: contentID1,
ReactionType: "👍",
})
require.NoError(t, err)
_, err = ts.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: contentID2,
ReactionType: "❤️",
})
require.NoError(t, err)
// List by content ID list
reactions, err := ts.ListReactions(ctx, &store.FindReaction{
ContentIDList: []string{contentID1, contentID2},
})
require.NoError(t, err)
require.Len(t, reactions, 2)
ts.Close()
}
func TestReactionUpsertDifferentTypes(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
contentID := "test_content"
// Create first reaction
reaction1, err := ts.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: contentID,
ReactionType: "👍",
})
require.NoError(t, err)
// Create second reaction with different type (should create new, not update)
reaction2, err := ts.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: contentID,
ReactionType: "❤️",
})
require.NoError(t, err)
// Both reactions should exist
require.NotEqual(t, reaction1.ID, reaction2.ID)
reactions, err := ts.ListReactions(ctx, &store.FindReaction{
ContentID: &contentID,
})
require.NoError(t, err)
require.Len(t, reactions, 2)
ts.Close()
}

111
store/test/store.go Normal file
View File

@@ -0,0 +1,111 @@
package test
import (
"context"
"fmt"
"net"
"os"
"testing"
// sqlite driver.
_ "modernc.org/sqlite"
"github.com/joho/godotenv"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/internal/version"
"github.com/usememos/memos/store"
"github.com/usememos/memos/store/db"
)
// NewTestingStore creates a new testing store with a fresh database.
// Each test gets its own isolated database:
// - SQLite: new temp file per test
// - MySQL/PostgreSQL: new database per test in shared container
func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
driver := getDriverFromEnv()
profile := getTestingProfileForDriver(t, driver)
dbDriver, err := db.NewDBDriver(profile)
if err != nil {
t.Fatalf("failed to create db driver: %v", err)
}
store := store.New(dbDriver, profile)
if err := store.Migrate(ctx); err != nil {
t.Fatalf("failed to migrate db: %v", err)
}
return store
}
// NewTestingStoreWithDSN creates a testing store connected to a specific DSN.
// This is useful for testing migrations on existing data.
func NewTestingStoreWithDSN(_ context.Context, t *testing.T, driver, dsn string) *store.Store {
profile := &profile.Profile{
Port: getUnusedPort(),
Data: t.TempDir(), // Dummy dir, DSN matters
DSN: dsn,
Driver: driver,
Version: version.GetCurrentVersion(),
}
dbDriver, err := db.NewDBDriver(profile)
if err != nil {
t.Fatalf("failed to create db driver: %v", err)
}
store := store.New(dbDriver, profile)
// Do not run Migrate() automatically, as we might be testing pre-migration state
// or want to run it manually.
return store
}
func getUnusedPort() int {
// Get a random unused port
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
panic(err)
}
defer listener.Close()
// Get the port number
port := listener.Addr().(*net.TCPAddr).Port
return port
}
// getTestingProfileForDriver creates a testing profile for a specific driver.
func getTestingProfileForDriver(t *testing.T, driver string) *profile.Profile {
// Attempt to load .env file if present (optional, for local development)
_ = godotenv.Load(".env")
// Get a temporary directory for the test data.
dir := t.TempDir()
mode := "prod"
port := getUnusedPort()
var dsn string
switch driver {
case "sqlite":
dsn = fmt.Sprintf("%s/memos_%s.db", dir, mode)
case "mysql":
dsn = GetMySQLDSN(t)
case "postgres":
dsn = GetPostgresDSN(t)
default:
t.Fatalf("unsupported driver: %s", driver)
}
return &profile.Profile{
Port: port,
Data: dir,
DSN: dsn,
Driver: driver,
Version: version.GetCurrentVersion(),
}
}
func getDriverFromEnv() string {
driver := os.Getenv("DRIVER")
if driver == "" {
driver = "sqlite"
}
return driver
}

View File

@@ -0,0 +1,993 @@
package test
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func TestUserSettingStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_GENERAL,
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "en"}},
})
require.NoError(t, err)
list, err := ts.ListUserSettings(ctx, &store.FindUserSetting{})
require.NoError(t, err)
require.Equal(t, 1, len(list))
ts.Close()
}
func TestUserSettingGetByUserID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create setting
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_GENERAL,
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "zh"}},
})
require.NoError(t, err)
// Get by user ID
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_GENERAL,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.Equal(t, "zh", setting.GetGeneral().Locale)
// Get non-existent key
nonExistentSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.Nil(t, nonExistentSetting)
ts.Close()
}
func TestUserSettingUpsertUpdate(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create initial setting
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_GENERAL,
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "en"}},
})
require.NoError(t, err)
// Update setting
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_GENERAL,
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "fr"}},
})
require.NoError(t, err)
// Verify update
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_GENERAL,
})
require.NoError(t, err)
require.Equal(t, "fr", setting.GetGeneral().Locale)
// Verify only one setting exists
list, err := ts.ListUserSettings(ctx, &store.FindUserSetting{UserID: &user.ID})
require.NoError(t, err)
require.Equal(t, 1, len(list))
ts.Close()
}
func TestUserSettingRefreshTokens(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Initially no tokens
tokens, err := ts.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
require.Empty(t, tokens)
// Add a refresh token
token1 := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: "token-1",
Description: "Chrome browser session",
}
err = ts.AddUserRefreshToken(ctx, user.ID, token1)
require.NoError(t, err)
// Verify token was added
tokens, err = ts.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, tokens, 1)
require.Equal(t, "token-1", tokens[0].TokenId)
// Add another token
token2 := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: "token-2",
Description: "Firefox browser session",
}
err = ts.AddUserRefreshToken(ctx, user.ID, token2)
require.NoError(t, err)
tokens, err = ts.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, tokens, 2)
// Get specific token by ID
foundToken, err := ts.GetUserRefreshTokenByID(ctx, user.ID, "token-1")
require.NoError(t, err)
require.NotNil(t, foundToken)
require.Equal(t, "Chrome browser session", foundToken.Description)
// Get non-existent token
notFound, err := ts.GetUserRefreshTokenByID(ctx, user.ID, "non-existent")
require.NoError(t, err)
require.Nil(t, notFound)
// Remove token
err = ts.RemoveUserRefreshToken(ctx, user.ID, "token-1")
require.NoError(t, err)
tokens, err = ts.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, tokens, 1)
require.Equal(t, "token-2", tokens[0].TokenId)
ts.Close()
}
func TestUserSettingPersonalAccessTokens(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Initially no PATs
pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
require.Empty(t, pats)
// Add a PAT
pat1 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-1",
TokenHash: "pat-hash-1",
Description: "API Token for external access",
}
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat1)
require.NoError(t, err)
// Verify PAT was added
pats, err = ts.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, pats, 1)
require.Equal(t, "API Token for external access", pats[0].Description)
// Add another PAT
pat2 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-2",
TokenHash: "pat-hash-2",
Description: "CI Token",
}
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat2)
require.NoError(t, err)
pats, err = ts.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, pats, 2)
// Remove PAT
err = ts.RemoveUserPersonalAccessToken(ctx, user.ID, "pat-1")
require.NoError(t, err)
pats, err = ts.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, pats, 1)
require.Equal(t, "pat-2", pats[0].TokenId)
ts.Close()
}
func TestUserSettingWebhooks(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Initially no webhooks
webhooks, err := ts.GetUserWebhooks(ctx, user.ID)
require.NoError(t, err)
require.Empty(t, webhooks)
// Add a webhook
webhook1 := &storepb.WebhooksUserSetting_Webhook{
Id: "webhook-1",
Title: "Deploy Hook",
Url: "https://example.com/webhook",
}
err = ts.AddUserWebhook(ctx, user.ID, webhook1)
require.NoError(t, err)
// Verify webhook was added
webhooks, err = ts.GetUserWebhooks(ctx, user.ID)
require.NoError(t, err)
require.Len(t, webhooks, 1)
require.Equal(t, "Deploy Hook", webhooks[0].Title)
// Update webhook
webhook1Updated := &storepb.WebhooksUserSetting_Webhook{
Id: "webhook-1",
Title: "Updated Deploy Hook",
Url: "https://example.com/webhook/v2",
}
err = ts.UpdateUserWebhook(ctx, user.ID, webhook1Updated)
require.NoError(t, err)
webhooks, err = ts.GetUserWebhooks(ctx, user.ID)
require.NoError(t, err)
require.Len(t, webhooks, 1)
require.Equal(t, "Updated Deploy Hook", webhooks[0].Title)
require.Equal(t, "https://example.com/webhook/v2", webhooks[0].Url)
// Add another webhook
webhook2 := &storepb.WebhooksUserSetting_Webhook{
Id: "webhook-2",
Title: "Notification Hook",
Url: "https://slack.example.com/webhook",
}
err = ts.AddUserWebhook(ctx, user.ID, webhook2)
require.NoError(t, err)
webhooks, err = ts.GetUserWebhooks(ctx, user.ID)
require.NoError(t, err)
require.Len(t, webhooks, 2)
// Remove webhook
err = ts.RemoveUserWebhook(ctx, user.ID, "webhook-1")
require.NoError(t, err)
webhooks, err = ts.GetUserWebhooks(ctx, user.ID)
require.NoError(t, err)
require.Len(t, webhooks, 1)
require.Equal(t, "webhook-2", webhooks[0].Id)
ts.Close()
}
func TestUserSettingShortcuts(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create shortcuts setting
shortcuts := &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "shortcut-1", Title: "Work Notes", Filter: "tag:work"},
{Id: "shortcut-2", Title: "Personal", Filter: "tag:personal"},
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
// Retrieve and verify
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.Len(t, setting.GetShortcuts().Shortcuts, 2)
require.Equal(t, "Work Notes", setting.GetShortcuts().Shortcuts[0].Title)
ts.Close()
}
func TestUserSettingGetUserByPATHash(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a PAT with a known hash
patHash := "test-pat-hash-12345"
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-test-1",
TokenHash: patHash,
Description: "Test PAT for lookup",
}
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Lookup user by PAT hash
result, err := ts.GetUserByPATHash(ctx, patHash)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, user.ID, result.UserID)
require.NotNil(t, result.User)
require.Equal(t, user.Username, result.User.Username)
require.NotNil(t, result.PAT)
require.Equal(t, "pat-test-1", result.PAT.TokenId)
require.Equal(t, "Test PAT for lookup", result.PAT.Description)
ts.Close()
}
func TestUserSettingGetUserByPATHashNotFound(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
_, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Lookup non-existent PAT hash
result, err := ts.GetUserByPATHash(ctx, "non-existent-hash")
require.Error(t, err)
require.Nil(t, result)
ts.Close()
}
func TestUserSettingGetUserByPATHashNoTokensKey(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// User exists but has no PERSONAL_ACCESS_TOKENS key at all
// This simulates fresh users or users upgraded from v0.25.3
result, err := ts.GetUserByPATHash(ctx, "any-hash")
require.Error(t, err)
require.Nil(t, result)
// Error could be "PAT not found" (Postgres) or "sql: no rows in result set" (SQLite/MySQL)
require.True(t,
strings.Contains(err.Error(), "PAT not found") || strings.Contains(err.Error(), "no rows"),
"expected PAT not found or no rows error, got: %v", err)
// Now add a PAT for the user
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-new",
TokenHash: "hash-new",
Description: "New PAT",
}
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Now the lookup should succeed
result, err = ts.GetUserByPATHash(ctx, "hash-new")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, user.ID, result.UserID)
ts.Close()
}
func TestUserSettingGetUserByPATHashEmptyTokensArray(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Add a PAT setting with empty tokens array
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS,
Value: &storepb.UserSetting_PersonalAccessTokens{
PersonalAccessTokens: &storepb.PersonalAccessTokensUserSetting{
Tokens: []*storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{},
},
},
})
require.NoError(t, err)
// Lookup should fail gracefully, not crash
result, err := ts.GetUserByPATHash(ctx, "any-hash")
require.Error(t, err)
require.Nil(t, result)
// Error could be "PAT not found" (Postgres) or "sql: no rows in result set" (SQLite/MySQL)
require.True(t,
strings.Contains(err.Error(), "PAT not found") || strings.Contains(err.Error(), "no rows"),
"expected PAT not found or no rows error, got: %v", err)
ts.Close()
}
func TestUserSettingGetUserByPATHashWithOtherUsers(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create multiple users - some with PATs, some without
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
_, err = createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
user3, err := createTestingUserWithRole(ctx, ts, "user3", store.RoleUser)
require.NoError(t, err)
// User1: Has PAT
pat1Hash := "user1-pat-hash-unique"
err = ts.AddUserPersonalAccessToken(ctx, user1.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-user1",
TokenHash: pat1Hash,
Description: "User 1 PAT",
})
require.NoError(t, err)
// User2: Has no PERSONAL_ACCESS_TOKENS key (fresh user)
// User3: Has empty tokens array
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user3.ID,
Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS,
Value: &storepb.UserSetting_PersonalAccessTokens{
PersonalAccessTokens: &storepb.PersonalAccessTokensUserSetting{
Tokens: []*storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{},
},
},
})
require.NoError(t, err)
// Should find user1's PAT despite user2 having no key and user3 having empty array
result, err := ts.GetUserByPATHash(ctx, pat1Hash)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, user1.ID, result.UserID)
require.Equal(t, "pat-user1", result.PAT.TokenId)
// Should not find non-existent hash even with mixed user states
result, err = ts.GetUserByPATHash(ctx, "non-existent")
require.Error(t, err)
require.Nil(t, result)
ts.Close()
}
func TestUserSettingGetUserByPATHashMultipleUsers(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
// Create PATs for both users
pat1Hash := "user1-pat-hash"
err = ts.AddUserPersonalAccessToken(ctx, user1.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-user1",
TokenHash: pat1Hash,
Description: "User 1 PAT",
})
require.NoError(t, err)
pat2Hash := "user2-pat-hash"
err = ts.AddUserPersonalAccessToken(ctx, user2.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-user2",
TokenHash: pat2Hash,
Description: "User 2 PAT",
})
require.NoError(t, err)
// Lookup user1's PAT
result1, err := ts.GetUserByPATHash(ctx, pat1Hash)
require.NoError(t, err)
require.Equal(t, user1.ID, result1.UserID)
require.Equal(t, user1.Username, result1.User.Username)
// Lookup user2's PAT
result2, err := ts.GetUserByPATHash(ctx, pat2Hash)
require.NoError(t, err)
require.Equal(t, user2.ID, result2.UserID)
require.Equal(t, user2.Username, result2.User.Username)
ts.Close()
}
func TestUserSettingGetUserByPATHashMultiplePATsSameUser(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create multiple PATs for the same user
pat1Hash := "first-pat-hash"
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-1",
TokenHash: pat1Hash,
Description: "First PAT",
})
require.NoError(t, err)
pat2Hash := "second-pat-hash"
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-2",
TokenHash: pat2Hash,
Description: "Second PAT",
})
require.NoError(t, err)
// Both PATs should resolve to the same user
result1, err := ts.GetUserByPATHash(ctx, pat1Hash)
require.NoError(t, err)
require.Equal(t, user.ID, result1.UserID)
require.Equal(t, "pat-1", result1.PAT.TokenId)
result2, err := ts.GetUserByPATHash(ctx, pat2Hash)
require.NoError(t, err)
require.Equal(t, user.ID, result2.UserID)
require.Equal(t, "pat-2", result2.PAT.TokenId)
ts.Close()
}
func TestUserSettingUpdatePATLastUsed(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a PAT
patHash := "pat-hash-for-update"
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-update-test",
TokenHash: patHash,
Description: "PAT for update test",
})
require.NoError(t, err)
// Update last used timestamp
now := timestamppb.Now()
err = ts.UpdatePATLastUsed(ctx, user.ID, "pat-update-test", now)
require.NoError(t, err)
// Verify the update
pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, pats, 1)
require.NotNil(t, pats[0].LastUsedAt)
ts.Close()
}
func TestUserSettingGetUserByPATHashWithExpiredToken(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a PAT with expiration info
patHash := "pat-hash-with-expiry"
expiresAt := timestamppb.Now()
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-expiry-test",
TokenHash: patHash,
Description: "PAT with expiry",
ExpiresAt: expiresAt,
}
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Should still be able to look up by hash (expiry check is done at auth level)
result, err := ts.GetUserByPATHash(ctx, patHash)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, user.ID, result.UserID)
require.NotNil(t, result.PAT.ExpiresAt)
ts.Close()
}
func TestUserSettingGetUserByPATHashAfterRemoval(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a PAT
patHash := "pat-hash-to-remove"
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-remove-test",
TokenHash: patHash,
Description: "PAT to be removed",
})
require.NoError(t, err)
// Verify it exists
result, err := ts.GetUserByPATHash(ctx, patHash)
require.NoError(t, err)
require.NotNil(t, result)
// Remove the PAT
err = ts.RemoveUserPersonalAccessToken(ctx, user.ID, "pat-remove-test")
require.NoError(t, err)
// Should no longer be found
result, err = ts.GetUserByPATHash(ctx, patHash)
require.Error(t, err)
require.Nil(t, result)
ts.Close()
}
func TestUserSettingGetUserByPATHashSpecialCharacters(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create PATs with special characters in hash (simulating real hash values)
testCases := []struct {
tokenID string
tokenHash string
}{
{"pat-special-1", "abc123+/=XYZ"},
{"pat-special-2", "sha256:abcdef1234567890"},
{"pat-special-3", "$2a$10$N9qo8uLOickgx2ZMRZoMy"},
}
for _, tc := range testCases {
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tc.tokenID,
TokenHash: tc.tokenHash,
Description: "PAT with special chars",
})
require.NoError(t, err)
// Verify lookup works with special characters
result, err := ts.GetUserByPATHash(ctx, tc.tokenHash)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, tc.tokenID, result.PAT.TokenId)
}
ts.Close()
}
func TestUserSettingGetUserByPATHashLargeTokenCount(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create many PATs for the same user
tokenCount := 10
hashes := make([]string, tokenCount)
for i := 0; i < tokenCount; i++ {
hashes[i] = "pat-hash-" + string(rune('A'+i)) + "-large-test"
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-large-" + string(rune('A'+i)),
TokenHash: hashes[i],
Description: "PAT for large count test",
})
require.NoError(t, err)
}
// Verify each hash can be looked up
for i, hash := range hashes {
result, err := ts.GetUserByPATHash(ctx, hash)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, user.ID, result.UserID)
require.Equal(t, "pat-large-"+string(rune('A'+i)), result.PAT.TokenId)
}
ts.Close()
}
func TestUserSettingMultipleSettingTypes(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create GENERAL setting
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_GENERAL,
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "ja"}},
})
require.NoError(t, err)
// Create SHORTCUTS setting
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s1", Title: "Shortcut 1"},
},
}},
})
require.NoError(t, err)
// Add a PAT
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-multi",
TokenHash: "hash-multi",
})
require.NoError(t, err)
// List all settings for user
settings, err := ts.ListUserSettings(ctx, &store.FindUserSetting{UserID: &user.ID})
require.NoError(t, err)
require.Len(t, settings, 3)
// Verify each setting type
generalSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_GENERAL})
require.NoError(t, err)
require.Equal(t, "ja", generalSetting.GetGeneral().Locale)
shortcutsSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_SHORTCUTS})
require.NoError(t, err)
require.Len(t, shortcutsSetting.GetShortcuts().Shortcuts, 1)
patsSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS})
require.NoError(t, err)
require.Len(t, patsSetting.GetPersonalAccessTokens().Tokens, 1)
ts.Close()
}
func TestUserSettingShortcutsEdgeCases(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Case 1: Special characters in Filter and Title
// Includes quotes, backslashes, newlines, and other JSON-sensitive characters
specialCharsFilter := `tag in ["work", "project"] && content.contains("urgent")`
specialCharsTitle := `Work "Urgent" \ Notes`
shortcuts := &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s1", Title: specialCharsTitle, Filter: specialCharsFilter},
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.Len(t, setting.GetShortcuts().Shortcuts, 1)
require.Equal(t, specialCharsTitle, setting.GetShortcuts().Shortcuts[0].Title)
require.Equal(t, specialCharsFilter, setting.GetShortcuts().Shortcuts[0].Filter)
// Case 2: Unicode characters
unicodeFilter := `tag in ["你好", "世界"]`
unicodeTitle := `My 🚀 Shortcuts`
shortcuts = &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s2", Title: unicodeTitle, Filter: unicodeFilter},
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.Len(t, setting.GetShortcuts().Shortcuts, 1)
require.Equal(t, unicodeTitle, setting.GetShortcuts().Shortcuts[0].Title)
require.Equal(t, unicodeFilter, setting.GetShortcuts().Shortcuts[0].Filter)
// Case 3: Empty shortcuts list
// Should allow saving an empty list (clearing shortcuts)
shortcuts = &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.NotNil(t, setting.GetShortcuts())
require.Len(t, setting.GetShortcuts().Shortcuts, 0)
// Case 4: Large filter string
// Test reasonable large string handling (e.g. 4KB)
largeFilter := strings.Repeat("tag:long_tag_name ", 200)
shortcuts = &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s3", Title: "Large Filter", Filter: largeFilter},
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.Equal(t, largeFilter, setting.GetShortcuts().Shortcuts[0].Filter)
ts.Close()
}
func TestUserSettingShortcutsPartialUpdate(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Initial set
shortcuts := &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s1", Title: "Note 1", Filter: "tag:1"},
{Id: "s2", Title: "Note 2", Filter: "tag:2"},
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
// Update by replacing the whole list (Store Upsert replaces the value for the key)
// We want to verify that we can "update" a single item by sending the modified list
updatedShortcuts := &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s1", Title: "Note 1 Updated", Filter: "tag:1_updated"},
{Id: "s2", Title: "Note 2", Filter: "tag:2"},
{Id: "s3", Title: "Note 3", Filter: "tag:3"}, // Add new one
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: updatedShortcuts},
})
require.NoError(t, err)
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.Len(t, setting.GetShortcuts().Shortcuts, 3)
// Verify updates
for _, s := range setting.GetShortcuts().Shortcuts {
if s.Id == "s1" {
require.Equal(t, "Note 1 Updated", s.Title)
require.Equal(t, "tag:1_updated", s.Filter)
} else if s.Id == "s2" {
require.Equal(t, "Note 2", s.Title)
} else if s.Id == "s3" {
require.Equal(t, "Note 3", s.Title)
}
}
ts.Close()
}
func TestUserSettingJSONFieldsEdgeCases(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Case 1: Webhook with special characters and Unicode in Title and URL
specialWebhook := &storepb.WebhooksUserSetting_Webhook{
Id: "wh-special",
Title: `My "Special" & <Webhook> 🚀`,
Url: "https://example.com/hook?query=你好&param=\"value\"",
}
err = ts.AddUserWebhook(ctx, user.ID, specialWebhook)
require.NoError(t, err)
webhooks, err := ts.GetUserWebhooks(ctx, user.ID)
require.NoError(t, err)
require.Len(t, webhooks, 1)
require.Equal(t, specialWebhook.Title, webhooks[0].Title)
require.Equal(t, specialWebhook.Url, webhooks[0].Url)
// Case 2: PAT with special description
specialPAT := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-special",
TokenHash: "hash-special",
Description: "Token for 'CLI' \n & \"API\" \t with unicode 🔑",
}
err = ts.AddUserPersonalAccessToken(ctx, user.ID, specialPAT)
require.NoError(t, err)
pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, pats, 1)
require.Equal(t, specialPAT.Description, pats[0].Description)
// Case 3: Refresh Token with special description
specialRefreshToken := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: "rt-special",
Description: "Browser: Firefox (Nightly) / OS: Linux 🐧",
}
err = ts.AddUserRefreshToken(ctx, user.ID, specialRefreshToken)
require.NoError(t, err)
tokens, err := ts.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, tokens, 1)
require.Equal(t, specialRefreshToken.Description, tokens[0].Description)
ts.Close()
}

256
store/test/user_test.go Normal file
View File

@@ -0,0 +1,256 @@
package test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
"github.com/usememos/memos/store"
)
func TestUserStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
users, err := ts.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err)
require.Equal(t, 1, len(users))
require.Equal(t, store.RoleAdmin, users[0].Role)
require.Equal(t, user, users[0])
userPatchNickname := "test_nickname_2"
userPatch := &store.UpdateUser{
ID: user.ID,
Nickname: &userPatchNickname,
}
user, err = ts.UpdateUser(ctx, userPatch)
require.NoError(t, err)
require.Equal(t, userPatchNickname, user.Nickname)
err = ts.DeleteUser(ctx, &store.DeleteUser{
ID: user.ID,
})
require.NoError(t, err)
users, err = ts.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err)
require.Equal(t, 0, len(users))
ts.Close()
}
func TestUserGetByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Get user by ID
found, err := ts.GetUser(ctx, &store.FindUser{ID: &user.ID})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, user.ID, found.ID)
require.Equal(t, user.Username, found.Username)
// Get non-existent user
nonExistentID := int32(99999)
notFound, err := ts.GetUser(ctx, &store.FindUser{ID: &nonExistentID})
require.NoError(t, err)
require.Nil(t, notFound)
// Get system bot
systemBotID := store.SystemBotID
systemBot, err := ts.GetUser(ctx, &store.FindUser{ID: &systemBotID})
require.NoError(t, err)
require.NotNil(t, systemBot)
require.Equal(t, store.SystemBotID, systemBot.ID)
require.Equal(t, "system_bot", systemBot.Username)
ts.Close()
}
func TestUserGetByUsername(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Get user by username
found, err := ts.GetUser(ctx, &store.FindUser{Username: &user.Username})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, user.Username, found.Username)
// Get non-existent username
nonExistent := "nonexistent"
notFound, err := ts.GetUser(ctx, &store.FindUser{Username: &nonExistent})
require.NoError(t, err)
require.Nil(t, notFound)
ts.Close()
}
func TestUserListByRole(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create users with different roles
_, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
_, err = createTestingUserWithRole(ctx, ts, "admin_user", store.RoleAdmin)
require.NoError(t, err)
regularUser, err := createTestingUserWithRole(ctx, ts, "regular_user", store.RoleUser)
require.NoError(t, err)
// List all users
allUsers, err := ts.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err)
require.Equal(t, 3, len(allUsers))
// List only ADMIN users
adminRole := store.RoleAdmin
adminOnlyUsers, err := ts.ListUsers(ctx, &store.FindUser{Role: &adminRole})
require.NoError(t, err)
require.Equal(t, 2, len(adminOnlyUsers))
// List only USER role users
userRole := store.RoleUser
regularUsers, err := ts.ListUsers(ctx, &store.FindUser{Role: &userRole})
require.NoError(t, err)
require.Equal(t, 1, len(regularUsers))
require.Equal(t, regularUser.ID, regularUsers[0].ID)
ts.Close()
}
func TestUserUpdateRowStatus(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
require.Equal(t, store.Normal, user.RowStatus)
// Archive user
archivedStatus := store.Archived
updated, err := ts.UpdateUser(ctx, &store.UpdateUser{
ID: user.ID,
RowStatus: &archivedStatus,
})
require.NoError(t, err)
require.Equal(t, store.Archived, updated.RowStatus)
// Verify by fetching
fetched, err := ts.GetUser(ctx, &store.FindUser{ID: &user.ID})
require.NoError(t, err)
require.Equal(t, store.Archived, fetched.RowStatus)
// Restore to normal
normalStatus := store.Normal
restored, err := ts.UpdateUser(ctx, &store.UpdateUser{
ID: user.ID,
RowStatus: &normalStatus,
})
require.NoError(t, err)
require.Equal(t, store.Normal, restored.RowStatus)
ts.Close()
}
func TestUserUpdateAllFields(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Update all fields
newUsername := "updated_username"
newEmail := "updated@test.com"
newNickname := "Updated Nickname"
newAvatarURL := "https://example.com/avatar.png"
newDescription := "Updated description"
newRole := store.RoleAdmin
newPasswordHash := "new_password_hash"
updated, err := ts.UpdateUser(ctx, &store.UpdateUser{
ID: user.ID,
Username: &newUsername,
Email: &newEmail,
Nickname: &newNickname,
AvatarURL: &newAvatarURL,
Description: &newDescription,
Role: &newRole,
PasswordHash: &newPasswordHash,
})
require.NoError(t, err)
require.Equal(t, newUsername, updated.Username)
require.Equal(t, newEmail, updated.Email)
require.Equal(t, newNickname, updated.Nickname)
require.Equal(t, newAvatarURL, updated.AvatarURL)
require.Equal(t, newDescription, updated.Description)
require.Equal(t, newRole, updated.Role)
require.Equal(t, newPasswordHash, updated.PasswordHash)
// Verify by fetching again
fetched, err := ts.GetUser(ctx, &store.FindUser{ID: &user.ID})
require.NoError(t, err)
require.Equal(t, newUsername, fetched.Username)
ts.Close()
}
func TestUserListWithLimit(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create 5 users
for i := 0; i < 5; i++ {
role := store.RoleUser
if i == 0 {
role = store.RoleAdmin
}
_, err := createTestingUserWithRole(ctx, ts, fmt.Sprintf("user%d", i), role)
require.NoError(t, err)
}
// List with limit
limit := 3
users, err := ts.ListUsers(ctx, &store.FindUser{Limit: &limit})
require.NoError(t, err)
require.Equal(t, 3, len(users))
ts.Close()
}
func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, error) {
return createTestingUserWithRole(ctx, ts, "test", store.RoleAdmin)
}
func createTestingUserWithRole(ctx context.Context, ts *store.Store, username string, role store.Role) (*store.User, error) {
userCreate := &store.User{
Username: username,
Role: role,
Email: username + "@test.com",
Nickname: username + "_nickname",
Description: username + "_description",
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte("test_password"), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
userCreate.PasswordHash = string(passwordHash)
user, err := ts.CreateUser(ctx, userCreate)
return user, err
}