first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
This commit is contained in:
13
store/test/README.md
Normal file
13
store/test/README.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# Store tests
|
||||
|
||||
## How to test store with MySQL?
|
||||
|
||||
1. Create a database in your MySQL server.
|
||||
2. Run the following command with two environment variables set:
|
||||
|
||||
```go
|
||||
DRIVER=mysql DSN=root@/memos_test go test -v ./test/store/...
|
||||
```
|
||||
|
||||
- `DRIVER` should be set to `mysql`.
|
||||
- `DSN` should be set to the DSN of your MySQL server.
|
||||
380
store/test/activity_test.go
Normal file
380
store/test/activity_test.go
Normal file
@@ -0,0 +1,380 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestActivityStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
create := &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
}
|
||||
activity, err := ts.CreateActivity(ctx, create)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, activity)
|
||||
activities, err := ts.ListActivities(ctx, &store.FindActivity{
|
||||
ID: &activity.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(activities))
|
||||
require.Equal(t, activity, activities[0])
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityGetByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get activity by ID
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, activity.ID, found.ID)
|
||||
|
||||
// Get non-existent activity
|
||||
nonExistentID := int32(99999)
|
||||
notFound, err := ts.GetActivity(ctx, &store.FindActivity{ID: &nonExistentID})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityListMultiple(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple activities
|
||||
_, err = ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List all activities
|
||||
allActivities, err := ts.ListActivities(ctx, &store.FindActivity{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(allActivities))
|
||||
|
||||
// List by type
|
||||
commentType := store.ActivityTypeMemoComment
|
||||
commentActivities, err := ts.ListActivities(ctx, &store.FindActivity{Type: &commentType})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(commentActivities))
|
||||
require.Equal(t, store.ActivityTypeMemoComment, commentActivities[0].Type)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityListByType(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activities with MEMO_COMMENT type
|
||||
_, err = ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by type
|
||||
activityType := store.ActivityTypeMemoComment
|
||||
activities, err := ts.ListActivities(ctx, &store.FindActivity{Type: &activityType})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activities, 2)
|
||||
for _, activity := range activities {
|
||||
require.Equal(t, store.ActivityTypeMemoComment, activity.Type)
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityPayloadMemoComment(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activity with MemoComment payload
|
||||
memoID := int32(123)
|
||||
relatedMemoID := int32(456)
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{
|
||||
MemoComment: &storepb.ActivityMemoCommentPayload{
|
||||
MemoId: memoID,
|
||||
RelatedMemoId: relatedMemoID,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, activity.Payload)
|
||||
require.NotNil(t, activity.Payload.MemoComment)
|
||||
require.Equal(t, memoID, activity.Payload.MemoComment.MemoId)
|
||||
require.Equal(t, relatedMemoID, activity.Payload.MemoComment.RelatedMemoId)
|
||||
|
||||
// Verify payload is preserved when listing
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found.Payload.MemoComment)
|
||||
require.Equal(t, memoID, found.Payload.MemoComment.MemoId)
|
||||
require.Equal(t, relatedMemoID, found.Payload.MemoComment.RelatedMemoId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityEmptyPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activity with empty payload
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, activity.Payload)
|
||||
|
||||
// Verify empty payload is handled correctly
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found.Payload)
|
||||
require.Nil(t, found.Payload.MemoComment)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityLevel(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activity with INFO level
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.ActivityLevelInfo, activity.Level)
|
||||
|
||||
// Verify level is preserved when listing
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.ActivityLevelInfo, found.Level)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityCreatorID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activity for user1
|
||||
activity1, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user1.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user1.ID, activity1.CreatorID)
|
||||
|
||||
// Create activity for user2
|
||||
activity2, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user2.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user2.ID, activity2.CreatorID)
|
||||
|
||||
// List all and verify creator IDs
|
||||
activities, err := ts.ListActivities(ctx, &store.FindActivity{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activities, 2)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityCreatedTs(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, activity.CreatedTs)
|
||||
|
||||
// Verify timestamp is preserved when listing
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, activity.CreatedTs, found.CreatedTs)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityListEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// List activities when none exist
|
||||
activities, err := ts.ListActivities(ctx, &store.FindActivity{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activities, 0)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityListWithIDAndType(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List with both ID and Type filters
|
||||
activityType := store.ActivityTypeMemoComment
|
||||
activities, err := ts.ListActivities(ctx, &store.FindActivity{
|
||||
ID: &activity.ID,
|
||||
Type: &activityType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activities, 1)
|
||||
require.Equal(t, activity.ID, activities[0].ID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityPayloadComplexMemoComment(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a memo first to use its ID
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-for-activity",
|
||||
CreatorID: user.ID,
|
||||
Content: "Test memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create comment memo
|
||||
commentMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "comment-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "This is a comment",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activity with real memo IDs
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{
|
||||
MemoComment: &storepb.ActivityMemoCommentPayload{
|
||||
MemoId: memo.ID,
|
||||
RelatedMemoId: commentMemo.ID,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memo.ID, activity.Payload.MemoComment.MemoId)
|
||||
require.Equal(t, commentMemo.ID, activity.Payload.MemoComment.RelatedMemoId)
|
||||
|
||||
// Verify payload is preserved
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memo.ID, found.Payload.MemoComment.MemoId)
|
||||
require.Equal(t, commentMemo.ID, found.Payload.MemoComment.RelatedMemoId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
375
store/test/attachment_filter_test.go
Normal file
375
store/test/attachment_filter_test.go
Normal file
@@ -0,0 +1,375 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Filename Field Tests
|
||||
// Schema: filename (string, supports contains)
|
||||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterFilenameContains(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("report.pdf").MimeType("application/pdf"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
|
||||
|
||||
// Test: filename.contains("report") - single match
|
||||
attachments := tc.ListWithFilter(`filename.contains("report")`)
|
||||
require.Len(t, attachments, 1)
|
||||
require.Contains(t, attachments[0].Filename, "report")
|
||||
|
||||
// Test: filename.contains(".pdf") - multiple matches
|
||||
attachments = tc.ListWithFilter(`filename.contains(".pdf")`)
|
||||
require.Len(t, attachments, 2)
|
||||
|
||||
// Test: filename.contains("nonexistent") - no matches
|
||||
attachments = tc.ListWithFilter(`filename.contains("nonexistent")`)
|
||||
require.Len(t, attachments, 0)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterFilenameSpecialCharacters(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).
|
||||
Filename("file_with-special.chars@2024.pdf").MimeType("application/pdf"))
|
||||
|
||||
// Test: filename.contains with underscore
|
||||
attachments := tc.ListWithFilter(`filename.contains("_with")`)
|
||||
require.Len(t, attachments, 1)
|
||||
|
||||
// Test: filename.contains with @
|
||||
attachments = tc.ListWithFilter(`filename.contains("@2024")`)
|
||||
require.Len(t, attachments, 1)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterFilenameUnicode(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).
|
||||
Filename("document_报告.pdf").MimeType("application/pdf"))
|
||||
|
||||
attachments := tc.ListWithFilter(`filename.contains("报告")`)
|
||||
require.Len(t, attachments, 1)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Mime Type Field Tests
|
||||
// Schema: mime_type (string, ==, !=)
|
||||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterMimeTypeEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("photo.jpeg").MimeType("image/jpeg"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
|
||||
|
||||
// Test: mime_type == "image/png"
|
||||
attachments := tc.ListWithFilter(`mime_type == "image/png"`)
|
||||
require.Len(t, attachments, 1)
|
||||
require.Equal(t, "image/png", attachments[0].Type)
|
||||
|
||||
// Test: mime_type == "application/pdf"
|
||||
attachments = tc.ListWithFilter(`mime_type == "application/pdf"`)
|
||||
require.Len(t, attachments, 1)
|
||||
require.Equal(t, "application/pdf", attachments[0].Type)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterMimeTypeNotEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
|
||||
|
||||
attachments := tc.ListWithFilter(`mime_type != "image/png"`)
|
||||
require.Len(t, attachments, 1)
|
||||
require.Equal(t, "application/pdf", attachments[0].Type)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterMimeTypeInList(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("photo.jpeg").MimeType("image/jpeg"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
|
||||
|
||||
// Test: mime_type in ["image/png", "image/jpeg"] - matches images
|
||||
attachments := tc.ListWithFilter(`mime_type in ["image/png", "image/jpeg"]`)
|
||||
require.Len(t, attachments, 2)
|
||||
|
||||
// Test: mime_type in ["video/mp4"] - no matches
|
||||
attachments = tc.ListWithFilter(`mime_type in ["video/mp4"]`)
|
||||
require.Len(t, attachments, 0)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Create Time Field Tests
|
||||
// Schema: create_time (timestamp, all comparison operators)
|
||||
// Functions: now(), arithmetic (+, -, *)
|
||||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterCreateTimeComparison(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
now := time.Now().Unix()
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
|
||||
|
||||
// Test: create_time < future (should match)
|
||||
attachments := tc.ListWithFilter(`create_time < ` + formatInt64(now+3600))
|
||||
require.Len(t, attachments, 1)
|
||||
|
||||
// Test: create_time > past (should match)
|
||||
attachments = tc.ListWithFilter(`create_time > ` + formatInt64(now-3600))
|
||||
require.Len(t, attachments, 1)
|
||||
|
||||
// Test: create_time > future (should not match)
|
||||
attachments = tc.ListWithFilter(`create_time > ` + formatInt64(now+3600))
|
||||
require.Len(t, attachments, 0)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterCreateTimeWithNow(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
|
||||
|
||||
// Test: create_time < now() + 5 (buffer for container clock drift)
|
||||
attachments := tc.ListWithFilter(`create_time < now() + 5`)
|
||||
require.Len(t, attachments, 1)
|
||||
|
||||
// Test: create_time > now() + 5 (should not match)
|
||||
attachments = tc.ListWithFilter(`create_time > now() + 5`)
|
||||
require.Len(t, attachments, 0)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterCreateTimeArithmetic(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
|
||||
|
||||
// Test: create_time >= now() - 3600 (attachments created in last hour)
|
||||
attachments := tc.ListWithFilter(`create_time >= now() - 3600`)
|
||||
require.Len(t, attachments, 1)
|
||||
|
||||
// Test: create_time < now() - 86400 (attachments older than 1 day - should be empty)
|
||||
attachments = tc.ListWithFilter(`create_time < now() - 86400`)
|
||||
require.Len(t, attachments, 0)
|
||||
|
||||
// Test: Multiplication - create_time >= now() - 60 * 60
|
||||
attachments = tc.ListWithFilter(`create_time >= now() - 60 * 60`)
|
||||
require.Len(t, attachments, 1)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterAllComparisonOperators(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
|
||||
|
||||
// Test: < (less than)
|
||||
attachments := tc.ListWithFilter(`create_time < now() + 3600`)
|
||||
require.Len(t, attachments, 1)
|
||||
|
||||
// Test: <= (less than or equal) with buffer for clock drift
|
||||
attachments = tc.ListWithFilter(`create_time < now() + 5`)
|
||||
require.Len(t, attachments, 1)
|
||||
|
||||
// Test: > (greater than)
|
||||
attachments = tc.ListWithFilter(`create_time > now() - 3600`)
|
||||
require.Len(t, attachments, 1)
|
||||
|
||||
// Test: >= (greater than or equal)
|
||||
attachments = tc.ListWithFilter(`create_time >= now() - 60`)
|
||||
require.Len(t, attachments, 1)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Memo ID Field Tests
|
||||
// Schema: memo_id (int, ==, !=)
|
||||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterMemoIdEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContextWithUser(t)
|
||||
defer tc.Close()
|
||||
|
||||
memo1 := tc.CreateMemo("memo-1", "Memo 1")
|
||||
memo2 := tc.CreateMemo("memo-2", "Memo 2")
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("memo1_attachment.png").MimeType("image/png").MemoID(&memo1.ID))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("memo2_attachment.png").MimeType("image/png").MemoID(&memo2.ID))
|
||||
|
||||
attachments := tc.ListWithFilter(`memo_id == ` + formatInt32(memo1.ID))
|
||||
require.Len(t, attachments, 1)
|
||||
require.Equal(t, &memo1.ID, attachments[0].MemoID)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterMemoIdNotEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContextWithUser(t)
|
||||
defer tc.Close()
|
||||
|
||||
memo1 := tc.CreateMemo("memo-1", "Memo 1")
|
||||
memo2 := tc.CreateMemo("memo-2", "Memo 2")
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("memo1_attachment.png").MimeType("image/png").MemoID(&memo1.ID))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("memo2_attachment.png").MimeType("image/png").MemoID(&memo2.ID))
|
||||
|
||||
attachments := tc.ListWithFilter(`memo_id != ` + formatInt32(memo1.ID))
|
||||
require.Len(t, attachments, 1)
|
||||
require.Equal(t, &memo2.ID, attachments[0].MemoID)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Logical Operator Tests
|
||||
// Operators: && (AND), || (OR), ! (NOT)
|
||||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterLogicalAnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("photo.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.pdf").MimeType("application/pdf"))
|
||||
|
||||
attachments := tc.ListWithFilter(`mime_type == "image/png" && filename.contains("image")`)
|
||||
require.Len(t, attachments, 1)
|
||||
require.Equal(t, "image.png", attachments[0].Filename)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterLogicalOr(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("video.mp4").MimeType("video/mp4"))
|
||||
|
||||
attachments := tc.ListWithFilter(`mime_type == "image/png" || mime_type == "application/pdf"`)
|
||||
require.Len(t, attachments, 2)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterLogicalNot(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("image.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("document.pdf").MimeType("application/pdf"))
|
||||
|
||||
attachments := tc.ListWithFilter(`!(mime_type == "image/png")`)
|
||||
require.Len(t, attachments, 1)
|
||||
require.Equal(t, "application/pdf", attachments[0].Type)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterComplexLogical(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("report.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("report.pdf").MimeType("application/pdf"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("other.png").MimeType("image/png"))
|
||||
|
||||
attachments := tc.ListWithFilter(`(mime_type == "image/png" || mime_type == "application/pdf") && filename.contains("report")`)
|
||||
require.Len(t, attachments, 2)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Multiple Filters Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterMultipleFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("report.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("other.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("report.pdf").MimeType("application/pdf"))
|
||||
|
||||
// Test: Multiple filters (applied as AND)
|
||||
attachments := tc.ListWithFilters(`filename.contains("report")`, `mime_type == "image/png"`)
|
||||
require.Len(t, attachments, 1)
|
||||
require.Contains(t, attachments[0].Filename, "report")
|
||||
require.Equal(t, "image/png", attachments[0].Type)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Edge Cases
|
||||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterNoMatches(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
|
||||
|
||||
attachments := tc.ListWithFilter(`filename.contains("nonexistent12345")`)
|
||||
require.Len(t, attachments, 0)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterNullMemoId(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContextWithUser(t)
|
||||
defer tc.Close()
|
||||
|
||||
memo := tc.CreateMemo("memo-1", "Memo 1")
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("with_memo.png").MimeType("image/png").MemoID(&memo.ID))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("no_memo.png").MimeType("image/png"))
|
||||
|
||||
// Test: memo_id == null
|
||||
attachments := tc.ListWithFilter(`memo_id == null`)
|
||||
require.Len(t, attachments, 1)
|
||||
require.Equal(t, "no_memo.png", attachments[0].Filename)
|
||||
require.Nil(t, attachments[0].MemoID)
|
||||
|
||||
// Test: memo_id != null
|
||||
attachments = tc.ListWithFilter(`memo_id != null`)
|
||||
require.Len(t, attachments, 1)
|
||||
require.Equal(t, "with_memo.png", attachments[0].Filename)
|
||||
require.NotNil(t, attachments[0].MemoID)
|
||||
require.Equal(t, memo.ID, *attachments[0].MemoID)
|
||||
}
|
||||
|
||||
func TestAttachmentFilterEmptyFilename(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("test.png").MimeType("image/png"))
|
||||
tc.CreateAttachment(NewAttachmentBuilder(tc.CreatorID).Filename("other.pdf").MimeType("application/pdf"))
|
||||
|
||||
// Test: filename.contains("") - should match all
|
||||
attachments := tc.ListWithFilter(`filename.contains("")`)
|
||||
require.Len(t, attachments, 2)
|
||||
}
|
||||
247
store/test/attachment_test.go
Normal file
247
store/test/attachment_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestAttachmentStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
_, err := ts.CreateAttachment(ctx, &store.Attachment{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: 101,
|
||||
Filename: "test.epub",
|
||||
Blob: []byte("test"),
|
||||
Type: "application/epub+zip",
|
||||
Size: 637607,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
correctFilename := "test.epub"
|
||||
incorrectFilename := "test.png"
|
||||
attachment, err := ts.GetAttachment(ctx, &store.FindAttachment{
|
||||
Filename: &correctFilename,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, correctFilename, attachment.Filename)
|
||||
require.Equal(t, int32(1), attachment.ID)
|
||||
|
||||
notFoundAttachment, err := ts.GetAttachment(ctx, &store.FindAttachment{
|
||||
Filename: &incorrectFilename,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFoundAttachment)
|
||||
|
||||
var correctCreatorID int32 = 101
|
||||
var incorrectCreatorID int32 = 102
|
||||
_, err = ts.GetAttachment(ctx, &store.FindAttachment{
|
||||
CreatorID: &correctCreatorID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
notFoundAttachment, err = ts.GetAttachment(ctx, &store.FindAttachment{
|
||||
CreatorID: &incorrectCreatorID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFoundAttachment)
|
||||
|
||||
err = ts.DeleteAttachment(ctx, &store.DeleteAttachment{
|
||||
ID: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = ts.DeleteAttachment(ctx, &store.DeleteAttachment{
|
||||
ID: 2,
|
||||
})
|
||||
require.ErrorContains(t, err, "attachment not found")
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestAttachmentStoreWithFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
_, err := ts.CreateAttachment(ctx, &store.Attachment{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: 101,
|
||||
Filename: "test.png",
|
||||
Blob: []byte("test"),
|
||||
Type: "image/png",
|
||||
Size: 1000,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateAttachment(ctx, &store.Attachment{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: 101,
|
||||
Filename: "test.jpg",
|
||||
Blob: []byte("test"),
|
||||
Type: "image/jpeg",
|
||||
Size: 2000,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateAttachment(ctx, &store.Attachment{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: 101,
|
||||
Filename: "test.pdf",
|
||||
Blob: []byte("test"),
|
||||
Type: "application/pdf",
|
||||
Size: 3000,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
attachments, err := ts.ListAttachments(ctx, &store.FindAttachment{
|
||||
CreatorID: &[]int32{101}[0],
|
||||
Filters: []string{`mime_type == "image/png"`},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, attachments, 1)
|
||||
require.Equal(t, "image/png", attachments[0].Type)
|
||||
|
||||
attachments, err = ts.ListAttachments(ctx, &store.FindAttachment{
|
||||
CreatorID: &[]int32{101}[0],
|
||||
Filters: []string{`mime_type in ["image/png", "image/jpeg"]`},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, attachments, 2)
|
||||
|
||||
attachments, err = ts.ListAttachments(ctx, &store.FindAttachment{
|
||||
CreatorID: &[]int32{101}[0],
|
||||
Filters: []string{`filename.contains("test")`},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, attachments, 3)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestAttachmentUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
attachment, err := ts.CreateAttachment(ctx, &store.Attachment{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: 101,
|
||||
Filename: "original.png",
|
||||
Blob: []byte("test"),
|
||||
Type: "image/png",
|
||||
Size: 1000,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update filename
|
||||
newFilename := "updated.png"
|
||||
err = ts.UpdateAttachment(ctx, &store.UpdateAttachment{
|
||||
ID: attachment.ID,
|
||||
Filename: &newFilename,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify update
|
||||
found, err := ts.GetAttachment(ctx, &store.FindAttachment{ID: &attachment.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newFilename, found.Filename)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestAttachmentGetByUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
uid := shortuuid.New()
|
||||
_, err := ts.CreateAttachment(ctx, &store.Attachment{
|
||||
UID: uid,
|
||||
CreatorID: 101,
|
||||
Filename: "test.png",
|
||||
Blob: []byte("test"),
|
||||
Type: "image/png",
|
||||
Size: 1000,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by UID
|
||||
found, err := ts.GetAttachment(ctx, &store.FindAttachment{UID: &uid})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, uid, found.UID)
|
||||
|
||||
// Get non-existent UID
|
||||
nonExistentUID := "non-existent-uid"
|
||||
notFound, err := ts.GetAttachment(ctx, &store.FindAttachment{UID: &nonExistentUID})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestAttachmentListWithPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create 5 attachments
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := ts.CreateAttachment(ctx, &store.Attachment{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: 101,
|
||||
Filename: fmt.Sprintf("test%d.png", i),
|
||||
Blob: []byte("test"),
|
||||
Type: "image/png",
|
||||
Size: int64(1000 + i),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test limit
|
||||
limit := 3
|
||||
attachments, err := ts.ListAttachments(ctx, &store.FindAttachment{
|
||||
CreatorID: &[]int32{101}[0],
|
||||
Limit: &limit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(attachments))
|
||||
|
||||
// Test offset
|
||||
offset := 2
|
||||
offsetAttachments, err := ts.ListAttachments(ctx, &store.FindAttachment{
|
||||
CreatorID: &[]int32{101}[0],
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(offsetAttachments))
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestAttachmentInvalidUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create with invalid UID (contains spaces)
|
||||
_, err := ts.CreateAttachment(ctx, &store.Attachment{
|
||||
UID: "invalid uid with spaces",
|
||||
CreatorID: 101,
|
||||
Filename: "test.png",
|
||||
Blob: []byte("test"),
|
||||
Type: "image/png",
|
||||
Size: 1000,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid uid")
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
317
store/test/containers.go
Normal file
317
store/test/containers.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mysql"
|
||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/network"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
|
||||
// Database drivers for connection verification.
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
testUser = "root"
|
||||
testPassword = "test"
|
||||
|
||||
// Memos container settings for migration testing.
|
||||
MemosDockerImage = "neosmemo/memos"
|
||||
StableMemosVersion = "stable" // Always points to the latest stable release
|
||||
)
|
||||
|
||||
var (
|
||||
mysqlContainer atomic.Pointer[mysql.MySQLContainer]
|
||||
postgresContainer atomic.Pointer[postgres.PostgresContainer]
|
||||
mysqlOnce sync.Once
|
||||
postgresOnce sync.Once
|
||||
mysqlBaseDSN atomic.Value // stores string
|
||||
postgresBaseDSN atomic.Value // stores string
|
||||
dbCounter atomic.Int64
|
||||
dbCreationMutex sync.Mutex // Protects database creation operations
|
||||
|
||||
// Network for container communication.
|
||||
testDockerNetwork atomic.Pointer[testcontainers.DockerNetwork]
|
||||
testNetworkOnce sync.Once
|
||||
)
|
||||
|
||||
// getTestNetwork creates or returns the shared Docker network for container communication.
|
||||
func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) {
|
||||
var networkErr error
|
||||
testNetworkOnce.Do(func() {
|
||||
nw, err := network.New(ctx, network.WithDriver("bridge"))
|
||||
if err != nil {
|
||||
networkErr = err
|
||||
return
|
||||
}
|
||||
testDockerNetwork.Store(nw)
|
||||
})
|
||||
return testDockerNetwork.Load(), networkErr
|
||||
}
|
||||
|
||||
// GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test.
|
||||
func GetMySQLDSN(t *testing.T) string {
|
||||
ctx := context.Background()
|
||||
|
||||
mysqlOnce.Do(func() {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test network: %v", err)
|
||||
}
|
||||
|
||||
container, err := mysql.Run(ctx,
|
||||
"mysql:8",
|
||||
mysql.WithDatabase("init_db"),
|
||||
mysql.WithUsername("root"),
|
||||
mysql.WithPassword(testPassword),
|
||||
testcontainers.WithEnv(map[string]string{
|
||||
"MYSQL_ROOT_PASSWORD": testPassword,
|
||||
}),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForAll(
|
||||
wait.ForLog("ready for connections").WithOccurrence(2),
|
||||
wait.ForListeningPort("3306/tcp"),
|
||||
).WithDeadline(120*time.Second),
|
||||
),
|
||||
network.WithNetwork(nil, nw),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start MySQL container: %v", err)
|
||||
}
|
||||
mysqlContainer.Store(container)
|
||||
|
||||
dsn, err := container.ConnectionString(ctx, "multiStatements=true")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get MySQL connection string: %v", err)
|
||||
}
|
||||
|
||||
if err := waitForDB("mysql", dsn, 30*time.Second); err != nil {
|
||||
t.Fatalf("MySQL not ready for connections: %v", err)
|
||||
}
|
||||
|
||||
mysqlBaseDSN.Store(dsn)
|
||||
})
|
||||
|
||||
dsn, ok := mysqlBaseDSN.Load().(string)
|
||||
if !ok || dsn == "" {
|
||||
t.Fatal("MySQL container failed to start in a previous test")
|
||||
}
|
||||
|
||||
// Serialize database creation to avoid "table already exists" race conditions
|
||||
dbCreationMutex.Lock()
|
||||
defer dbCreationMutex.Unlock()
|
||||
|
||||
// Create a fresh database for this test
|
||||
dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1))
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to MySQL: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if _, err := db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE `%s`", dbName)); err != nil {
|
||||
t.Fatalf("failed to create database %s: %v", dbName, err)
|
||||
}
|
||||
|
||||
// Return DSN pointing to the new database
|
||||
return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1)
|
||||
}
|
||||
|
||||
// waitForDB polls the database until it's ready or timeout is reached.
|
||||
func waitForDB(driver, dsn string, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastErr error
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if lastErr != nil {
|
||||
return errors.Errorf("timeout waiting for %s database: %v", driver, lastErr)
|
||||
}
|
||||
return errors.Errorf("timeout waiting for %s database to be ready", driver)
|
||||
case <-ticker.C:
|
||||
db, err := sql.Open(driver, dsn)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
err = db.PingContext(ctx)
|
||||
db.Close()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetPostgresDSN starts a PostgreSQL container (if not already running) and creates a fresh database for this test.
|
||||
func GetPostgresDSN(t *testing.T) string {
|
||||
ctx := context.Background()
|
||||
|
||||
postgresOnce.Do(func() {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test network: %v", err)
|
||||
}
|
||||
|
||||
container, err := postgres.Run(ctx,
|
||||
"postgres:18",
|
||||
postgres.WithDatabase("init_db"),
|
||||
postgres.WithUsername(testUser),
|
||||
postgres.WithPassword(testPassword),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForAll(
|
||||
wait.ForLog("database system is ready to accept connections").WithOccurrence(2),
|
||||
wait.ForListeningPort("5432/tcp"),
|
||||
).WithDeadline(120*time.Second),
|
||||
),
|
||||
network.WithNetwork(nil, nw),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start PostgreSQL container: %v", err)
|
||||
}
|
||||
postgresContainer.Store(container)
|
||||
|
||||
dsn, err := container.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get PostgreSQL connection string: %v", err)
|
||||
}
|
||||
|
||||
if err := waitForDB("postgres", dsn, 30*time.Second); err != nil {
|
||||
t.Fatalf("PostgreSQL not ready for connections: %v", err)
|
||||
}
|
||||
|
||||
postgresBaseDSN.Store(dsn)
|
||||
})
|
||||
|
||||
dsn, ok := postgresBaseDSN.Load().(string)
|
||||
if !ok || dsn == "" {
|
||||
t.Fatal("PostgreSQL container failed to start in a previous test")
|
||||
}
|
||||
|
||||
// Serialize database creation to avoid "table already exists" race conditions
|
||||
dbCreationMutex.Lock()
|
||||
defer dbCreationMutex.Unlock()
|
||||
|
||||
// Create a fresh database for this test
|
||||
dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1))
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to PostgreSQL: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if _, err := db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", dbName)); err != nil {
|
||||
t.Fatalf("failed to create database %s: %v", dbName, err)
|
||||
}
|
||||
|
||||
// Return DSN pointing to the new database
|
||||
return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1)
|
||||
}
|
||||
|
||||
// TerminateContainers cleans up all running containers and network.
|
||||
// This is typically called from TestMain.
|
||||
func TerminateContainers() {
|
||||
ctx := context.Background()
|
||||
if container := mysqlContainer.Load(); container != nil {
|
||||
_ = container.Terminate(ctx)
|
||||
}
|
||||
if container := postgresContainer.Load(); container != nil {
|
||||
_ = container.Terminate(ctx)
|
||||
}
|
||||
if network := testDockerNetwork.Load(); network != nil {
|
||||
_ = network.Remove(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// MemosContainerConfig holds configuration for starting a Memos container.
|
||||
type MemosContainerConfig struct {
|
||||
Version string // Memos version tag (e.g., "0.24.0")
|
||||
Driver string // Database driver: sqlite, mysql, postgres
|
||||
DSN string // Database DSN (for mysql/postgres)
|
||||
DataDir string // Host directory to mount for SQLite data
|
||||
}
|
||||
|
||||
// MemosStartupWaitStrategy defines the wait strategy for Memos container startup.
|
||||
// Uses regex to match various log message formats across versions.
|
||||
var MemosStartupWaitStrategy = wait.ForAll(
|
||||
wait.ForLog("(started successfully|has been started on port)").AsRegexp(),
|
||||
wait.ForListeningPort("5230/tcp"),
|
||||
).WithDeadline(180 * time.Second)
|
||||
|
||||
// StartMemosContainer starts a Memos container for migration testing.
|
||||
// For SQLite, it mounts the dataDir to /var/opt/memos.
|
||||
func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcontainers.Container, error) {
|
||||
env := map[string]string{
|
||||
"MEMOS_MODE": "prod",
|
||||
}
|
||||
|
||||
var opts []testcontainers.ContainerCustomizer
|
||||
|
||||
switch cfg.Driver {
|
||||
case "sqlite":
|
||||
env["MEMOS_DRIVER"] = "sqlite"
|
||||
opts = append(opts, testcontainers.WithHostConfigModifier(func(hc *container.HostConfig) {
|
||||
hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", cfg.DataDir, "/var/opt/memos"))
|
||||
}))
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported driver for migration testing: %s", cfg.Driver)
|
||||
}
|
||||
|
||||
req := testcontainers.ContainerRequest{
|
||||
Image: fmt.Sprintf("%s:%s", MemosDockerImage, cfg.Version),
|
||||
Env: env,
|
||||
ExposedPorts: []string{"5230/tcp"},
|
||||
WaitingFor: MemosStartupWaitStrategy,
|
||||
User: fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()),
|
||||
}
|
||||
|
||||
// Use local image if specified
|
||||
if cfg.Version == "local" {
|
||||
if os.Getenv("MEMOS_TEST_IMAGE_BUILT") == "1" {
|
||||
req.Image = "memos-test:local"
|
||||
} else {
|
||||
req.FromDockerfile = testcontainers.FromDockerfile{
|
||||
Context: "../../",
|
||||
Dockerfile: "Dockerfile",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
genericReq := testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: req,
|
||||
Started: true,
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
if err := opt.Customize(&genericReq); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to apply container option")
|
||||
}
|
||||
}
|
||||
|
||||
ctr, err := testcontainers.GenericContainer(ctx, genericReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to start memos container")
|
||||
}
|
||||
|
||||
return ctr, nil
|
||||
}
|
||||
290
store/test/filter_helpers_test.go
Normal file
290
store/test/filter_helpers_test.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Formatting Helpers
|
||||
// =============================================================================
|
||||
|
||||
func formatInt64(n int64) string {
|
||||
return strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
func formatInt32(n int32) string {
|
||||
return strconv.FormatInt(int64(n), 10)
|
||||
}
|
||||
|
||||
func formatInt(n int) string {
|
||||
return strconv.Itoa(n)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Pointer Helpers
|
||||
// =============================================================================
|
||||
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Test Fixture Builders
|
||||
// =============================================================================
|
||||
|
||||
// MemoBuilder provides a fluent API for creating test memos.
|
||||
type MemoBuilder struct {
|
||||
memo *store.Memo
|
||||
}
|
||||
|
||||
// NewMemoBuilder creates a new memo builder with required fields.
|
||||
func NewMemoBuilder(uid string, creatorID int32) *MemoBuilder {
|
||||
return &MemoBuilder{
|
||||
memo: &store.Memo{
|
||||
UID: uid,
|
||||
CreatorID: creatorID,
|
||||
Visibility: store.Public,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *MemoBuilder) Content(content string) *MemoBuilder {
|
||||
b.memo.Content = content
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *MemoBuilder) Visibility(v store.Visibility) *MemoBuilder {
|
||||
b.memo.Visibility = v
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *MemoBuilder) Tags(tags ...string) *MemoBuilder {
|
||||
if b.memo.Payload == nil {
|
||||
b.memo.Payload = &storepb.MemoPayload{}
|
||||
}
|
||||
b.memo.Payload.Tags = tags
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *MemoBuilder) Property(fn func(*storepb.MemoPayload_Property)) *MemoBuilder {
|
||||
if b.memo.Payload == nil {
|
||||
b.memo.Payload = &storepb.MemoPayload{}
|
||||
}
|
||||
if b.memo.Payload.Property == nil {
|
||||
b.memo.Payload.Property = &storepb.MemoPayload_Property{}
|
||||
}
|
||||
fn(b.memo.Payload.Property)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *MemoBuilder) Build() *store.Memo {
|
||||
return b.memo
|
||||
}
|
||||
|
||||
// AttachmentBuilder provides a fluent API for creating test attachments.
|
||||
type AttachmentBuilder struct {
|
||||
attachment *store.Attachment
|
||||
}
|
||||
|
||||
// NewAttachmentBuilder creates a new attachment builder with required fields.
|
||||
func NewAttachmentBuilder(creatorID int32) *AttachmentBuilder {
|
||||
return &AttachmentBuilder{
|
||||
attachment: &store.Attachment{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: creatorID,
|
||||
Blob: []byte("test"),
|
||||
Size: 1000,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *AttachmentBuilder) Filename(filename string) *AttachmentBuilder {
|
||||
b.attachment.Filename = filename
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *AttachmentBuilder) MimeType(mimeType string) *AttachmentBuilder {
|
||||
b.attachment.Type = mimeType
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *AttachmentBuilder) MemoID(memoID *int32) *AttachmentBuilder {
|
||||
b.attachment.MemoID = memoID
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *AttachmentBuilder) Size(size int64) *AttachmentBuilder {
|
||||
b.attachment.Size = size
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *AttachmentBuilder) Build() *store.Attachment {
|
||||
return b.attachment
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Test Context Helpers
|
||||
// =============================================================================
|
||||
|
||||
// MemoFilterTestContext holds common test dependencies for memo filter tests.
|
||||
type MemoFilterTestContext struct {
|
||||
Ctx context.Context
|
||||
T *testing.T
|
||||
Store *store.Store
|
||||
User *store.User
|
||||
}
|
||||
|
||||
// NewMemoFilterTestContext creates a new test context with store and user.
|
||||
func NewMemoFilterTestContext(t *testing.T) *MemoFilterTestContext {
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &MemoFilterTestContext{
|
||||
Ctx: ctx,
|
||||
T: t,
|
||||
Store: ts,
|
||||
User: user,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateMemo creates a memo using the builder pattern.
|
||||
func (tc *MemoFilterTestContext) CreateMemo(b *MemoBuilder) *store.Memo {
|
||||
memo, err := tc.Store.CreateMemo(tc.Ctx, b.Build())
|
||||
require.NoError(tc.T, err)
|
||||
return memo
|
||||
}
|
||||
|
||||
// PinMemo pins a memo by ID.
|
||||
func (tc *MemoFilterTestContext) PinMemo(memoID int32) {
|
||||
err := tc.Store.UpdateMemo(tc.Ctx, &store.UpdateMemo{
|
||||
ID: memoID,
|
||||
Pinned: boolPtr(true),
|
||||
})
|
||||
require.NoError(tc.T, err)
|
||||
}
|
||||
|
||||
// ListWithFilter lists memos with the given filter and returns the count.
|
||||
func (tc *MemoFilterTestContext) ListWithFilter(filter string) []*store.Memo {
|
||||
memos, err := tc.Store.ListMemos(tc.Ctx, &store.FindMemo{
|
||||
Filters: []string{filter},
|
||||
})
|
||||
require.NoError(tc.T, err)
|
||||
return memos
|
||||
}
|
||||
|
||||
// ListWithFilters lists memos with multiple filters and returns the count.
|
||||
func (tc *MemoFilterTestContext) ListWithFilters(filters ...string) []*store.Memo {
|
||||
memos, err := tc.Store.ListMemos(tc.Ctx, &store.FindMemo{
|
||||
Filters: filters,
|
||||
})
|
||||
require.NoError(tc.T, err)
|
||||
return memos
|
||||
}
|
||||
|
||||
// Close closes the test store.
|
||||
func (tc *MemoFilterTestContext) Close() {
|
||||
tc.Store.Close()
|
||||
}
|
||||
|
||||
// AttachmentFilterTestContext holds common test dependencies for attachment filter tests.
|
||||
type AttachmentFilterTestContext struct {
|
||||
Ctx context.Context
|
||||
T *testing.T
|
||||
Store *store.Store
|
||||
User *store.User
|
||||
CreatorID int32
|
||||
}
|
||||
|
||||
// NewAttachmentFilterTestContext creates a new test context for attachments.
|
||||
func NewAttachmentFilterTestContext(t *testing.T) *AttachmentFilterTestContext {
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
return &AttachmentFilterTestContext{
|
||||
Ctx: ctx,
|
||||
T: t,
|
||||
Store: ts,
|
||||
CreatorID: 101,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAttachmentFilterTestContextWithUser creates a new test context with a user.
|
||||
func NewAttachmentFilterTestContextWithUser(t *testing.T) *AttachmentFilterTestContext {
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &AttachmentFilterTestContext{
|
||||
Ctx: ctx,
|
||||
T: t,
|
||||
Store: ts,
|
||||
User: user,
|
||||
CreatorID: user.ID,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAttachment creates an attachment using the builder pattern.
|
||||
func (tc *AttachmentFilterTestContext) CreateAttachment(b *AttachmentBuilder) *store.Attachment {
|
||||
attachment, err := tc.Store.CreateAttachment(tc.Ctx, b.Build())
|
||||
require.NoError(tc.T, err)
|
||||
return attachment
|
||||
}
|
||||
|
||||
// CreateMemo creates a memo (for attachment tests that need memos).
|
||||
func (tc *AttachmentFilterTestContext) CreateMemo(uid, content string) *store.Memo {
|
||||
memo, err := tc.Store.CreateMemo(tc.Ctx, &store.Memo{
|
||||
UID: uid,
|
||||
CreatorID: tc.CreatorID,
|
||||
Content: content,
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(tc.T, err)
|
||||
return memo
|
||||
}
|
||||
|
||||
// ListWithFilter lists attachments with the given filter.
|
||||
func (tc *AttachmentFilterTestContext) ListWithFilter(filter string) []*store.Attachment {
|
||||
attachments, err := tc.Store.ListAttachments(tc.Ctx, &store.FindAttachment{
|
||||
CreatorID: &tc.CreatorID,
|
||||
Filters: []string{filter},
|
||||
})
|
||||
require.NoError(tc.T, err)
|
||||
return attachments
|
||||
}
|
||||
|
||||
// ListWithFilters lists attachments with multiple filters.
|
||||
func (tc *AttachmentFilterTestContext) ListWithFilters(filters ...string) []*store.Attachment {
|
||||
attachments, err := tc.Store.ListAttachments(tc.Ctx, &store.FindAttachment{
|
||||
CreatorID: &tc.CreatorID,
|
||||
Filters: filters,
|
||||
})
|
||||
require.NoError(tc.T, err)
|
||||
return attachments
|
||||
}
|
||||
|
||||
// Close closes the test store.
|
||||
func (tc *AttachmentFilterTestContext) Close() {
|
||||
tc.Store.Close()
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Filter Test Case Definition
|
||||
// =============================================================================
|
||||
|
||||
// FilterTestCase defines a single filter test case for table-driven tests.
|
||||
type FilterTestCase struct {
|
||||
Name string
|
||||
Filter string
|
||||
ExpectedCount int
|
||||
}
|
||||
454
store/test/idp_test.go
Normal file
454
store/test/idp_test.go
Normal file
@@ -0,0 +1,454 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestIdentityProviderStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
createdIDP, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Name: "GitHub OAuth",
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
IdentifierFilter: "",
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AuthUrl: "https://github.com/auth",
|
||||
TokenUrl: "https://github.com/token",
|
||||
UserInfoUrl: "https://github.com/user",
|
||||
Scopes: []string{"login"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "login",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &createdIDP.Id,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, idp)
|
||||
require.Equal(t, createdIDP, idp)
|
||||
newName := "My GitHub OAuth"
|
||||
updatedIdp, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
|
||||
ID: idp.Id,
|
||||
Name: &newName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newName, updatedIdp.Name)
|
||||
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{
|
||||
ID: idp.Id,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(idpList))
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderGetByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create IDP
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by ID
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, idp.Id, found.Id)
|
||||
require.Equal(t, idp.Name, found.Name)
|
||||
|
||||
// Get by non-existent ID
|
||||
nonExistentID := int32(99999)
|
||||
notFound, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &nonExistentID})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderListMultiple(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create multiple IDPs
|
||||
_, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// List all
|
||||
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, idpList, 3)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderListByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create multiple IDPs
|
||||
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by specific ID
|
||||
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{ID: &idp1.Id})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, idpList, 1)
|
||||
require.Equal(t, "GitHub OAuth", idpList[0].Name)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderUpdateName(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Original Name", idp.Name)
|
||||
|
||||
// Update name
|
||||
newName := "Updated Name"
|
||||
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
|
||||
ID: idp.Id,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Name: &newName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Updated Name", updated.Name)
|
||||
|
||||
// Verify update persisted
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Updated Name", found.Name)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderUpdateIdentifierFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", idp.IdentifierFilter)
|
||||
|
||||
// Update identifier filter
|
||||
newFilter := "@example.com$"
|
||||
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
|
||||
ID: idp.Id,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
IdentifierFilter: &newFilter,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "@example.com$", updated.IdentifierFilter)
|
||||
|
||||
// Verify update persisted
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "@example.com$", found.IdentifierFilter)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderUpdateConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update config
|
||||
newConfig := &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "new_client_id",
|
||||
ClientSecret: "new_client_secret",
|
||||
AuthUrl: "https://newprovider.com/auth",
|
||||
TokenUrl: "https://newprovider.com/token",
|
||||
UserInfoUrl: "https://newprovider.com/user",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
|
||||
ID: idp.Id,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Config: newConfig,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new_client_id", updated.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "new_client_secret", updated.Config.GetOauth2Config().ClientSecret)
|
||||
require.Contains(t, updated.Config.GetOauth2Config().Scopes, "openid")
|
||||
|
||||
// Verify update persisted
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new_client_id", found.Config.GetOauth2Config().ClientId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderUpdateMultipleFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update multiple fields at once
|
||||
newName := "Updated IDP"
|
||||
newFilter := "^admin@"
|
||||
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
|
||||
ID: idp.Id,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Name: &newName,
|
||||
IdentifierFilter: &newFilter,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Updated IDP", updated.Name)
|
||||
require.Equal(t, "^admin@", updated.IdentifierFilter)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete
|
||||
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp.Id})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, found)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderDeleteNotAffectOthers(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create multiple IDPs
|
||||
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1"))
|
||||
require.NoError(t, err)
|
||||
idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete first one
|
||||
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp1.Id})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify second still exists
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp2.Id})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, "IDP 2", found.Name)
|
||||
|
||||
// Verify list only contains second
|
||||
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, idpList, 1)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderOAuth2ConfigScopes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create IDP with multiple scopes
|
||||
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Name: "Multi-Scope OAuth",
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AuthUrl: "https://provider.com/auth",
|
||||
TokenUrl: "https://provider.com/token",
|
||||
UserInfoUrl: "https://provider.com/userinfo",
|
||||
Scopes: []string{"openid", "profile", "email", "groups"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify scopes are preserved
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, found.Config.GetOauth2Config().Scopes, 4)
|
||||
require.Contains(t, found.Config.GetOauth2Config().Scopes, "openid")
|
||||
require.Contains(t, found.Config.GetOauth2Config().Scopes, "groups")
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderFieldMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create IDP with custom field mapping
|
||||
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Name: "Custom Field Mapping",
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AuthUrl: "https://provider.com/auth",
|
||||
TokenUrl: "https://provider.com/token",
|
||||
UserInfoUrl: "https://provider.com/userinfo",
|
||||
Scopes: []string{"login"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "preferred_username",
|
||||
DisplayName: "full_name",
|
||||
Email: "email_address",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify field mapping is preserved
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "preferred_username", found.Config.GetOauth2Config().FieldMapping.Identifier)
|
||||
require.Equal(t, "full_name", found.Config.GetOauth2Config().FieldMapping.DisplayName)
|
||||
require.Equal(t, "email_address", found.Config.GetOauth2Config().FieldMapping.Email)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
filter string
|
||||
}{
|
||||
{"Domain filter", "@company\\.com$"},
|
||||
{"Prefix filter", "^admin_"},
|
||||
{"Complex regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"},
|
||||
{"Empty filter", ""},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Name: tc.name,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
IdentifierFilter: tc.filter,
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AuthUrl: "https://provider.com/auth",
|
||||
TokenUrl: "https://provider.com/token",
|
||||
UserInfoUrl: "https://provider.com/userinfo",
|
||||
Scopes: []string{"login"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.filter, found.IdentifierFilter)
|
||||
|
||||
// Cleanup
|
||||
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp.Id})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
// Helper function to create a test OAuth2 IDP.
|
||||
func createTestOAuth2IDP(name string) *storepb.IdentityProvider {
|
||||
return &storepb.IdentityProvider{
|
||||
Name: name,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
IdentifierFilter: "",
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AuthUrl: "https://provider.com/auth",
|
||||
TokenUrl: "https://provider.com/token",
|
||||
UserInfoUrl: "https://provider.com/userinfo",
|
||||
Scopes: []string{"login"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "login",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
592
store/test/inbox_test.go
Normal file
592
store/test/inbox_test.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestInboxStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
const systemBotID int32 = 0
|
||||
create := &store.Inbox{
|
||||
SenderID: systemBotID,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
},
|
||||
}
|
||||
inbox, err := ts.CreateInbox(ctx, create)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, inbox)
|
||||
require.Equal(t, create.Message, inbox.Message)
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(inboxes))
|
||||
require.Equal(t, inbox, inboxes[0])
|
||||
updatedInbox, err := ts.UpdateInbox(ctx, &store.UpdateInbox{
|
||||
ID: inbox.ID,
|
||||
Status: store.ARCHIVED,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedInbox)
|
||||
require.Equal(t, store.ARCHIVED, updatedInbox.Status)
|
||||
err = ts.DeleteInbox(ctx, &store.DeleteInbox{
|
||||
ID: inbox.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(inboxes))
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by ID
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ID: &inbox.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, inbox.ID, inboxes[0].ID)
|
||||
|
||||
// List by non-existent ID
|
||||
nonExistentID := int32(99999)
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ID: &nonExistentID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 0)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListBySenderID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox from system bot (senderID = 0)
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox from user2
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: user2.ID,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by sender ID = user2
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{SenderID: &user2.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, user2.ID, inboxes[0].SenderID)
|
||||
|
||||
// List by sender ID = 0 (system bot)
|
||||
systemBotID := int32(0)
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{SenderID: &systemBotID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, int32(0), inboxes[0].SenderID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListByStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create UNREAD inbox
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create another inbox and archive it
|
||||
inbox2, err := ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: inbox2.ID, Status: store.ARCHIVED})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by UNREAD status
|
||||
unreadStatus := store.UNREAD
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{Status: &unreadStatus})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, store.UNREAD, inboxes[0].Status)
|
||||
|
||||
// List by ARCHIVED status
|
||||
archivedStatus := store.ARCHIVED
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{Status: &archivedStatus})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, store.ARCHIVED, inboxes[0].Status)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListByMessageType(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create MEMO_COMMENT inboxes
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by MEMO_COMMENT type
|
||||
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{MessageType: &memoCommentType})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2)
|
||||
for _, inbox := range inboxes {
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create 5 inboxes
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test Limit only
|
||||
limit := 3
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
Limit: &limit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 3)
|
||||
|
||||
// Test Limit + Offset (offset requires limit in the implementation)
|
||||
limit = 2
|
||||
offset := 2
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2)
|
||||
|
||||
// Test Limit + Offset skipping to end
|
||||
limit = 10
|
||||
offset = 3
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2) // Only 2 remaining after offset of 3
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListCombinedFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create various inboxes
|
||||
// user2 -> user1, MEMO_COMMENT, UNREAD
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: user2.ID,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// user2 -> user1, TYPE_UNSPECIFIED, UNREAD
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: user2.ID,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_TYPE_UNSPECIFIED},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// system -> user1, MEMO_COMMENT, ARCHIVED
|
||||
inbox3, err := ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: inbox3.ID, Status: store.ARCHIVED})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Combined filter: ReceiverID + SenderID + Status
|
||||
unreadStatus := store.UNREAD
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user1.ID,
|
||||
SenderID: &user2.ID,
|
||||
Status: &unreadStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2)
|
||||
|
||||
// Combined filter: ReceiverID + MessageType + Status
|
||||
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user1.ID,
|
||||
MessageType: &memoCommentType,
|
||||
Status: &unreadStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, user2.ID, inboxes[0].SenderID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxMessagePayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox with message payload containing activity ID
|
||||
activityID := int32(123)
|
||||
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
ActivityId: &activityID,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, inbox.Message)
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
|
||||
require.Equal(t, activityID, *inbox.Message.ActivityId)
|
||||
|
||||
// List and verify payload is preserved
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, activityID, *inboxes[0].Message.ActivityId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxUpdateStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.UNREAD, inbox.Status)
|
||||
|
||||
// Update to ARCHIVED
|
||||
updated, err := ts.UpdateInbox(ctx, &store.UpdateInbox{
|
||||
ID: inbox.ID,
|
||||
Status: store.ARCHIVED,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.ARCHIVED, updated.Status)
|
||||
require.Equal(t, inbox.ID, updated.ID)
|
||||
|
||||
// Verify the update persisted
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ID: &inbox.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, store.ARCHIVED, inboxes[0].Status)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListByMessageTypeMultipleTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inboxes with different message types
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_TYPE_UNSPECIFIED},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Filter by MEMO_COMMENT - should get 2
|
||||
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
MessageType: &memoCommentType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2)
|
||||
for _, inbox := range inboxes {
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
|
||||
}
|
||||
|
||||
// Filter by TYPE_UNSPECIFIED - should get 1
|
||||
unspecifiedType := storepb.InboxMessage_TYPE_UNSPECIFIED
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
MessageType: &unspecifiedType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, storepb.InboxMessage_TYPE_UNSPECIFIED, inboxes[0].Message.Type)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxMessageTypeFilterWithPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox with full payload
|
||||
activityID := int32(456)
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
ActivityId: &activityID,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox with different type but also has payload
|
||||
otherActivityID := int32(789)
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_TYPE_UNSPECIFIED,
|
||||
ActivityId: &otherActivityID,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Filter by type should work correctly even with complex JSON payload
|
||||
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
MessageType: &memoCommentType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, activityID, *inboxes[0].Message.ActivityId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxMessageTypeFilterWithStatusAndPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple inboxes with various combinations
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Archive 2 of them
|
||||
allInboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: allInboxes[i].ID, Status: store.ARCHIVED})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Filter by type + status + pagination
|
||||
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
|
||||
unreadStatus := store.UNREAD
|
||||
limit := 2
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
MessageType: &memoCommentType,
|
||||
Status: &unreadStatus,
|
||||
Limit: &limit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2)
|
||||
for _, inbox := range inboxes {
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
|
||||
require.Equal(t, store.UNREAD, inbox.Status)
|
||||
}
|
||||
|
||||
// Get next page
|
||||
offset := 2
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
MessageType: &memoCommentType,
|
||||
Status: &unreadStatus,
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1) // Only 1 remaining (3 unread total, got 2, now 1 left)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxMultipleReceivers(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox for user1
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox for user2
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user2.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// User1 should only see their inbox
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user1.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, user1.ID, inboxes[0].ReceiverID)
|
||||
|
||||
// User2 should only see their inbox
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user2.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, user2.ID, inboxes[0].ReceiverID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
300
store/test/instance_setting_test.go
Normal file
300
store/test/instance_setting_test.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestInstanceSettingV1Store(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
instanceSetting, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
AdditionalScript: "",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
setting, err := ts.GetInstanceSetting(ctx, &store.FindInstanceSetting{
|
||||
Name: storepb.InstanceSettingKey_GENERAL.String(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, instanceSetting, setting)
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingGetNonExistent(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Get non-existent setting
|
||||
setting, err := ts.GetInstanceSetting(ctx, &store.FindInstanceSetting{
|
||||
Name: storepb.InstanceSettingKey_STORAGE.String(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, setting)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingUpsertUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create setting
|
||||
_, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
AdditionalScript: "console.log('v1')",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update setting
|
||||
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
AdditionalScript: "console.log('v2')",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify update
|
||||
setting, err := ts.GetInstanceSetting(ctx, &store.FindInstanceSetting{
|
||||
Name: storepb.InstanceSettingKey_GENERAL.String(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "console.log('v2')", setting.GetGeneralSetting().AdditionalScript)
|
||||
|
||||
// Verify only one setting exists
|
||||
list, err := ts.ListInstanceSettings(ctx, &store.FindInstanceSetting{
|
||||
Name: storepb.InstanceSettingKey_GENERAL.String(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(list))
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingBasicSetting(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Get default basic setting (should return empty defaults)
|
||||
basicSetting, err := ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, basicSetting)
|
||||
|
||||
// Set basic setting
|
||||
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_BASIC,
|
||||
Value: &storepb.InstanceSetting_BasicSetting{
|
||||
BasicSetting: &storepb.InstanceBasicSetting{
|
||||
SecretKey: "my-secret-key",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify
|
||||
basicSetting, err = ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "my-secret-key", basicSetting.SecretKey)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingGeneralSetting(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Get default general setting
|
||||
generalSetting, err := ts.GetInstanceGeneralSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, generalSetting)
|
||||
|
||||
// Set general setting
|
||||
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
AdditionalScript: "console.log('test')",
|
||||
AdditionalStyle: "body { color: red; }",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify
|
||||
generalSetting, err = ts.GetInstanceGeneralSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "console.log('test')", generalSetting.AdditionalScript)
|
||||
require.Equal(t, "body { color: red; }", generalSetting.AdditionalStyle)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingMemoRelatedSetting(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Get default memo related setting (should have defaults)
|
||||
memoSetting, err := ts.GetInstanceMemoRelatedSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoSetting)
|
||||
require.GreaterOrEqual(t, memoSetting.ContentLengthLimit, int32(store.DefaultContentLengthLimit))
|
||||
require.NotEmpty(t, memoSetting.Reactions)
|
||||
|
||||
// Set custom memo related setting
|
||||
customReactions := []string{"👍", "👎", "🚀"}
|
||||
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_MEMO_RELATED,
|
||||
Value: &storepb.InstanceSetting_MemoRelatedSetting{
|
||||
MemoRelatedSetting: &storepb.InstanceMemoRelatedSetting{
|
||||
ContentLengthLimit: 16384,
|
||||
Reactions: customReactions,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify
|
||||
memoSetting, err = ts.GetInstanceMemoRelatedSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(16384), memoSetting.ContentLengthLimit)
|
||||
require.Equal(t, customReactions, memoSetting.Reactions)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingStorageSetting(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Get default storage setting (should have defaults)
|
||||
storageSetting, err := ts.GetInstanceStorageSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, storageSetting)
|
||||
require.Equal(t, storepb.InstanceStorageSetting_LOCAL, storageSetting.StorageType)
|
||||
require.Equal(t, int64(30), storageSetting.UploadSizeLimitMb)
|
||||
require.Equal(t, "assets/{timestamp}_{filename}", storageSetting.FilepathTemplate)
|
||||
|
||||
// Set custom storage setting
|
||||
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_STORAGE,
|
||||
Value: &storepb.InstanceSetting_StorageSetting{
|
||||
StorageSetting: &storepb.InstanceStorageSetting{
|
||||
StorageType: storepb.InstanceStorageSetting_LOCAL,
|
||||
UploadSizeLimitMb: 100,
|
||||
FilepathTemplate: "uploads/{date}/{filename}",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify
|
||||
storageSetting, err = ts.GetInstanceStorageSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, storepb.InstanceStorageSetting_LOCAL, storageSetting.StorageType)
|
||||
require.Equal(t, int64(100), storageSetting.UploadSizeLimitMb)
|
||||
require.Equal(t, "uploads/{date}/{filename}", storageSetting.FilepathTemplate)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingListAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Count initial settings
|
||||
initialList, err := ts.ListInstanceSettings(ctx, &store.FindInstanceSetting{})
|
||||
require.NoError(t, err)
|
||||
initialCount := len(initialList)
|
||||
|
||||
// Create multiple settings
|
||||
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_STORAGE,
|
||||
Value: &storepb.InstanceSetting_StorageSetting{
|
||||
StorageSetting: &storepb.InstanceStorageSetting{},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List all - should have 2 more than initial
|
||||
list, err := ts.ListInstanceSettings(ctx, &store.FindInstanceSetting{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialCount+2, len(list))
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Case 1: General Setting with special characters and Unicode
|
||||
specialScript := `<script>alert("你好"); var x = 'test\'s';</script>`
|
||||
specialStyle := `body { font-family: "Noto Sans SC", sans-serif; content: "\u2764"; }`
|
||||
_, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
AdditionalScript: specialScript,
|
||||
AdditionalStyle: specialStyle,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
generalSetting, err := ts.GetInstanceGeneralSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, specialScript, generalSetting.AdditionalScript)
|
||||
require.Equal(t, specialStyle, generalSetting.AdditionalStyle)
|
||||
|
||||
// Case 2: Memo Related Setting with Unicode reactions
|
||||
unicodeReactions := []string{"🐱", "🐶", "🦊", "🦄"}
|
||||
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_MEMO_RELATED,
|
||||
Value: &storepb.InstanceSetting_MemoRelatedSetting{
|
||||
MemoRelatedSetting: &storepb.InstanceMemoRelatedSetting{
|
||||
ContentLengthLimit: 1000,
|
||||
Reactions: unicodeReactions,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoSetting, err := ts.GetInstanceMemoRelatedSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, unicodeReactions, memoSetting.Reactions)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
50
store/test/main_test.go
Normal file
50
store/test/main_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// If DRIVER is set, run tests for that driver only
|
||||
if os.Getenv("DRIVER") != "" {
|
||||
defer TerminateContainers()
|
||||
m.Run() //nolint:revive // Exit code is handled by test runner
|
||||
return
|
||||
}
|
||||
|
||||
// No DRIVER set - run tests for all drivers sequentially
|
||||
runAllDrivers()
|
||||
}
|
||||
|
||||
func runAllDrivers() {
|
||||
drivers := []string{"sqlite", "mysql", "postgres"}
|
||||
_, currentFile, _, _ := runtime.Caller(0)
|
||||
projectRoot := filepath.Dir(filepath.Dir(filepath.Dir(currentFile)))
|
||||
|
||||
var failed []string
|
||||
for _, driver := range drivers {
|
||||
fmt.Printf("\n==================== %s ====================\n\n", driver)
|
||||
|
||||
cmd := exec.Command("go", "test", "-v", "-count=1", "./store/test/...")
|
||||
cmd.Dir = projectRoot
|
||||
cmd.Env = append(os.Environ(), "DRIVER="+driver)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
failed = append(failed, driver)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
if len(failed) > 0 {
|
||||
fmt.Printf("FAIL: %v\n", failed)
|
||||
panic("some drivers failed")
|
||||
}
|
||||
fmt.Println("PASS: all drivers")
|
||||
}
|
||||
939
store/test/memo_filter_test.go
Normal file
939
store/test/memo_filter_test.go
Normal file
@@ -0,0 +1,939 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Content Field Tests
|
||||
// Schema: content (string, supports contains)
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterContentContains(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memos with different content
|
||||
tc.CreateMemo(NewMemoBuilder("memo-hello", tc.User.ID).Content("Hello world"))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-goodbye", tc.User.ID).Content("Goodbye world"))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-test", tc.User.ID).Content("Testing content"))
|
||||
|
||||
// Test: content.contains("Hello") - single match
|
||||
memos := tc.ListWithFilter(`content.contains("Hello")`)
|
||||
require.Len(t, memos, 1)
|
||||
require.Contains(t, memos[0].Content, "Hello")
|
||||
|
||||
// Test: content.contains("world") - multiple matches
|
||||
memos = tc.ListWithFilter(`content.contains("world")`)
|
||||
require.Len(t, memos, 2)
|
||||
|
||||
// Test: content.contains("nonexistent") - no matches
|
||||
memos = tc.ListWithFilter(`content.contains("nonexistent")`)
|
||||
require.Len(t, memos, 0)
|
||||
}
|
||||
|
||||
func TestMemoFilterContentSpecialCharacters(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-special", tc.User.ID).Content("Special chars: @#$%^&*()"))
|
||||
|
||||
memos := tc.ListWithFilter(`content.contains("@#$%")`)
|
||||
require.Len(t, memos, 1)
|
||||
}
|
||||
|
||||
func TestMemoFilterContentUnicode(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-unicode", tc.User.ID).Content("Unicode test: 你好世界 🌍"))
|
||||
|
||||
memos := tc.ListWithFilter(`content.contains("你好")`)
|
||||
require.Len(t, memos, 1)
|
||||
}
|
||||
|
||||
func TestMemoFilterContentUnicodeCaseFold(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-unicode-case", tc.User.ID).Content("Привет Мир"))
|
||||
|
||||
memos := tc.ListWithFilter(`content.contains("привет")`)
|
||||
require.Len(t, memos, 1)
|
||||
}
|
||||
|
||||
func TestMemoFilterContentCaseSensitivity(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-case", tc.User.ID).Content("MixedCase Content"))
|
||||
|
||||
// Exact match
|
||||
memos := tc.ListWithFilter(`content.contains("MixedCase")`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Lowercase match (depends on DB collation, usually case-insensitive in default installs but good to verify behavior)
|
||||
// SQLite default LIKE is case-insensitive for ASCII.
|
||||
memosLower := tc.ListWithFilter(`content.contains("mixedcase")`)
|
||||
// We just verify it doesn't crash; strict case sensitivity expectation depends on DB config.
|
||||
// For standard Memos setup (SQLite), it's often case-insensitive.
|
||||
// Let's check if we get a result or not to characterize current behavior.
|
||||
if len(memosLower) > 0 {
|
||||
require.Equal(t, "MixedCase Content", memosLower[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Visibility Field Tests
|
||||
// Schema: visibility (string, ==, !=)
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterVisibilityEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-public", tc.User.ID).Content("Public memo").Visibility(store.Public))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-private", tc.User.ID).Content("Private memo").Visibility(store.Private))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-protected", tc.User.ID).Content("Protected memo").Visibility(store.Protected))
|
||||
|
||||
// Test: visibility == "PUBLIC"
|
||||
memos := tc.ListWithFilter(`visibility == "PUBLIC"`)
|
||||
require.Len(t, memos, 1)
|
||||
require.Equal(t, store.Public, memos[0].Visibility)
|
||||
|
||||
// Test: visibility == "PRIVATE"
|
||||
memos = tc.ListWithFilter(`visibility == "PRIVATE"`)
|
||||
require.Len(t, memos, 1)
|
||||
require.Equal(t, store.Private, memos[0].Visibility)
|
||||
}
|
||||
|
||||
func TestMemoFilterVisibilityNotEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-public", tc.User.ID).Content("Public memo").Visibility(store.Public))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-private", tc.User.ID).Content("Private memo").Visibility(store.Private))
|
||||
|
||||
memos := tc.ListWithFilter(`visibility != "PUBLIC"`)
|
||||
require.Len(t, memos, 1)
|
||||
require.Equal(t, store.Private, memos[0].Visibility)
|
||||
}
|
||||
|
||||
func TestMemoFilterVisibilityInList(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-pub", tc.User.ID).Visibility(store.Public))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-priv", tc.User.ID).Visibility(store.Private))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-prot", tc.User.ID).Visibility(store.Protected))
|
||||
|
||||
memos := tc.ListWithFilter(`visibility in ["PUBLIC", "PRIVATE"]`)
|
||||
require.Len(t, memos, 2)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Pinned Field Tests
|
||||
// Schema: pinned (bool column, ==, !=, predicate)
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterPinnedEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
pinnedMemo := tc.CreateMemo(NewMemoBuilder("memo-pinned", tc.User.ID).Content("Pinned memo"))
|
||||
tc.PinMemo(pinnedMemo.ID)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-unpinned", tc.User.ID).Content("Unpinned memo"))
|
||||
|
||||
// Test: pinned == true
|
||||
memos := tc.ListWithFilter(`pinned == true`)
|
||||
require.Len(t, memos, 1)
|
||||
require.True(t, memos[0].Pinned)
|
||||
|
||||
// Test: pinned == false
|
||||
memos = tc.ListWithFilter(`pinned == false`)
|
||||
require.Len(t, memos, 1)
|
||||
require.False(t, memos[0].Pinned)
|
||||
}
|
||||
|
||||
func TestMemoFilterPinnedPredicate(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
pinnedMemo := tc.CreateMemo(NewMemoBuilder("memo-pinned", tc.User.ID).Content("Pinned memo"))
|
||||
tc.PinMemo(pinnedMemo.ID)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-unpinned", tc.User.ID).Content("Unpinned memo"))
|
||||
|
||||
memos := tc.ListWithFilter(`pinned`)
|
||||
require.Len(t, memos, 1)
|
||||
require.True(t, memos[0].Pinned)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Creator ID Field Tests
|
||||
// Schema: creator_id (int, ==, !=)
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterCreatorIdEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
user2, err := tc.Store.CreateUser(tc.Ctx, &store.User{
|
||||
Username: "user2",
|
||||
Role: store.RoleUser,
|
||||
Email: "user2@example.com",
|
||||
Nickname: "User 2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-user1", tc.User.ID).Content("User 1 memo"))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-user2", user2.ID).Content("User 2 memo"))
|
||||
|
||||
memos := tc.ListWithFilter(`creator_id == ` + formatInt(int(tc.User.ID)))
|
||||
require.Len(t, memos, 1)
|
||||
require.Equal(t, tc.User.ID, memos[0].CreatorID)
|
||||
}
|
||||
|
||||
func TestMemoFilterCreatorIdNotEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
user2, err := tc.Store.CreateUser(tc.Ctx, &store.User{
|
||||
Username: "user2",
|
||||
Role: store.RoleUser,
|
||||
Email: "user2@example.com",
|
||||
Nickname: "User 2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-user1", tc.User.ID).Content("User 1 memo"))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-user2", user2.ID).Content("User 2 memo"))
|
||||
|
||||
memos := tc.ListWithFilter(`creator_id != ` + formatInt(int(tc.User.ID)))
|
||||
require.Len(t, memos, 1)
|
||||
require.Equal(t, user2.ID, memos[0].CreatorID)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tags Field Tests
|
||||
// Schema: tags (JSON list), tag (virtual alias)
|
||||
// Operators: tag in [...], "value" in tags
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterTagInList(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-work", tc.User.ID).Content("Work memo").Tags("work", "important"))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-personal", tc.User.ID).Content("Personal memo").Tags("personal", "fun"))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-no-tags", tc.User.ID).Content("No tags"))
|
||||
|
||||
// Test: tag in ["work"]
|
||||
memos := tc.ListWithFilter(`tag in ["work"]`)
|
||||
require.Len(t, memos, 1)
|
||||
require.Contains(t, memos[0].Payload.Tags, "work")
|
||||
|
||||
// Test: tag in ["work", "personal"]
|
||||
memos = tc.ListWithFilter(`tag in ["work", "personal"]`)
|
||||
require.Len(t, memos, 2)
|
||||
}
|
||||
|
||||
func TestMemoFilterElementInTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-tagged", tc.User.ID).Content("Tagged memo").Tags("project", "todo"))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-untagged", tc.User.ID).Content("Untagged memo"))
|
||||
|
||||
// Test: "project" in tags
|
||||
memos := tc.ListWithFilter(`"project" in tags`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: "nonexistent" in tags
|
||||
memos = tc.ListWithFilter(`"nonexistent" in tags`)
|
||||
require.Len(t, memos, 0)
|
||||
}
|
||||
|
||||
func TestMemoFilterHierarchicalTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-book", tc.User.ID).Content("Book memo").Tags("book"))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-book-fiction", tc.User.ID).Content("Fiction book memo").Tags("book/fiction"))
|
||||
|
||||
// Test: tag in ["book"] should match both (hierarchical matching)
|
||||
memos := tc.ListWithFilter(`tag in ["book"]`)
|
||||
require.Len(t, memos, 2)
|
||||
}
|
||||
|
||||
func TestMemoFilterEmptyTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-empty-tags", tc.User.ID).Content("Empty tags").Tags())
|
||||
|
||||
memos := tc.ListWithFilter(`tag in ["anything"]`)
|
||||
require.Len(t, memos, 0)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// JSON Bool Field Tests
|
||||
// Schema: has_task_list, has_link, has_code, has_incomplete_tasks
|
||||
// Operators: ==, !=, predicate
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterHasTaskList(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-with-tasks", tc.User.ID).
|
||||
Content("- [ ] Task 1\n- [x] Task 2").
|
||||
Property(func(p *storepb.MemoPayload_Property) { p.HasTaskList = true }))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-no-tasks", tc.User.ID).Content("No tasks here"))
|
||||
|
||||
// Test: has_task_list (predicate)
|
||||
memos := tc.ListWithFilter(`has_task_list`)
|
||||
require.Len(t, memos, 1)
|
||||
require.True(t, memos[0].Payload.Property.HasTaskList)
|
||||
|
||||
// Test: has_task_list == true
|
||||
memos = tc.ListWithFilter(`has_task_list == true`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Note: has_task_list == false is not tested because JSON boolean fields
|
||||
// with false value may not be queryable when the field is not present in JSON
|
||||
}
|
||||
|
||||
func TestMemoFilterHasLink(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-with-link", tc.User.ID).
|
||||
Content("Check out https://example.com").
|
||||
Property(func(p *storepb.MemoPayload_Property) { p.HasLink = true }))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-no-link", tc.User.ID).Content("No links"))
|
||||
|
||||
memos := tc.ListWithFilter(`has_link`)
|
||||
require.Len(t, memos, 1)
|
||||
require.True(t, memos[0].Payload.Property.HasLink)
|
||||
}
|
||||
|
||||
func TestMemoFilterHasCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-with-code", tc.User.ID).
|
||||
Content("```go\nfmt.Println(\"Hello\")\n```").
|
||||
Property(func(p *storepb.MemoPayload_Property) { p.HasCode = true }))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-no-code", tc.User.ID).Content("No code"))
|
||||
|
||||
memos := tc.ListWithFilter(`has_code`)
|
||||
require.Len(t, memos, 1)
|
||||
require.True(t, memos[0].Payload.Property.HasCode)
|
||||
}
|
||||
|
||||
func TestMemoFilterHasIncompleteTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-incomplete", tc.User.ID).
|
||||
Content("- [ ] Incomplete task").
|
||||
Property(func(p *storepb.MemoPayload_Property) {
|
||||
p.HasTaskList = true
|
||||
p.HasIncompleteTasks = true
|
||||
}))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-complete", tc.User.ID).
|
||||
Content("- [x] Complete task").
|
||||
Property(func(p *storepb.MemoPayload_Property) {
|
||||
p.HasTaskList = true
|
||||
p.HasIncompleteTasks = false
|
||||
}))
|
||||
|
||||
memos := tc.ListWithFilter(`has_incomplete_tasks`)
|
||||
require.Len(t, memos, 1)
|
||||
require.True(t, memos[0].Payload.Property.HasIncompleteTasks)
|
||||
}
|
||||
|
||||
func TestMemoFilterCombinedJSONBool(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Memo with all properties
|
||||
tc.CreateMemo(NewMemoBuilder("memo-all-props", tc.User.ID).
|
||||
Content("All properties").
|
||||
Property(func(p *storepb.MemoPayload_Property) {
|
||||
p.HasLink = true
|
||||
p.HasTaskList = true
|
||||
p.HasCode = true
|
||||
p.HasIncompleteTasks = true
|
||||
}))
|
||||
|
||||
// Memo with only link
|
||||
tc.CreateMemo(NewMemoBuilder("memo-only-link", tc.User.ID).
|
||||
Content("Only link").
|
||||
Property(func(p *storepb.MemoPayload_Property) { p.HasLink = true }))
|
||||
|
||||
// Test: has_link && has_code
|
||||
memos := tc.ListWithFilter(`has_link && has_code`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: has_task_list && has_incomplete_tasks
|
||||
memos = tc.ListWithFilter(`has_task_list && has_incomplete_tasks`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: has_link || has_code
|
||||
memos = tc.ListWithFilter(`has_link || has_code`)
|
||||
require.Len(t, memos, 2)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Timestamp Field Tests
|
||||
// Schema: created_ts, updated_ts (timestamp, all comparison operators)
|
||||
// Functions: now(), arithmetic (+, -, *)
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterCreatedTsComparison(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
now := time.Now().Unix()
|
||||
tc.CreateMemo(NewMemoBuilder("memo-ts", tc.User.ID).Content("Timestamp test"))
|
||||
|
||||
// Test: created_ts < future (should match)
|
||||
memos := tc.ListWithFilter(`created_ts < ` + formatInt64(now+3600))
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: created_ts > past (should match)
|
||||
memos = tc.ListWithFilter(`created_ts > ` + formatInt64(now-3600))
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: created_ts > future (should not match)
|
||||
memos = tc.ListWithFilter(`created_ts > ` + formatInt64(now+3600))
|
||||
require.Len(t, memos, 0)
|
||||
}
|
||||
|
||||
func TestMemoFilterCreatedTsWithNow(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-ts-test", tc.User.ID).Content("Timestamp test"))
|
||||
|
||||
// Test: created_ts < now() + 5 (buffer for container clock drift)
|
||||
memos := tc.ListWithFilter(`created_ts < now() + 5`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: created_ts > now() + 5 (should not match)
|
||||
memos = tc.ListWithFilter(`created_ts > now() + 5`)
|
||||
require.Len(t, memos, 0)
|
||||
}
|
||||
|
||||
func TestMemoFilterCreatedTsArithmetic(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-ts-arith", tc.User.ID).Content("Timestamp arithmetic test"))
|
||||
|
||||
// Test: created_ts >= now() - 3600 (memos created in last hour)
|
||||
memos := tc.ListWithFilter(`created_ts >= now() - 3600`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: created_ts < now() - 86400 (memos older than 1 day - should be empty)
|
||||
memos = tc.ListWithFilter(`created_ts < now() - 86400`)
|
||||
require.Len(t, memos, 0)
|
||||
|
||||
// Test: Multiplication - created_ts >= now() - 60 * 60
|
||||
memos = tc.ListWithFilter(`created_ts >= now() - 60 * 60`)
|
||||
require.Len(t, memos, 1)
|
||||
}
|
||||
|
||||
func TestMemoFilterUpdatedTs(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
memo := tc.CreateMemo(NewMemoBuilder("memo-updated", tc.User.ID).Content("Will be updated"))
|
||||
|
||||
// Update the memo
|
||||
newContent := "Updated content"
|
||||
err := tc.Store.UpdateMemo(tc.Ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
Content: &newContent,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test: updated_ts >= now() - 60 (updated in last minute)
|
||||
memos := tc.ListWithFilter(`updated_ts >= now() - 60`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: updated_ts > now() + 3600 (should be empty)
|
||||
memos = tc.ListWithFilter(`updated_ts > now() + 3600`)
|
||||
require.Len(t, memos, 0)
|
||||
}
|
||||
|
||||
func TestMemoFilterAllComparisonOperators(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-ops", tc.User.ID).Content("Comparison operators test"))
|
||||
|
||||
// Test: < (less than)
|
||||
memos := tc.ListWithFilter(`created_ts < now() + 3600`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: <= (less than or equal) with buffer for clock drift
|
||||
memos = tc.ListWithFilter(`created_ts < now() + 5`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: > (greater than)
|
||||
memos = tc.ListWithFilter(`created_ts > now() - 3600`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: >= (greater than or equal)
|
||||
memos = tc.ListWithFilter(`created_ts >= now() - 60`)
|
||||
require.Len(t, memos, 1)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Logical Operator Tests
|
||||
// Operators: && (AND), || (OR), ! (NOT)
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterLogicalAnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
pinnedMemo := tc.CreateMemo(NewMemoBuilder("memo-pinned-public", tc.User.ID).Content("Pinned public"))
|
||||
tc.PinMemo(pinnedMemo.ID)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-unpinned-public", tc.User.ID).Content("Unpinned public"))
|
||||
|
||||
memos := tc.ListWithFilter(`pinned && visibility == "PUBLIC"`)
|
||||
require.Len(t, memos, 1)
|
||||
require.True(t, memos[0].Pinned)
|
||||
}
|
||||
|
||||
func TestMemoFilterLogicalOr(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-public", tc.User.ID).Visibility(store.Public))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-private", tc.User.ID).Visibility(store.Private))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-protected", tc.User.ID).Visibility(store.Protected))
|
||||
|
||||
memos := tc.ListWithFilter(`visibility == "PUBLIC" || visibility == "PRIVATE"`)
|
||||
require.Len(t, memos, 2)
|
||||
}
|
||||
|
||||
func TestMemoFilterLogicalNot(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
pinnedMemo := tc.CreateMemo(NewMemoBuilder("memo-pinned", tc.User.ID).Content("Pinned"))
|
||||
tc.PinMemo(pinnedMemo.ID)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-unpinned", tc.User.ID).Content("Unpinned"))
|
||||
|
||||
memos := tc.ListWithFilter(`!pinned`)
|
||||
require.Len(t, memos, 1)
|
||||
require.False(t, memos[0].Pinned)
|
||||
}
|
||||
|
||||
func TestMemoFilterNegatedComparison(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-public", tc.User.ID).Visibility(store.Public))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-private", tc.User.ID).Visibility(store.Private))
|
||||
|
||||
memos := tc.ListWithFilter(`!(visibility == "PUBLIC")`)
|
||||
require.Len(t, memos, 1)
|
||||
require.Equal(t, store.Private, memos[0].Visibility)
|
||||
}
|
||||
|
||||
func TestMemoFilterComplexLogical(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create pinned public memo with tags
|
||||
pinnedMemo := tc.CreateMemo(NewMemoBuilder("memo-pinned-tagged", tc.User.ID).
|
||||
Content("Pinned and tagged").Tags("important"))
|
||||
tc.PinMemo(pinnedMemo.ID)
|
||||
|
||||
// Create unpinned memo with same tag
|
||||
tc.CreateMemo(NewMemoBuilder("memo-unpinned-tagged", tc.User.ID).
|
||||
Content("Unpinned but tagged").Tags("important"))
|
||||
|
||||
// Create pinned memo without tag
|
||||
pinned2 := tc.CreateMemo(NewMemoBuilder("memo-pinned-untagged", tc.User.ID).Content("Pinned but untagged"))
|
||||
tc.PinMemo(pinned2.ID)
|
||||
|
||||
// Test: pinned && tag in ["important"]
|
||||
memos := tc.ListWithFilter(`pinned && tag in ["important"]`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Test: (pinned || tag in ["important"]) && visibility == "PUBLIC"
|
||||
memos = tc.ListWithFilter(`(pinned || tag in ["important"]) && visibility == "PUBLIC"`)
|
||||
require.Len(t, memos, 3)
|
||||
|
||||
// Test: De Morgan's Law ! (A || B) == !A && !B
|
||||
// ! (pinned || has_task_list)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-no-props", tc.User.ID).Content("Nothing special"))
|
||||
memos = tc.ListWithFilter(`!(pinned || has_task_list)`)
|
||||
require.Len(t, memos, 2) // Unpinned-tagged + Nothing special (pinned-untagged is pinned)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tag Comprehension Tests (exists macro)
|
||||
// Schema: tags (list of strings, supports exists/all macros with predicates)
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterTagsExistsStartsWith(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memos with different tags
|
||||
tc.CreateMemo(NewMemoBuilder("memo-archive1", tc.User.ID).
|
||||
Content("Archived project memo").
|
||||
Tags("archive/project", "done"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-archive2", tc.User.ID).
|
||||
Content("Archived work memo").
|
||||
Tags("archive/work", "old"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-active", tc.User.ID).
|
||||
Content("Active project memo").
|
||||
Tags("project/active", "todo"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-homelab", tc.User.ID).
|
||||
Content("Homelab memo").
|
||||
Tags("homelab/memos", "tech"))
|
||||
|
||||
// Test: tags.exists(t, t.startsWith("archive")) - should match archived memos
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("archive"))`)
|
||||
require.Len(t, memos, 2, "Should find 2 archived memos")
|
||||
for _, memo := range memos {
|
||||
hasArchiveTag := false
|
||||
for _, tag := range memo.Payload.Tags {
|
||||
if len(tag) >= 7 && tag[:7] == "archive" {
|
||||
hasArchiveTag = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasArchiveTag, "Memo should have tag starting with 'archive'")
|
||||
}
|
||||
|
||||
// Test: !tags.exists(t, t.startsWith("archive")) - should match non-archived memos
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t.startsWith("archive"))`)
|
||||
require.Len(t, memos, 2, "Should find 2 non-archived memos")
|
||||
|
||||
// Test: tags.exists(t, t.startsWith("project")) - should match project memos
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("project"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 project memo")
|
||||
|
||||
// Test: tags.exists(t, t.startsWith("homelab")) - should match homelab memos
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("homelab"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 homelab memo")
|
||||
|
||||
// Test: tags.exists(t, t.startsWith("nonexistent")) - should match nothing
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("nonexistent"))`)
|
||||
require.Len(t, memos, 0, "Should find no memos")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsContains(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memos with different tags
|
||||
tc.CreateMemo(NewMemoBuilder("memo-todo1", tc.User.ID).
|
||||
Content("Todo task 1").
|
||||
Tags("project/todo", "urgent"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-todo2", tc.User.ID).
|
||||
Content("Todo task 2").
|
||||
Tags("work/todo-list", "pending"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-done", tc.User.ID).
|
||||
Content("Done task").
|
||||
Tags("project/completed", "done"))
|
||||
|
||||
// Test: tags.exists(t, t.contains("todo")) - should match todos
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t.contains("todo"))`)
|
||||
require.Len(t, memos, 2, "Should find 2 todo memos")
|
||||
|
||||
// Test: tags.exists(t, t.contains("done")) - should match done
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.contains("done"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 done memo")
|
||||
|
||||
// Test: !tags.exists(t, t.contains("todo")) - should exclude todos
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t.contains("todo"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 non-todo memo")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsEndsWith(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memos with different tag endings
|
||||
tc.CreateMemo(NewMemoBuilder("memo-bug", tc.User.ID).
|
||||
Content("Bug report").
|
||||
Tags("project/bug", "critical"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-debug", tc.User.ID).
|
||||
Content("Debug session").
|
||||
Tags("work/debug", "dev"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-feature", tc.User.ID).
|
||||
Content("New feature").
|
||||
Tags("project/feature", "new"))
|
||||
|
||||
// Test: tags.exists(t, t.endsWith("bug")) - should match bug-related tags
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t.endsWith("bug"))`)
|
||||
require.Len(t, memos, 2, "Should find 2 bug-related memos")
|
||||
|
||||
// Test: tags.exists(t, t.endsWith("feature")) - should match feature
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.endsWith("feature"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 feature memo")
|
||||
|
||||
// Test: !tags.exists(t, t.endsWith("bug")) - should exclude bug-related
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t.endsWith("bug"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 non-bug memo")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsCombinedWithOtherFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memos with tags and other properties
|
||||
tc.CreateMemo(NewMemoBuilder("memo-archived-old", tc.User.ID).
|
||||
Content("Old archived memo").
|
||||
Tags("archive/old", "done"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-archived-recent", tc.User.ID).
|
||||
Content("Recent archived memo with TODO").
|
||||
Tags("archive/recent", "done"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-active-todo", tc.User.ID).
|
||||
Content("Active TODO").
|
||||
Tags("project/active", "todo"))
|
||||
|
||||
// Test: Combine tag filter with content filter
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) && content.contains("TODO")`)
|
||||
require.Len(t, memos, 1, "Should find 1 archived memo with TODO in content")
|
||||
|
||||
// Test: OR condition with tag filters
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) || tags.exists(t, t.contains("todo"))`)
|
||||
require.Len(t, memos, 3, "Should find all memos (archived or with todo tag)")
|
||||
|
||||
// Test: Complex filter - archived but not containing "Recent"
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) && !content.contains("Recent")`)
|
||||
require.Len(t, memos, 1, "Should find 1 old archived memo")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsEmptyAndNullCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memo with no tags
|
||||
tc.CreateMemo(NewMemoBuilder("memo-no-tags", tc.User.ID).
|
||||
Content("Memo without tags"))
|
||||
|
||||
// Create memo with tags
|
||||
tc.CreateMemo(NewMemoBuilder("memo-with-tags", tc.User.ID).
|
||||
Content("Memo with tags").
|
||||
Tags("tag1", "tag2"))
|
||||
|
||||
// Test: tags.exists should not match memos without tags
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("tag"))`)
|
||||
require.Len(t, memos, 1, "Should only find memo with tags")
|
||||
|
||||
// Test: Negation should match memos without matching tags
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t.startsWith("tag"))`)
|
||||
require.Len(t, memos, 1, "Should find memo without matching tags")
|
||||
}
|
||||
|
||||
func TestMemoFilterIssue5480_ArchiveWorkflow(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create a realistic scenario as described in issue #5480
|
||||
// User has hierarchical tags and archives memos by prefixing with "archive"
|
||||
|
||||
// Active memos
|
||||
tc.CreateMemo(NewMemoBuilder("memo-homelab", tc.User.ID).
|
||||
Content("Setting up Memos").
|
||||
Tags("homelab/memos", "tech"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-project-alpha", tc.User.ID).
|
||||
Content("Project Alpha notes").
|
||||
Tags("work/project-alpha", "active"))
|
||||
|
||||
// Archived memos (user prefixed tags with "archive")
|
||||
tc.CreateMemo(NewMemoBuilder("memo-old-homelab", tc.User.ID).
|
||||
Content("Old homelab setup").
|
||||
Tags("archive/homelab/old-server", "done"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-old-project", tc.User.ID).
|
||||
Content("Old project beta").
|
||||
Tags("archive/work/project-beta", "completed"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-archived-personal", tc.User.ID).
|
||||
Content("Archived personal note").
|
||||
Tags("archive/personal/2024", "old"))
|
||||
|
||||
// Test: Filter out ALL archived memos using startsWith
|
||||
memos := tc.ListWithFilter(`!tags.exists(t, t.startsWith("archive"))`)
|
||||
require.Len(t, memos, 2, "Should only show active memos (not archived)")
|
||||
for _, memo := range memos {
|
||||
for _, tag := range memo.Payload.Tags {
|
||||
require.NotContains(t, tag, "archive", "Active memos should not have archive prefix")
|
||||
}
|
||||
}
|
||||
|
||||
// Test: Show ONLY archived memos
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive"))`)
|
||||
require.Len(t, memos, 3, "Should find all archived memos")
|
||||
for _, memo := range memos {
|
||||
hasArchiveTag := false
|
||||
for _, tag := range memo.Payload.Tags {
|
||||
if len(tag) >= 7 && tag[:7] == "archive" {
|
||||
hasArchiveTag = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasArchiveTag, "All returned memos should have archive prefix")
|
||||
}
|
||||
|
||||
// Test: Filter archived homelab memos specifically
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive/homelab"))`)
|
||||
require.Len(t, memos, 1, "Should find only archived homelab memos")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Multiple Filters Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterMultipleFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-public-hello", tc.User.ID).Content("Hello world").Visibility(store.Public))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-private-hello", tc.User.ID).Content("Hello private").Visibility(store.Private))
|
||||
|
||||
// Test: Multiple filters (applied as AND)
|
||||
memos := tc.ListWithFilters(`content.contains("Hello")`, `visibility == "PUBLIC"`)
|
||||
require.Len(t, memos, 1)
|
||||
require.Contains(t, memos[0].Content, "Hello")
|
||||
require.Equal(t, store.Public, memos[0].Visibility)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Edge Cases
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterNullPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-null-payload", tc.User.ID).Content("Null payload"))
|
||||
|
||||
// Test: has_link should not crash and return no results
|
||||
memos := tc.ListWithFilter(`has_link`)
|
||||
require.Len(t, memos, 0)
|
||||
}
|
||||
|
||||
func TestMemoFilterNoMatches(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-test", tc.User.ID).Content("Test content"))
|
||||
|
||||
memos := tc.ListWithFilter(`content.contains("nonexistent12345")`)
|
||||
require.Len(t, memos, 0)
|
||||
}
|
||||
|
||||
func TestMemoFilterJSONBooleanLogic(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// 1. Memo with task list (true) and NO link (null)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-task-only", tc.User.ID).
|
||||
Content("Task only").
|
||||
Property(func(p *storepb.MemoPayload_Property) { p.HasTaskList = true }))
|
||||
|
||||
// 2. Memo with link (true) and NO task list (null)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-link-only", tc.User.ID).
|
||||
Content("Link only").
|
||||
Property(func(p *storepb.MemoPayload_Property) { p.HasLink = true }))
|
||||
|
||||
// 3. Memo with both (true)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-both", tc.User.ID).
|
||||
Content("Both").
|
||||
Property(func(p *storepb.MemoPayload_Property) {
|
||||
p.HasTaskList = true
|
||||
p.HasLink = true
|
||||
}))
|
||||
|
||||
// 4. Memo with neither (null)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-neither", tc.User.ID).Content("Neither"))
|
||||
|
||||
// Test A: has_task_list || has_link
|
||||
// Expected: 3 memos (task-only, link-only, both). Neither should be excluded.
|
||||
// This specifically tests the NULL handling in OR logic (NULL || TRUE should be TRUE)
|
||||
memos := tc.ListWithFilter(`has_task_list || has_link`)
|
||||
require.Len(t, memos, 3, "Should find 3 memos with OR logic")
|
||||
|
||||
// Test B: !has_task_list
|
||||
// Expected: 2 memos (link-only, neither). Memos where has_task_list is NULL or FALSE.
|
||||
// Note: If NULL is not handled, !NULL is still NULL (false-y in WHERE), so "neither" might be missed depending on logic.
|
||||
// In our implementation, we want missing fields to behave as false.
|
||||
memos = tc.ListWithFilter(`!has_task_list`)
|
||||
require.Len(t, memos, 2, "Should find 2 memos where task list is false or missing")
|
||||
|
||||
// Test C: has_task_list && !has_link
|
||||
// Expected: 1 memo (task-only).
|
||||
memos = tc.ListWithFilter(`has_task_list && !has_link`)
|
||||
require.Len(t, memos, 1, "Should find 1 memo (task only)")
|
||||
}
|
||||
682
store/test/memo_relation_test.go
Normal file
682
store/test/memo_relation_test.go
Normal file
@@ -0,0 +1,682 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestMemoRelationStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
memoCreate := &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
}
|
||||
memo, err := ts.CreateMemo(ctx, memoCreate)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memoCreate.Content, memo.Content)
|
||||
relatedMemoCreate := &store.Memo{
|
||||
UID: "related-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo content",
|
||||
Visibility: store.Public,
|
||||
}
|
||||
relatedMemo, err := ts.CreateMemo(ctx, relatedMemoCreate)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, relatedMemoCreate.Content, relatedMemo.Content)
|
||||
commentMemoCreate := &store.Memo{
|
||||
UID: "comment-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "comment memo content",
|
||||
Visibility: store.Public,
|
||||
}
|
||||
commentMemo, err := ts.CreateMemo(ctx, commentMemoCreate)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, commentMemoCreate.Content, commentMemo.Content)
|
||||
|
||||
// Reference relation.
|
||||
referenceRelation := &store.MemoRelation{
|
||||
MemoID: memo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
}
|
||||
_, err = ts.UpsertMemoRelation(ctx, referenceRelation)
|
||||
require.NoError(t, err)
|
||||
// Comment relation.
|
||||
commentRelation := &store.MemoRelation{
|
||||
MemoID: memo.ID,
|
||||
RelatedMemoID: commentMemo.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
}
|
||||
_, err = ts.UpsertMemoRelation(ctx, commentRelation)
|
||||
require.NoError(t, err)
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListByMemoID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create main memo
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create related memos
|
||||
relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo 1 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo 2 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relations
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo1.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo2.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by memo ID
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(relations))
|
||||
|
||||
// List by type
|
||||
refType := store.MemoRelationReference
|
||||
refRelations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
Type: &refType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(refRelations))
|
||||
require.Equal(t, store.MemoRelationReference, refRelations[0].Type)
|
||||
|
||||
// List by related memo ID
|
||||
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
RelatedMemoID: &relatedMemo1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(relations))
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create memos
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relation
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify relation exists
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(relations))
|
||||
|
||||
// Delete relation by memo ID
|
||||
relType := store.MemoRelationReference
|
||||
err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
RelatedMemoID: &relatedMemo.ID,
|
||||
Type: &relType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify relation is deleted
|
||||
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(relations))
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationDifferentTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create reference relation
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create comment relation (same memos, different type - should be allowed)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify both relations exist
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(relations))
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationUpsertSameRelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relation
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Upsert the same relation again (should not create duplicate)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only one relation exists
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationDeleteByType(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo 1 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo 2 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create reference relations
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo1.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create comment relation
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo2.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete only reference type relations
|
||||
refType := store.MemoRelationReference
|
||||
err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
Type: &refType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only comment relation remains
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
require.Equal(t, store.MemoRelationComment, relations[0].Type)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationDeleteByMemoID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memo1, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo 1 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memo2, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo 2 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relations for both memos
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memo1.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memo2.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete all relations for memo1
|
||||
err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
MemoID: &memo1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify memo1's relations are gone
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memo1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 0)
|
||||
|
||||
// Verify memo2's relations still exist
|
||||
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memo2.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListByRelatedMemoID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a memo that will be referenced by others
|
||||
targetMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "target-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "target memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create memos that reference the target
|
||||
referrer1, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "referrer-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "referrer 1 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
referrer2, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "referrer-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "referrer 2 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relations pointing to target
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: referrer1.ID,
|
||||
RelatedMemoID: targetMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: referrer2.ID,
|
||||
RelatedMemoID: targetMemo.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by related memo ID (find all memos that reference the target)
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
RelatedMemoID: &targetMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 2)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListCombinedFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo 1 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo 2 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple relations
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo1.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo2.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List with MemoID and Type filter
|
||||
refType := store.MemoRelationReference
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
Type: &refType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
require.Equal(t, relatedMemo1.ID, relations[0].RelatedMemoID)
|
||||
|
||||
// List with MemoID, RelatedMemoID, and Type filter
|
||||
commentType := store.MemoRelationComment
|
||||
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
RelatedMemoID: &relatedMemo2.ID,
|
||||
Type: &commentType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-no-relations",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo with no relations",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List relations for memo with none
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 0)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationBidirectional(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memoA, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-a",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo A content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoB, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-b",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo B content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relation A -> B
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoA.ID,
|
||||
RelatedMemoID: memoB.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relation B -> A (reverse direction)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoB.ID,
|
||||
RelatedMemoID: memoA.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify A -> B exists
|
||||
relationsFromA, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memoA.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relationsFromA, 1)
|
||||
require.Equal(t, memoB.ID, relationsFromA[0].RelatedMemoID)
|
||||
|
||||
// Verify B -> A exists
|
||||
relationsFromB, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memoB.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relationsFromB, 1)
|
||||
require.Equal(t, memoA.ID, relationsFromB[0].RelatedMemoID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationMultipleRelationsToSameMemo(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple memos that all relate to the main memo
|
||||
for i := 1; i <= 5; i++ {
|
||||
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-" + string(rune('0'+i)),
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify all 5 relations exist
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 5)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
478
store/test/memo_test.go
Normal file
478
store/test/memo_test.go
Normal file
@@ -0,0 +1,478 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
func TestMemoStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
memoCreate := &store.Memo{
|
||||
UID: "test-resource-name",
|
||||
CreatorID: user.ID,
|
||||
Content: "test_content",
|
||||
Visibility: store.Public,
|
||||
}
|
||||
memo, err := ts.CreateMemo(ctx, memoCreate)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memoCreate.Content, memo.Content)
|
||||
memoPatchContent := "test_content_2"
|
||||
memoPatch := &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
Content: &memoPatchContent,
|
||||
}
|
||||
err = ts.UpdateMemo(ctx, memoPatch)
|
||||
require.NoError(t, err)
|
||||
memo, err = ts.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &memo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
memoList, err := ts.ListMemos(ctx, &store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(memoList))
|
||||
require.Equal(t, memo, memoList[0])
|
||||
err = ts.DeleteMemo(ctx, &store.DeleteMemo{
|
||||
ID: memo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
memoList, err = ts.ListMemos(ctx, &store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(memoList))
|
||||
|
||||
memoList, err = ts.ListMemos(ctx, &store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
VisibilityList: []store.Visibility{store.Public},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(memoList))
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoListByTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
memoCreate := &store.Memo{
|
||||
UID: "test-resource-name",
|
||||
CreatorID: user.ID,
|
||||
Content: "test_content",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: []string{"test_tag"},
|
||||
},
|
||||
}
|
||||
memo, err := ts.CreateMemo(ctx, memoCreate)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memoCreate.Content, memo.Content)
|
||||
memo, err = ts.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &memo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
memoList, err := ts.ListMemos(ctx, &store.FindMemo{
|
||||
Filters: []string{"tag in [\"test_tag\"]"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(memoList))
|
||||
require.Equal(t, memo, memoList[0])
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestDeleteMemoStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
memoCreate := &store.Memo{
|
||||
UID: "test-resource-name",
|
||||
CreatorID: user.ID,
|
||||
Content: "test_content",
|
||||
Visibility: store.Public,
|
||||
}
|
||||
memo, err := ts.CreateMemo(ctx, memoCreate)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memoCreate.Content, memo.Content)
|
||||
err = ts.DeleteMemo(ctx, &store.DeleteMemo{
|
||||
ID: memo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoGetByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "test content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by ID
|
||||
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, memo.ID, found.ID)
|
||||
require.Equal(t, memo.Content, found.Content)
|
||||
|
||||
// Get non-existent
|
||||
nonExistentID := int32(99999)
|
||||
notFound, err := ts.GetMemo(ctx, &store.FindMemo{ID: &nonExistentID})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoGetByUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
uid := "unique-memo-uid"
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: uid,
|
||||
CreatorID: user.ID,
|
||||
Content: "test content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by UID
|
||||
found, err := ts.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, memo.UID, found.UID)
|
||||
|
||||
// Get non-existent UID
|
||||
nonExistentUID := "non-existent-uid"
|
||||
notFound, err := ts.GetMemo(ctx, &store.FindMemo{UID: &nonExistentUID})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoListByVisibility(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create memos with different visibilities
|
||||
_, err = ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "public-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "public content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "protected-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "protected content",
|
||||
Visibility: store.Protected,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "private-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "private content",
|
||||
Visibility: store.Private,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List public memos only
|
||||
publicMemos, err := ts.ListMemos(ctx, &store.FindMemo{
|
||||
VisibilityList: []store.Visibility{store.Public},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(publicMemos))
|
||||
require.Equal(t, store.Public, publicMemos[0].Visibility)
|
||||
|
||||
// List protected memos only
|
||||
protectedMemos, err := ts.ListMemos(ctx, &store.FindMemo{
|
||||
VisibilityList: []store.Visibility{store.Protected},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(protectedMemos))
|
||||
require.Equal(t, store.Protected, protectedMemos[0].Visibility)
|
||||
|
||||
// List public and protected (multiple visibility)
|
||||
publicAndProtected, err := ts.ListMemos(ctx, &store.FindMemo{
|
||||
VisibilityList: []store.Visibility{store.Public, store.Protected},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(publicAndProtected))
|
||||
|
||||
// List all
|
||||
allMemos, err := ts.ListMemos(ctx, &store.FindMemo{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(allMemos))
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoListWithPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create 10 memos
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: fmt.Sprintf("memo-%d", i),
|
||||
CreatorID: user.ID,
|
||||
Content: fmt.Sprintf("content %d", i),
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test limit
|
||||
limit := 5
|
||||
limitedMemos, err := ts.ListMemos(ctx, &store.FindMemo{Limit: &limit})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, len(limitedMemos))
|
||||
|
||||
// Test offset
|
||||
offset := 3
|
||||
offsetMemos, err := ts.ListMemos(ctx, &store.FindMemo{Limit: &limit, Offset: &offset})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, len(offsetMemos))
|
||||
|
||||
// Verify offset works correctly (different memos)
|
||||
require.NotEqual(t, limitedMemos[0].ID, offsetMemos[0].ID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoUpdatePinned(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "pinnable-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, memo.Pinned)
|
||||
|
||||
// Pin the memo
|
||||
pinned := true
|
||||
err = ts.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
Pinned: &pinned,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify pinned
|
||||
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
|
||||
require.NoError(t, err)
|
||||
require.True(t, found.Pinned)
|
||||
|
||||
// Unpin
|
||||
unpinned := false
|
||||
err = ts.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
Pinned: &unpinned,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err = ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
|
||||
require.NoError(t, err)
|
||||
require.False(t, found.Pinned)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoUpdateVisibility(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "visibility-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.Public, memo.Visibility)
|
||||
|
||||
// Change to private
|
||||
privateVisibility := store.Private
|
||||
err = ts.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
Visibility: &privateVisibility,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.Private, found.Visibility)
|
||||
|
||||
// Change to protected
|
||||
protectedVisibility := store.Protected
|
||||
err = ts.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
Visibility: &protectedVisibility,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err = ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.Protected, found.Visibility)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoInvalidUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create memo with invalid UID (contains special characters)
|
||||
_, err = ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "invalid uid with spaces",
|
||||
CreatorID: user.ID,
|
||||
Content: "content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid uid")
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoCreateWithCustomTimestamps(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
customCreatedTs := int64(1700000000) // 2023-11-14 22:13:20 UTC
|
||||
customUpdatedTs := int64(1700000001)
|
||||
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "custom-timestamp-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "content with custom timestamps",
|
||||
Visibility: store.Public,
|
||||
CreatedTs: customCreatedTs,
|
||||
UpdatedTs: customUpdatedTs,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, customCreatedTs, memo.CreatedTs)
|
||||
require.Equal(t, customUpdatedTs, memo.UpdatedTs)
|
||||
|
||||
// Fetch and verify timestamps are preserved
|
||||
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, customCreatedTs, found.CreatedTs)
|
||||
require.Equal(t, customUpdatedTs, found.UpdatedTs)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoCreateWithOnlyCreatedTs(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
customCreatedTs := int64(1609459200) // 2021-01-01 00:00:00 UTC
|
||||
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "custom-created-ts-only",
|
||||
CreatorID: user.ID,
|
||||
Content: "content with custom created_ts only",
|
||||
Visibility: store.Public,
|
||||
CreatedTs: customCreatedTs,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, customCreatedTs, memo.CreatedTs)
|
||||
|
||||
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, customCreatedTs, found.CreatedTs)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoWithPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create memo with tags in payload
|
||||
tags := []string{"tag1", "tag2", "tag3"}
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-with-payload",
|
||||
CreatorID: user.ID,
|
||||
Content: "content with tags",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: tags,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo.Payload)
|
||||
require.Equal(t, tags, memo.Payload.Tags)
|
||||
|
||||
// Fetch and verify
|
||||
found, err := ts.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found.Payload)
|
||||
require.Equal(t, tags, found.Payload.Tags)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
200
store/test/migrator_test.go
Normal file
200
store/test/migrator_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// TestFreshInstall verifies that LATEST.sql applies correctly on a fresh database.
|
||||
// This is essentially what NewTestingStore already does, but we make it explicit.
|
||||
func TestFreshInstall(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
// NewTestingStore creates a fresh database and runs Migrate()
|
||||
// which applies LATEST.sql for uninitialized databases
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Verify migration completed successfully
|
||||
currentSchemaVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, currentSchemaVersion, "schema version should be set after fresh install")
|
||||
|
||||
// Verify we can read instance settings (basic sanity check)
|
||||
instanceSetting, err := ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, currentSchemaVersion, instanceSetting.SchemaVersion)
|
||||
}
|
||||
|
||||
// TestMigrationReRun verifies that re-running the migration on an already
|
||||
// migrated database does not fail or cause issues. This simulates a
|
||||
// scenario where the server is restarted.
|
||||
func TestMigrationReRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
// Use the shared testing store which already runs migrations on init
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Get current version
|
||||
initialVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually trigger migration again
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "re-running migration should not fail")
|
||||
|
||||
// Verify version hasn't changed (or at least is valid)
|
||||
finalVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialVersion, finalVersion, "version should match after re-run")
|
||||
}
|
||||
|
||||
// TestMigrationWithData verifies that migration preserves data integrity.
|
||||
// Creates data, then re-runs migration and verifies data is still accessible.
|
||||
func TestMigrationWithData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create a user and memo before re-running migration
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err, "should create user")
|
||||
|
||||
originalMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "migration-data-test",
|
||||
CreatorID: user.ID,
|
||||
Content: "Data before migration re-run",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err, "should create memo")
|
||||
|
||||
// Re-run migration
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "re-running migration should not fail")
|
||||
|
||||
// Verify data is still accessible
|
||||
memo, err := ts.GetMemo(ctx, &store.FindMemo{UID: &originalMemo.UID})
|
||||
require.NoError(t, err, "should retrieve memo after migration")
|
||||
require.Equal(t, "Data before migration re-run", memo.Content, "memo content should be preserved")
|
||||
}
|
||||
|
||||
// TestMigrationMultipleReRuns verifies that migration is idempotent
|
||||
// even when run multiple times in succession.
|
||||
func TestMigrationMultipleReRuns(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Get initial version
|
||||
initialVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run migration multiple times
|
||||
for i := 0; i < 3; i++ {
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "migration run %d should not fail", i+1)
|
||||
}
|
||||
|
||||
// Verify version is still correct
|
||||
finalVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialVersion, finalVersion, "version should remain unchanged after multiple re-runs")
|
||||
}
|
||||
|
||||
// TestMigrationFromStableVersion verifies that upgrading from a stable Memos version
|
||||
// to the current version works correctly. This is the critical upgrade path test.
|
||||
//
|
||||
// Test flow:
|
||||
// 1. Start a stable Memos container to create a database with the old schema
|
||||
// 2. Stop the container and wait for cleanup
|
||||
// 3. Use the store directly to run migration with current code
|
||||
// 4. Verify the migration succeeded and data can be written
|
||||
//
|
||||
// Note: This test is skipped when running with -race flag because testcontainers
|
||||
// has known race conditions in its reaper code that are outside our control.
|
||||
func TestMigrationFromStableVersion(t *testing.T) {
|
||||
// Skip for non-SQLite drivers (simplifies the test)
|
||||
if getDriverFromEnv() != "sqlite" {
|
||||
t.Skip("skipping upgrade test for non-sqlite driver")
|
||||
}
|
||||
|
||||
// Skip if explicitly disabled (e.g., in environments without Docker)
|
||||
if os.Getenv("SKIP_CONTAINER_TESTS") == "1" {
|
||||
t.Skip("skipping container-based test (SKIP_CONTAINER_TESTS=1)")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
// 1. Start stable Memos container to create database with old schema
|
||||
cfg := MemosContainerConfig{
|
||||
Driver: "sqlite",
|
||||
DataDir: dataDir,
|
||||
Version: StableMemosVersion,
|
||||
}
|
||||
|
||||
t.Logf("Starting Memos %s container to create old-schema database...", cfg.Version)
|
||||
container, err := StartMemosContainer(ctx, cfg)
|
||||
require.NoError(t, err, "failed to start stable memos container")
|
||||
|
||||
// Wait for the container to fully initialize the database
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
// Stop the container gracefully
|
||||
t.Log("Stopping stable Memos container...")
|
||||
err = container.Terminate(ctx)
|
||||
require.NoError(t, err, "failed to stop memos container")
|
||||
|
||||
// Wait for file handles to be released
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 2. Connect to the database directly and run migration with current code
|
||||
dsn := fmt.Sprintf("%s/memos_prod.db", dataDir)
|
||||
t.Logf("Connecting to database at %s...", dsn)
|
||||
|
||||
ts := NewTestingStoreWithDSN(ctx, t, "sqlite", dsn)
|
||||
|
||||
// Get the schema version before migration
|
||||
oldSetting, err := ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Old schema version: %s", oldSetting.SchemaVersion)
|
||||
|
||||
// 3. Run migration with current code
|
||||
t.Log("Running migration with current code...")
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "migration from stable version should succeed")
|
||||
|
||||
// 4. Verify migration succeeded
|
||||
newVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
t.Logf("New schema version: %s", newVersion)
|
||||
|
||||
newSetting, err := ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newVersion, newSetting.SchemaVersion, "schema version should be updated")
|
||||
|
||||
// Verify we can write data to the migrated database
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err, "should create user after migration")
|
||||
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "post-upgrade-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "Content after upgrade from stable",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err, "should create memo after migration")
|
||||
require.Equal(t, "Content after upgrade from stable", memo.Content)
|
||||
|
||||
t.Logf("Migration successful: %s -> %s", oldSetting.SchemaVersion, newVersion)
|
||||
}
|
||||
189
store/test/reaction_test.go
Normal file
189
store/test/reaction_test.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestReactionStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentID := "test_content_id"
|
||||
reaction, err := ts.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user.ID,
|
||||
ContentID: contentID,
|
||||
ReactionType: "💗",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reaction)
|
||||
require.NotEmpty(t, reaction.ID)
|
||||
|
||||
reactions, err := ts.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &contentID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, reactions, 1)
|
||||
require.Equal(t, reaction, reactions[0])
|
||||
|
||||
// Test GetReaction.
|
||||
gotReaction, err := ts.GetReaction(ctx, &store.FindReaction{
|
||||
ID: &reaction.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gotReaction)
|
||||
require.Equal(t, reaction.ID, gotReaction.ID)
|
||||
require.Equal(t, reaction.CreatorID, gotReaction.CreatorID)
|
||||
require.Equal(t, reaction.ContentID, gotReaction.ContentID)
|
||||
require.Equal(t, reaction.ReactionType, gotReaction.ReactionType)
|
||||
|
||||
// Test GetReaction with non-existent ID.
|
||||
nonExistentID := int32(99999)
|
||||
notFoundReaction, err := ts.GetReaction(ctx, &store.FindReaction{
|
||||
ID: &nonExistentID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFoundReaction)
|
||||
|
||||
err = ts.DeleteReaction(ctx, &store.DeleteReaction{
|
||||
ID: reaction.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
reactions, err = ts.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &contentID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, reactions, 0)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestReactionListByCreatorID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentID := "shared_content"
|
||||
|
||||
// User 1 creates reaction
|
||||
_, err = ts.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user1.ID,
|
||||
ContentID: contentID,
|
||||
ReactionType: "👍",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// User 2 creates reaction
|
||||
_, err = ts.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user2.ID,
|
||||
ContentID: contentID,
|
||||
ReactionType: "❤️",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List all reactions for content
|
||||
reactions, err := ts.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &contentID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, reactions, 2)
|
||||
|
||||
// List by creator ID
|
||||
user1Reactions, err := ts.ListReactions(ctx, &store.FindReaction{
|
||||
CreatorID: &user1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, user1Reactions, 1)
|
||||
require.Equal(t, "👍", user1Reactions[0].ReactionType)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestReactionMultipleContentIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentID1 := "content_1"
|
||||
contentID2 := "content_2"
|
||||
|
||||
// Create reactions for different contents
|
||||
_, err = ts.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user.ID,
|
||||
ContentID: contentID1,
|
||||
ReactionType: "👍",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user.ID,
|
||||
ContentID: contentID2,
|
||||
ReactionType: "❤️",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by content ID list
|
||||
reactions, err := ts.ListReactions(ctx, &store.FindReaction{
|
||||
ContentIDList: []string{contentID1, contentID2},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, reactions, 2)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestReactionUpsertDifferentTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentID := "test_content"
|
||||
|
||||
// Create first reaction
|
||||
reaction1, err := ts.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user.ID,
|
||||
ContentID: contentID,
|
||||
ReactionType: "👍",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create second reaction with different type (should create new, not update)
|
||||
reaction2, err := ts.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user.ID,
|
||||
ContentID: contentID,
|
||||
ReactionType: "❤️",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both reactions should exist
|
||||
require.NotEqual(t, reaction1.ID, reaction2.ID)
|
||||
|
||||
reactions, err := ts.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &contentID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, reactions, 2)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
111
store/test/store.go
Normal file
111
store/test/store.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
// sqlite driver.
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/internal/version"
|
||||
"github.com/usememos/memos/store"
|
||||
"github.com/usememos/memos/store/db"
|
||||
)
|
||||
|
||||
// NewTestingStore creates a new testing store with a fresh database.
|
||||
// Each test gets its own isolated database:
|
||||
// - SQLite: new temp file per test
|
||||
// - MySQL/PostgreSQL: new database per test in shared container
|
||||
func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
|
||||
driver := getDriverFromEnv()
|
||||
profile := getTestingProfileForDriver(t, driver)
|
||||
dbDriver, err := db.NewDBDriver(profile)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db driver: %v", err)
|
||||
}
|
||||
|
||||
store := store.New(dbDriver, profile)
|
||||
if err := store.Migrate(ctx); err != nil {
|
||||
t.Fatalf("failed to migrate db: %v", err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
// NewTestingStoreWithDSN creates a testing store connected to a specific DSN.
|
||||
// This is useful for testing migrations on existing data.
|
||||
func NewTestingStoreWithDSN(_ context.Context, t *testing.T, driver, dsn string) *store.Store {
|
||||
profile := &profile.Profile{
|
||||
Port: getUnusedPort(),
|
||||
Data: t.TempDir(), // Dummy dir, DSN matters
|
||||
DSN: dsn,
|
||||
Driver: driver,
|
||||
Version: version.GetCurrentVersion(),
|
||||
}
|
||||
dbDriver, err := db.NewDBDriver(profile)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db driver: %v", err)
|
||||
}
|
||||
|
||||
store := store.New(dbDriver, profile)
|
||||
// Do not run Migrate() automatically, as we might be testing pre-migration state
|
||||
// or want to run it manually.
|
||||
return store
|
||||
}
|
||||
|
||||
func getUnusedPort() int {
|
||||
// Get a random unused port
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
// Get the port number
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
return port
|
||||
}
|
||||
|
||||
// getTestingProfileForDriver creates a testing profile for a specific driver.
|
||||
func getTestingProfileForDriver(t *testing.T, driver string) *profile.Profile {
|
||||
// Attempt to load .env file if present (optional, for local development)
|
||||
_ = godotenv.Load(".env")
|
||||
|
||||
// Get a temporary directory for the test data.
|
||||
dir := t.TempDir()
|
||||
mode := "prod"
|
||||
port := getUnusedPort()
|
||||
|
||||
var dsn string
|
||||
switch driver {
|
||||
case "sqlite":
|
||||
dsn = fmt.Sprintf("%s/memos_%s.db", dir, mode)
|
||||
case "mysql":
|
||||
dsn = GetMySQLDSN(t)
|
||||
case "postgres":
|
||||
dsn = GetPostgresDSN(t)
|
||||
default:
|
||||
t.Fatalf("unsupported driver: %s", driver)
|
||||
}
|
||||
|
||||
return &profile.Profile{
|
||||
Port: port,
|
||||
Data: dir,
|
||||
DSN: dsn,
|
||||
Driver: driver,
|
||||
Version: version.GetCurrentVersion(),
|
||||
}
|
||||
}
|
||||
|
||||
func getDriverFromEnv() string {
|
||||
driver := os.Getenv("DRIVER")
|
||||
if driver == "" {
|
||||
driver = "sqlite"
|
||||
}
|
||||
return driver
|
||||
}
|
||||
993
store/test/user_setting_test.go
Normal file
993
store/test/user_setting_test.go
Normal file
@@ -0,0 +1,993 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestUserSettingStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "en"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
list, err := ts.ListUserSettings(ctx, &store.FindUserSetting{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(list))
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create setting
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "zh"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by user ID
|
||||
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.Equal(t, "zh", setting.GetGeneral().Locale)
|
||||
|
||||
// Get non-existent key
|
||||
nonExistentSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, nonExistentSetting)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingUpsertUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create initial setting
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "en"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update setting
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "fr"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify update
|
||||
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fr", setting.GetGeneral().Locale)
|
||||
|
||||
// Verify only one setting exists
|
||||
list, err := ts.ListUserSettings(ctx, &store.FindUserSetting{UserID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(list))
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingRefreshTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially no tokens
|
||||
tokens, err := ts.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tokens)
|
||||
|
||||
// Add a refresh token
|
||||
token1 := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: "token-1",
|
||||
Description: "Chrome browser session",
|
||||
}
|
||||
err = ts.AddUserRefreshToken(ctx, user.ID, token1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify token was added
|
||||
tokens, err = ts.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tokens, 1)
|
||||
require.Equal(t, "token-1", tokens[0].TokenId)
|
||||
|
||||
// Add another token
|
||||
token2 := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: "token-2",
|
||||
Description: "Firefox browser session",
|
||||
}
|
||||
err = ts.AddUserRefreshToken(ctx, user.ID, token2)
|
||||
require.NoError(t, err)
|
||||
|
||||
tokens, err = ts.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tokens, 2)
|
||||
|
||||
// Get specific token by ID
|
||||
foundToken, err := ts.GetUserRefreshTokenByID(ctx, user.ID, "token-1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, foundToken)
|
||||
require.Equal(t, "Chrome browser session", foundToken.Description)
|
||||
|
||||
// Get non-existent token
|
||||
notFound, err := ts.GetUserRefreshTokenByID(ctx, user.ID, "non-existent")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
|
||||
// Remove token
|
||||
err = ts.RemoveUserRefreshToken(ctx, user.ID, "token-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokens, err = ts.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tokens, 1)
|
||||
require.Equal(t, "token-2", tokens[0].TokenId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingPersonalAccessTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially no PATs
|
||||
pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, pats)
|
||||
|
||||
// Add a PAT
|
||||
pat1 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-1",
|
||||
TokenHash: "pat-hash-1",
|
||||
Description: "API Token for external access",
|
||||
}
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify PAT was added
|
||||
pats, err = ts.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pats, 1)
|
||||
require.Equal(t, "API Token for external access", pats[0].Description)
|
||||
|
||||
// Add another PAT
|
||||
pat2 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-2",
|
||||
TokenHash: "pat-hash-2",
|
||||
Description: "CI Token",
|
||||
}
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat2)
|
||||
require.NoError(t, err)
|
||||
|
||||
pats, err = ts.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pats, 2)
|
||||
|
||||
// Remove PAT
|
||||
err = ts.RemoveUserPersonalAccessToken(ctx, user.ID, "pat-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
pats, err = ts.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pats, 1)
|
||||
require.Equal(t, "pat-2", pats[0].TokenId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingWebhooks(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially no webhooks
|
||||
webhooks, err := ts.GetUserWebhooks(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, webhooks)
|
||||
|
||||
// Add a webhook
|
||||
webhook1 := &storepb.WebhooksUserSetting_Webhook{
|
||||
Id: "webhook-1",
|
||||
Title: "Deploy Hook",
|
||||
Url: "https://example.com/webhook",
|
||||
}
|
||||
err = ts.AddUserWebhook(ctx, user.ID, webhook1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify webhook was added
|
||||
webhooks, err = ts.GetUserWebhooks(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, webhooks, 1)
|
||||
require.Equal(t, "Deploy Hook", webhooks[0].Title)
|
||||
|
||||
// Update webhook
|
||||
webhook1Updated := &storepb.WebhooksUserSetting_Webhook{
|
||||
Id: "webhook-1",
|
||||
Title: "Updated Deploy Hook",
|
||||
Url: "https://example.com/webhook/v2",
|
||||
}
|
||||
err = ts.UpdateUserWebhook(ctx, user.ID, webhook1Updated)
|
||||
require.NoError(t, err)
|
||||
|
||||
webhooks, err = ts.GetUserWebhooks(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, webhooks, 1)
|
||||
require.Equal(t, "Updated Deploy Hook", webhooks[0].Title)
|
||||
require.Equal(t, "https://example.com/webhook/v2", webhooks[0].Url)
|
||||
|
||||
// Add another webhook
|
||||
webhook2 := &storepb.WebhooksUserSetting_Webhook{
|
||||
Id: "webhook-2",
|
||||
Title: "Notification Hook",
|
||||
Url: "https://slack.example.com/webhook",
|
||||
}
|
||||
err = ts.AddUserWebhook(ctx, user.ID, webhook2)
|
||||
require.NoError(t, err)
|
||||
|
||||
webhooks, err = ts.GetUserWebhooks(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, webhooks, 2)
|
||||
|
||||
// Remove webhook
|
||||
err = ts.RemoveUserWebhook(ctx, user.ID, "webhook-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
webhooks, err = ts.GetUserWebhooks(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, webhooks, 1)
|
||||
require.Equal(t, "webhook-2", webhooks[0].Id)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingShortcuts(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create shortcuts setting
|
||||
shortcuts := &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "shortcut-1", Title: "Work Notes", Filter: "tag:work"},
|
||||
{Id: "shortcut-2", Title: "Personal", Filter: "tag:personal"},
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve and verify
|
||||
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.Len(t, setting.GetShortcuts().Shortcuts, 2)
|
||||
require.Equal(t, "Work Notes", setting.GetShortcuts().Shortcuts[0].Title)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a PAT with a known hash
|
||||
patHash := "test-pat-hash-12345"
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-test-1",
|
||||
TokenHash: patHash,
|
||||
Description: "Test PAT for lookup",
|
||||
}
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Lookup user by PAT hash
|
||||
result, err := ts.GetUserByPATHash(ctx, patHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, user.ID, result.UserID)
|
||||
require.NotNil(t, result.User)
|
||||
require.Equal(t, user.Username, result.User.Username)
|
||||
require.NotNil(t, result.PAT)
|
||||
require.Equal(t, "pat-test-1", result.PAT.TokenId)
|
||||
require.Equal(t, "Test PAT for lookup", result.PAT.Description)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
_, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Lookup non-existent PAT hash
|
||||
result, err := ts.GetUserByPATHash(ctx, "non-existent-hash")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashNoTokensKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User exists but has no PERSONAL_ACCESS_TOKENS key at all
|
||||
// This simulates fresh users or users upgraded from v0.25.3
|
||||
result, err := ts.GetUserByPATHash(ctx, "any-hash")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
// Error could be "PAT not found" (Postgres) or "sql: no rows in result set" (SQLite/MySQL)
|
||||
require.True(t,
|
||||
strings.Contains(err.Error(), "PAT not found") || strings.Contains(err.Error(), "no rows"),
|
||||
"expected PAT not found or no rows error, got: %v", err)
|
||||
|
||||
// Now add a PAT for the user
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-new",
|
||||
TokenHash: "hash-new",
|
||||
Description: "New PAT",
|
||||
}
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now the lookup should succeed
|
||||
result, err = ts.GetUserByPATHash(ctx, "hash-new")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, user.ID, result.UserID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashEmptyTokensArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a PAT setting with empty tokens array
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS,
|
||||
Value: &storepb.UserSetting_PersonalAccessTokens{
|
||||
PersonalAccessTokens: &storepb.PersonalAccessTokensUserSetting{
|
||||
Tokens: []*storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Lookup should fail gracefully, not crash
|
||||
result, err := ts.GetUserByPATHash(ctx, "any-hash")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
// Error could be "PAT not found" (Postgres) or "sql: no rows in result set" (SQLite/MySQL)
|
||||
require.True(t,
|
||||
strings.Contains(err.Error(), "PAT not found") || strings.Contains(err.Error(), "no rows"),
|
||||
"expected PAT not found or no rows error, got: %v", err)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashWithOtherUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create multiple users - some with PATs, some without
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
user3, err := createTestingUserWithRole(ctx, ts, "user3", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User1: Has PAT
|
||||
pat1Hash := "user1-pat-hash-unique"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user1.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-user1",
|
||||
TokenHash: pat1Hash,
|
||||
Description: "User 1 PAT",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// User2: Has no PERSONAL_ACCESS_TOKENS key (fresh user)
|
||||
// User3: Has empty tokens array
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user3.ID,
|
||||
Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS,
|
||||
Value: &storepb.UserSetting_PersonalAccessTokens{
|
||||
PersonalAccessTokens: &storepb.PersonalAccessTokensUserSetting{
|
||||
Tokens: []*storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should find user1's PAT despite user2 having no key and user3 having empty array
|
||||
result, err := ts.GetUserByPATHash(ctx, pat1Hash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, user1.ID, result.UserID)
|
||||
require.Equal(t, "pat-user1", result.PAT.TokenId)
|
||||
|
||||
// Should not find non-existent hash even with mixed user states
|
||||
result, err = ts.GetUserByPATHash(ctx, "non-existent")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashMultipleUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create PATs for both users
|
||||
pat1Hash := "user1-pat-hash"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user1.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-user1",
|
||||
TokenHash: pat1Hash,
|
||||
Description: "User 1 PAT",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pat2Hash := "user2-pat-hash"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user2.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-user2",
|
||||
TokenHash: pat2Hash,
|
||||
Description: "User 2 PAT",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Lookup user1's PAT
|
||||
result1, err := ts.GetUserByPATHash(ctx, pat1Hash)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user1.ID, result1.UserID)
|
||||
require.Equal(t, user1.Username, result1.User.Username)
|
||||
|
||||
// Lookup user2's PAT
|
||||
result2, err := ts.GetUserByPATHash(ctx, pat2Hash)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user2.ID, result2.UserID)
|
||||
require.Equal(t, user2.Username, result2.User.Username)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashMultiplePATsSameUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple PATs for the same user
|
||||
pat1Hash := "first-pat-hash"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-1",
|
||||
TokenHash: pat1Hash,
|
||||
Description: "First PAT",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pat2Hash := "second-pat-hash"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-2",
|
||||
TokenHash: pat2Hash,
|
||||
Description: "Second PAT",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both PATs should resolve to the same user
|
||||
result1, err := ts.GetUserByPATHash(ctx, pat1Hash)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user.ID, result1.UserID)
|
||||
require.Equal(t, "pat-1", result1.PAT.TokenId)
|
||||
|
||||
result2, err := ts.GetUserByPATHash(ctx, pat2Hash)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user.ID, result2.UserID)
|
||||
require.Equal(t, "pat-2", result2.PAT.TokenId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingUpdatePATLastUsed(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a PAT
|
||||
patHash := "pat-hash-for-update"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-update-test",
|
||||
TokenHash: patHash,
|
||||
Description: "PAT for update test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update last used timestamp
|
||||
now := timestamppb.Now()
|
||||
err = ts.UpdatePATLastUsed(ctx, user.ID, "pat-update-test", now)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the update
|
||||
pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pats, 1)
|
||||
require.NotNil(t, pats[0].LastUsedAt)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashWithExpiredToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a PAT with expiration info
|
||||
patHash := "pat-hash-with-expiry"
|
||||
expiresAt := timestamppb.Now()
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-expiry-test",
|
||||
TokenHash: patHash,
|
||||
Description: "PAT with expiry",
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should still be able to look up by hash (expiry check is done at auth level)
|
||||
result, err := ts.GetUserByPATHash(ctx, patHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, user.ID, result.UserID)
|
||||
require.NotNil(t, result.PAT.ExpiresAt)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashAfterRemoval(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a PAT
|
||||
patHash := "pat-hash-to-remove"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-remove-test",
|
||||
TokenHash: patHash,
|
||||
Description: "PAT to be removed",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it exists
|
||||
result, err := ts.GetUserByPATHash(ctx, patHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Remove the PAT
|
||||
err = ts.RemoveUserPersonalAccessToken(ctx, user.ID, "pat-remove-test")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should no longer be found
|
||||
result, err = ts.GetUserByPATHash(ctx, patHash)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashSpecialCharacters(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create PATs with special characters in hash (simulating real hash values)
|
||||
testCases := []struct {
|
||||
tokenID string
|
||||
tokenHash string
|
||||
}{
|
||||
{"pat-special-1", "abc123+/=XYZ"},
|
||||
{"pat-special-2", "sha256:abcdef1234567890"},
|
||||
{"pat-special-3", "$2a$10$N9qo8uLOickgx2ZMRZoMy"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tc.tokenID,
|
||||
TokenHash: tc.tokenHash,
|
||||
Description: "PAT with special chars",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify lookup works with special characters
|
||||
result, err := ts.GetUserByPATHash(ctx, tc.tokenHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tc.tokenID, result.PAT.TokenId)
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashLargeTokenCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create many PATs for the same user
|
||||
tokenCount := 10
|
||||
hashes := make([]string, tokenCount)
|
||||
for i := 0; i < tokenCount; i++ {
|
||||
hashes[i] = "pat-hash-" + string(rune('A'+i)) + "-large-test"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-large-" + string(rune('A'+i)),
|
||||
TokenHash: hashes[i],
|
||||
Description: "PAT for large count test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify each hash can be looked up
|
||||
for i, hash := range hashes {
|
||||
result, err := ts.GetUserByPATHash(ctx, hash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, user.ID, result.UserID)
|
||||
require.Equal(t, "pat-large-"+string(rune('A'+i)), result.PAT.TokenId)
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingMultipleSettingTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create GENERAL setting
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "ja"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create SHORTCUTS setting
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s1", Title: "Shortcut 1"},
|
||||
},
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a PAT
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-multi",
|
||||
TokenHash: "hash-multi",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List all settings for user
|
||||
settings, err := ts.ListUserSettings(ctx, &store.FindUserSetting{UserID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, settings, 3)
|
||||
|
||||
// Verify each setting type
|
||||
generalSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_GENERAL})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ja", generalSetting.GetGeneral().Locale)
|
||||
|
||||
shortcutsSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_SHORTCUTS})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, shortcutsSetting.GetShortcuts().Shortcuts, 1)
|
||||
|
||||
patsSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, patsSetting.GetPersonalAccessTokens().Tokens, 1)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingShortcutsEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Case 1: Special characters in Filter and Title
|
||||
// Includes quotes, backslashes, newlines, and other JSON-sensitive characters
|
||||
specialCharsFilter := `tag in ["work", "project"] && content.contains("urgent")`
|
||||
specialCharsTitle := `Work "Urgent" \ Notes`
|
||||
shortcuts := &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s1", Title: specialCharsTitle, Filter: specialCharsFilter},
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.Len(t, setting.GetShortcuts().Shortcuts, 1)
|
||||
require.Equal(t, specialCharsTitle, setting.GetShortcuts().Shortcuts[0].Title)
|
||||
require.Equal(t, specialCharsFilter, setting.GetShortcuts().Shortcuts[0].Filter)
|
||||
|
||||
// Case 2: Unicode characters
|
||||
unicodeFilter := `tag in ["你好", "世界"]`
|
||||
unicodeTitle := `My 🚀 Shortcuts`
|
||||
shortcuts = &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s2", Title: unicodeTitle, Filter: unicodeFilter},
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.Len(t, setting.GetShortcuts().Shortcuts, 1)
|
||||
require.Equal(t, unicodeTitle, setting.GetShortcuts().Shortcuts[0].Title)
|
||||
require.Equal(t, unicodeFilter, setting.GetShortcuts().Shortcuts[0].Filter)
|
||||
|
||||
// Case 3: Empty shortcuts list
|
||||
// Should allow saving an empty list (clearing shortcuts)
|
||||
shortcuts = &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.NotNil(t, setting.GetShortcuts())
|
||||
require.Len(t, setting.GetShortcuts().Shortcuts, 0)
|
||||
|
||||
// Case 4: Large filter string
|
||||
// Test reasonable large string handling (e.g. 4KB)
|
||||
largeFilter := strings.Repeat("tag:long_tag_name ", 200)
|
||||
shortcuts = &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s3", Title: "Large Filter", Filter: largeFilter},
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.Equal(t, largeFilter, setting.GetShortcuts().Shortcuts[0].Filter)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingShortcutsPartialUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initial set
|
||||
shortcuts := &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s1", Title: "Note 1", Filter: "tag:1"},
|
||||
{Id: "s2", Title: "Note 2", Filter: "tag:2"},
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update by replacing the whole list (Store Upsert replaces the value for the key)
|
||||
// We want to verify that we can "update" a single item by sending the modified list
|
||||
updatedShortcuts := &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s1", Title: "Note 1 Updated", Filter: "tag:1_updated"},
|
||||
{Id: "s2", Title: "Note 2", Filter: "tag:2"},
|
||||
{Id: "s3", Title: "Note 3", Filter: "tag:3"}, // Add new one
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: updatedShortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.Len(t, setting.GetShortcuts().Shortcuts, 3)
|
||||
|
||||
// Verify updates
|
||||
for _, s := range setting.GetShortcuts().Shortcuts {
|
||||
if s.Id == "s1" {
|
||||
require.Equal(t, "Note 1 Updated", s.Title)
|
||||
require.Equal(t, "tag:1_updated", s.Filter)
|
||||
} else if s.Id == "s2" {
|
||||
require.Equal(t, "Note 2", s.Title)
|
||||
} else if s.Id == "s3" {
|
||||
require.Equal(t, "Note 3", s.Title)
|
||||
}
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingJSONFieldsEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Case 1: Webhook with special characters and Unicode in Title and URL
|
||||
specialWebhook := &storepb.WebhooksUserSetting_Webhook{
|
||||
Id: "wh-special",
|
||||
Title: `My "Special" & <Webhook> 🚀`,
|
||||
Url: "https://example.com/hook?query=你好¶m=\"value\"",
|
||||
}
|
||||
err = ts.AddUserWebhook(ctx, user.ID, specialWebhook)
|
||||
require.NoError(t, err)
|
||||
|
||||
webhooks, err := ts.GetUserWebhooks(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, webhooks, 1)
|
||||
require.Equal(t, specialWebhook.Title, webhooks[0].Title)
|
||||
require.Equal(t, specialWebhook.Url, webhooks[0].Url)
|
||||
|
||||
// Case 2: PAT with special description
|
||||
specialPAT := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-special",
|
||||
TokenHash: "hash-special",
|
||||
Description: "Token for 'CLI' \n & \"API\" \t with unicode 🔑",
|
||||
}
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, specialPAT)
|
||||
require.NoError(t, err)
|
||||
|
||||
pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pats, 1)
|
||||
require.Equal(t, specialPAT.Description, pats[0].Description)
|
||||
|
||||
// Case 3: Refresh Token with special description
|
||||
specialRefreshToken := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: "rt-special",
|
||||
Description: "Browser: Firefox (Nightly) / OS: Linux 🐧",
|
||||
}
|
||||
err = ts.AddUserRefreshToken(ctx, user.ID, specialRefreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
tokens, err := ts.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tokens, 1)
|
||||
require.Equal(t, specialRefreshToken.Description, tokens[0].Description)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
256
store/test/user_test.go
Normal file
256
store/test/user_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestUserStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
users, err := ts.ListUsers(ctx, &store.FindUser{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(users))
|
||||
require.Equal(t, store.RoleAdmin, users[0].Role)
|
||||
require.Equal(t, user, users[0])
|
||||
userPatchNickname := "test_nickname_2"
|
||||
userPatch := &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
Nickname: &userPatchNickname,
|
||||
}
|
||||
user, err = ts.UpdateUser(ctx, userPatch)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userPatchNickname, user.Nickname)
|
||||
err = ts.DeleteUser(ctx, &store.DeleteUser{
|
||||
ID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
users, err = ts.ListUsers(ctx, &store.FindUser{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(users))
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserGetByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get user by ID
|
||||
found, err := ts.GetUser(ctx, &store.FindUser{ID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, user.ID, found.ID)
|
||||
require.Equal(t, user.Username, found.Username)
|
||||
|
||||
// Get non-existent user
|
||||
nonExistentID := int32(99999)
|
||||
notFound, err := ts.GetUser(ctx, &store.FindUser{ID: &nonExistentID})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
|
||||
// Get system bot
|
||||
systemBotID := store.SystemBotID
|
||||
systemBot, err := ts.GetUser(ctx, &store.FindUser{ID: &systemBotID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, systemBot)
|
||||
require.Equal(t, store.SystemBotID, systemBot.ID)
|
||||
require.Equal(t, "system_bot", systemBot.Username)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserGetByUsername(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get user by username
|
||||
found, err := ts.GetUser(ctx, &store.FindUser{Username: &user.Username})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, user.Username, found.Username)
|
||||
|
||||
// Get non-existent username
|
||||
nonExistent := "nonexistent"
|
||||
notFound, err := ts.GetUser(ctx, &store.FindUser{Username: &nonExistent})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserListByRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create users with different roles
|
||||
_, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = createTestingUserWithRole(ctx, ts, "admin_user", store.RoleAdmin)
|
||||
require.NoError(t, err)
|
||||
|
||||
regularUser, err := createTestingUserWithRole(ctx, ts, "regular_user", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List all users
|
||||
allUsers, err := ts.ListUsers(ctx, &store.FindUser{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(allUsers))
|
||||
|
||||
// List only ADMIN users
|
||||
adminRole := store.RoleAdmin
|
||||
adminOnlyUsers, err := ts.ListUsers(ctx, &store.FindUser{Role: &adminRole})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(adminOnlyUsers))
|
||||
|
||||
// List only USER role users
|
||||
userRole := store.RoleUser
|
||||
regularUsers, err := ts.ListUsers(ctx, &store.FindUser{Role: &userRole})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(regularUsers))
|
||||
require.Equal(t, regularUser.ID, regularUsers[0].ID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserUpdateRowStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.Normal, user.RowStatus)
|
||||
|
||||
// Archive user
|
||||
archivedStatus := store.Archived
|
||||
updated, err := ts.UpdateUser(ctx, &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
RowStatus: &archivedStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.Archived, updated.RowStatus)
|
||||
|
||||
// Verify by fetching
|
||||
fetched, err := ts.GetUser(ctx, &store.FindUser{ID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.Archived, fetched.RowStatus)
|
||||
|
||||
// Restore to normal
|
||||
normalStatus := store.Normal
|
||||
restored, err := ts.UpdateUser(ctx, &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
RowStatus: &normalStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.Normal, restored.RowStatus)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserUpdateAllFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update all fields
|
||||
newUsername := "updated_username"
|
||||
newEmail := "updated@test.com"
|
||||
newNickname := "Updated Nickname"
|
||||
newAvatarURL := "https://example.com/avatar.png"
|
||||
newDescription := "Updated description"
|
||||
newRole := store.RoleAdmin
|
||||
newPasswordHash := "new_password_hash"
|
||||
|
||||
updated, err := ts.UpdateUser(ctx, &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
Username: &newUsername,
|
||||
Email: &newEmail,
|
||||
Nickname: &newNickname,
|
||||
AvatarURL: &newAvatarURL,
|
||||
Description: &newDescription,
|
||||
Role: &newRole,
|
||||
PasswordHash: &newPasswordHash,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newUsername, updated.Username)
|
||||
require.Equal(t, newEmail, updated.Email)
|
||||
require.Equal(t, newNickname, updated.Nickname)
|
||||
require.Equal(t, newAvatarURL, updated.AvatarURL)
|
||||
require.Equal(t, newDescription, updated.Description)
|
||||
require.Equal(t, newRole, updated.Role)
|
||||
require.Equal(t, newPasswordHash, updated.PasswordHash)
|
||||
|
||||
// Verify by fetching again
|
||||
fetched, err := ts.GetUser(ctx, &store.FindUser{ID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newUsername, fetched.Username)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserListWithLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create 5 users
|
||||
for i := 0; i < 5; i++ {
|
||||
role := store.RoleUser
|
||||
if i == 0 {
|
||||
role = store.RoleAdmin
|
||||
}
|
||||
_, err := createTestingUserWithRole(ctx, ts, fmt.Sprintf("user%d", i), role)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// List with limit
|
||||
limit := 3
|
||||
users, err := ts.ListUsers(ctx, &store.FindUser{Limit: &limit})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, error) {
|
||||
return createTestingUserWithRole(ctx, ts, "test", store.RoleAdmin)
|
||||
}
|
||||
|
||||
func createTestingUserWithRole(ctx context.Context, ts *store.Store, username string, role store.Role) (*store.User, error) {
|
||||
userCreate := &store.User{
|
||||
Username: username,
|
||||
Role: role,
|
||||
Email: username + "@test.com",
|
||||
Nickname: username + "_nickname",
|
||||
Description: username + "_description",
|
||||
}
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte("test_password"), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userCreate.PasswordHash = string(passwordHash)
|
||||
user, err := ts.CreateUser(ctx, userCreate)
|
||||
return user, err
|
||||
}
|
||||
Reference in New Issue
Block a user