first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
This commit is contained in:
263
server/router/api/v1/test/activity_deleted_memo_test.go
Normal file
263
server/router/api/v1/test/activity_deleted_memo_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
v1 "github.com/usememos/memos/server/router/api/v1" //nolint:revive
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// TestListActivitiesWithDeletedMemos verifies that ListActivities gracefully handles
|
||||
// activities that reference deleted memos instead of crashing the entire request.
|
||||
func TestListActivitiesWithDeletedMemos(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users - one to create memo, one to comment
|
||||
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
|
||||
require.NoError(t, err)
|
||||
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
|
||||
|
||||
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
|
||||
require.NoError(t, err)
|
||||
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
|
||||
|
||||
// Create a memo by userOne
|
||||
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Original memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo1)
|
||||
|
||||
// Create a comment on the memo by userTwo (this will create an activity for userOne)
|
||||
comment, err := ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: memo1.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "This is a comment",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, comment)
|
||||
|
||||
// Verify activity was created for the comment (check from userOne's perspective - they receive the notification)
|
||||
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
initialActivityCount := len(activities.Activities)
|
||||
require.Greater(t, initialActivityCount, 0, "Should have at least one activity")
|
||||
|
||||
// Delete the original memo (this deletes the comment too)
|
||||
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
|
||||
Name: memo1.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List activities again - should succeed even though the memo is deleted
|
||||
activities, err = ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
// Activities list should be empty or not contain the deleted memo activity
|
||||
for _, activity := range activities.Activities {
|
||||
if activity.Payload != nil && activity.Payload.GetMemoComment() != nil {
|
||||
require.NotEqual(t, memo1.Name, activity.Payload.GetMemoComment().Memo,
|
||||
"Activity should not reference deleted memo")
|
||||
}
|
||||
}
|
||||
// After deletion, there should be fewer activities
|
||||
require.LessOrEqual(t, len(activities.Activities), initialActivityCount-1,
|
||||
"Should have filtered out the activity for the deleted memo")
|
||||
}
|
||||
|
||||
// TestGetActivityWithDeletedMemo verifies that GetActivity returns a proper error
|
||||
// when trying to fetch an activity that references a deleted memo.
|
||||
func TestGetActivityWithDeletedMemo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
|
||||
require.NoError(t, err)
|
||||
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
|
||||
|
||||
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
|
||||
require.NoError(t, err)
|
||||
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
|
||||
|
||||
// Create a memo by userOne
|
||||
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Original memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo1)
|
||||
|
||||
// Create a comment to trigger activity creation by userTwo
|
||||
comment, err := ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: memo1.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "Comment",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, comment)
|
||||
|
||||
// Get the activity ID by listing activities from userOne's perspective
|
||||
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(activities.Activities), 0)
|
||||
|
||||
activityName := activities.Activities[0].Name
|
||||
|
||||
// Delete the memo
|
||||
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
|
||||
Name: memo1.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get the specific activity - should return NotFound error
|
||||
_, err = ts.Service.GetActivity(userOneCtx, &apiv1.GetActivityRequest{
|
||||
Name: activityName,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "activity references deleted content")
|
||||
}
|
||||
|
||||
// TestActivitiesWithPartiallyDeletedMemos verifies that when some memos are deleted,
|
||||
// other valid activities are still returned.
|
||||
func TestActivitiesWithPartiallyDeletedMemos(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
|
||||
require.NoError(t, err)
|
||||
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
|
||||
|
||||
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
|
||||
require.NoError(t, err)
|
||||
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
|
||||
|
||||
// Create two memos by userOne
|
||||
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "First memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memo2, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Second memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create comments on both by userTwo (creates activities for userOne)
|
||||
_, err = ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: memo1.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "Comment on first",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: memo2.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "Comment on second",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have 2 activities from userOne's perspective
|
||||
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(activities.Activities))
|
||||
|
||||
// Delete first memo
|
||||
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
|
||||
Name: memo1.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List activities - should still work and return only the second memo's activity
|
||||
activities, err = ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(activities.Activities), "Should have 1 activity remaining")
|
||||
|
||||
// Verify the remaining activity relates to a valid memo
|
||||
require.NotNil(t, activities.Activities[0].Payload.GetMemoComment())
|
||||
require.Contains(t, activities.Activities[0].Payload.GetMemoComment().RelatedMemo, "memos/")
|
||||
}
|
||||
|
||||
// TestActivityStoreDirectDeletion tests the scenario where a memo is deleted directly
|
||||
// from the store (simulating database-level deletion or migration).
|
||||
func TestActivityStoreDirectDeletion(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "test-user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a memo
|
||||
memo1, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a comment
|
||||
comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: memo1.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "Test comment",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract memo UID from the comment name
|
||||
commentMemoUID, err := v1.ExtractMemoUIDFromName(comment.Name)
|
||||
require.NoError(t, err)
|
||||
|
||||
commentMemo, err := ts.Store.GetMemo(ctx, &store.FindMemo{
|
||||
UID: &commentMemoUID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, commentMemo)
|
||||
|
||||
// Delete the comment memo directly from store (simulating orphaned activity)
|
||||
err = ts.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: commentMemo.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List activities should still succeed even with orphaned activity
|
||||
activities, err := ts.Service.ListActivities(userCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
// Activities should be empty or not include the orphaned one
|
||||
for _, activity := range activities.Activities {
|
||||
if activity.Payload != nil && activity.Payload.GetMemoComment() != nil {
|
||||
require.NotEqual(t, comment.Name, activity.Payload.GetMemoComment().Memo,
|
||||
"Should not return activity with deleted memo")
|
||||
}
|
||||
}
|
||||
}
|
||||
1
server/router/api/v1/test/assets/1772542534_test.png
Normal file
1
server/router/api/v1/test/assets/1772542534_test.png
Normal file
@@ -0,0 +1 @@
|
||||
fake png content
|
||||
1
server/router/api/v1/test/assets/1772542535_test.png
Normal file
1
server/router/api/v1/test/assets/1772542535_test.png
Normal file
@@ -0,0 +1 @@
|
||||
fake png content
|
||||
2
server/router/api/v1/test/assets/1772542535_test.unknown
Normal file
2
server/router/api/v1/test/assets/1772542535_test.unknown
Normal file
@@ -0,0 +1,2 @@
|
||||
‰PNG
|
||||
|
||||
|
After Width: | Height: | Size: 8 B |
BIN
server/router/api/v1/test/assets/1772542536_test.data
Normal file
BIN
server/router/api/v1/test/assets/1772542536_test.data
Normal file
Binary file not shown.
59
server/router/api/v1/test/attachment_service_test.go
Normal file
59
server/router/api/v1/test/attachment_service_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestCreateAttachment(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "test_user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Test case 1: Create attachment with empty type but known extension
|
||||
t.Run("EmptyType_KnownExtension", func(t *testing.T) {
|
||||
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
|
||||
Attachment: &v1pb.Attachment{
|
||||
Filename: "test.png",
|
||||
Content: []byte("fake png content"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "image/png", attachment.Type)
|
||||
})
|
||||
|
||||
// Test case 2: Create attachment with empty type and unknown extension, but detectable content
|
||||
t.Run("EmptyType_UnknownExtension_ContentSniffing", func(t *testing.T) {
|
||||
// PNG magic header: 89 50 4E 47 0D 0A 1A 0A
|
||||
pngContent := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
|
||||
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
|
||||
Attachment: &v1pb.Attachment{
|
||||
Filename: "test.unknown",
|
||||
Content: pngContent,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "image/png", attachment.Type)
|
||||
})
|
||||
|
||||
// Test case 3: Empty type, unknown extension, random content -> fallback to application/octet-stream
|
||||
t.Run("EmptyType_Fallback", func(t *testing.T) {
|
||||
randomContent := []byte{0x00, 0x01, 0x02, 0x03}
|
||||
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
|
||||
Attachment: &v1pb.Attachment{
|
||||
Filename: "test.data",
|
||||
Content: randomContent,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "application/octet-stream", attachment.Type)
|
||||
})
|
||||
}
|
||||
655
server/router/api/v1/test/auth_test.go
Normal file
655
server/router/api/v1/test/auth_test.go
Normal file
@@ -0,0 +1,655 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestAuthenticatorAccessTokenV2(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("authenticates valid access token v2", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate access token v2
|
||||
token, _, err := auth.GenerateAccessTokenV2(
|
||||
user.ID,
|
||||
user.Username,
|
||||
string(user.Role),
|
||||
string(user.RowStatus),
|
||||
[]byte(ts.Secret),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
claims, err := authenticator.AuthenticateByAccessTokenV2(token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, claims)
|
||||
assert.Equal(t, user.ID, claims.UserID)
|
||||
assert.Equal(t, user.Username, claims.Username)
|
||||
assert.Equal(t, string(user.Role), claims.Role)
|
||||
assert.Equal(t, string(user.RowStatus), claims.Status)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, err := authenticator.AuthenticateByAccessTokenV2("invalid-token")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with wrong secret", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate token with one secret
|
||||
token, _, err := auth.GenerateAccessTokenV2(
|
||||
user.ID,
|
||||
user.Username,
|
||||
string(user.Role),
|
||||
string(user.RowStatus),
|
||||
[]byte("secret-1"),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate with different secret
|
||||
authenticator := auth.NewAuthenticator(ts.Store, "secret-2")
|
||||
_, err = authenticator.AuthenticateByAccessTokenV2(token)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthenticatorRefreshToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("authenticates valid refresh token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create refresh token record in store
|
||||
tokenID := util.GenUUID()
|
||||
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(auth.RefreshTokenDuration)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate refresh token JWT
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
authenticatedUser, returnedTokenID, err := authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.Equal(t, user.ID, authenticatedUser.ID)
|
||||
assert.Equal(t, tokenID, returnedTokenID)
|
||||
})
|
||||
|
||||
t.Run("fails with revoked token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
// Generate refresh token JWT but don't store it in database (simulates revocation)
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "revoked")
|
||||
})
|
||||
|
||||
t.Run("fails with expired token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create expired refresh token record in store
|
||||
tokenID := util.GenUUID()
|
||||
expiredToken := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), // Expired
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, expiredToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate refresh token JWT (JWT itself isn't expired yet)
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expired")
|
||||
})
|
||||
|
||||
t.Run("fails with archived user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create valid refresh token
|
||||
tokenID := util.GenUUID()
|
||||
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(auth.RefreshTokenDuration)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the user
|
||||
archivedStatus := store.Archived
|
||||
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
RowStatus: &archivedStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "archived")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthenticatorPAT(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("authenticates valid PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate PAT
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
// Store PAT in database
|
||||
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
authenticatedUser, pat, err := authenticator.AuthenticateByPAT(ctx, token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.NotNil(t, pat)
|
||||
assert.Equal(t, user.ID, authenticatedUser.ID)
|
||||
assert.Equal(t, tokenID, pat.TokenId)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid PAT format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err := authenticator.AuthenticateByPAT(ctx, "invalid-token-without-prefix")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid PAT format")
|
||||
})
|
||||
|
||||
t.Run("fails with non-existent PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Generate a PAT but don't store it
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err := authenticator.AuthenticateByPAT(ctx, token)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with expired PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate and store expired PAT
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
expiredPAT := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Expired PAT",
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), // Expired
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, expiredPAT)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByPAT(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expired")
|
||||
})
|
||||
|
||||
t.Run("succeeds with non-expiring PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate and store PAT without expiration
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Never-expiring PAT",
|
||||
ExpiresAt: nil, // No expiration
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
authenticatedUser, pat, err := authenticator.AuthenticateByPAT(ctx, token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.NotNil(t, pat)
|
||||
assert.Nil(t, pat.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("fails with archived user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate and store PAT
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the user
|
||||
archivedStatus := store.Archived
|
||||
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
RowStatus: &archivedStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByPAT(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "archived")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreRefreshTokenMethods(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("adds and retrieves refresh token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
token := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve tokens
|
||||
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 1)
|
||||
assert.Equal(t, tokenID, tokens[0].TokenId)
|
||||
})
|
||||
|
||||
t.Run("retrieves specific refresh token by ID", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
token := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve specific token
|
||||
retrievedToken, err := ts.Store.GetUserRefreshTokenByID(ctx, user.ID, tokenID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, retrievedToken)
|
||||
assert.Equal(t, tokenID, retrievedToken.TokenId)
|
||||
})
|
||||
|
||||
t.Run("removes refresh token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
token := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove token
|
||||
err = ts.Store.RemoveUserRefreshToken(ctx, user.ID, tokenID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify removal
|
||||
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 0)
|
||||
})
|
||||
|
||||
t.Run("handles multiple refresh tokens", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add multiple tokens
|
||||
tokenID1 := util.GenUUID()
|
||||
tokenID2 := util.GenUUID()
|
||||
|
||||
token1 := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID1,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
token2 := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID2,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token1)
|
||||
require.NoError(t, err)
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve all tokens
|
||||
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 2)
|
||||
|
||||
// Remove one token
|
||||
err = ts.Store.RemoveUserRefreshToken(ctx, user.ID, tokenID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only one token remains
|
||||
tokens, err = ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 1)
|
||||
assert.Equal(t, tokenID2, tokens[0].TokenId)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorePersonalAccessTokenMethods(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("adds and retrieves PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve PATs
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 1)
|
||||
assert.Equal(t, tokenID, pats[0].TokenId)
|
||||
assert.Equal(t, tokenHash, pats[0].TokenHash)
|
||||
})
|
||||
|
||||
t.Run("removes PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove PAT
|
||||
err = ts.Store.RemoveUserPersonalAccessToken(ctx, user.ID, tokenID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify removal
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
})
|
||||
|
||||
t.Run("updates PAT last used time", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update last used time
|
||||
lastUsed := timestamppb.Now()
|
||||
err = ts.Store.UpdatePATLastUsed(ctx, user.ID, tokenID, lastUsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify update
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 1)
|
||||
assert.NotNil(t, pats[0].LastUsedAt)
|
||||
})
|
||||
|
||||
t.Run("handles multiple PATs", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add multiple PATs
|
||||
token1 := auth.GeneratePersonalAccessToken()
|
||||
tokenHash1 := auth.HashPersonalAccessToken(token1)
|
||||
tokenID1 := util.GenUUID()
|
||||
|
||||
token2 := auth.GeneratePersonalAccessToken()
|
||||
tokenHash2 := auth.HashPersonalAccessToken(token2)
|
||||
tokenID2 := util.GenUUID()
|
||||
|
||||
pat1 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID1,
|
||||
TokenHash: tokenHash1,
|
||||
Description: "PAT 1",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
pat2 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID2,
|
||||
TokenHash: tokenHash2,
|
||||
Description: "PAT 2",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat1)
|
||||
require.NoError(t, err)
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve all PATs
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 2)
|
||||
|
||||
// Remove one PAT
|
||||
err = ts.Store.RemoveUserPersonalAccessToken(ctx, user.ID, tokenID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only one PAT remains
|
||||
pats, err = ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 1)
|
||||
assert.Equal(t, tokenID2, pats[0].TokenId)
|
||||
})
|
||||
|
||||
t.Run("finds user by PAT hash", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find user by PAT hash
|
||||
result, err := ts.Store.GetUserByPATHash(ctx, tokenHash)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, user.ID, result.UserID)
|
||||
assert.NotNil(t, result.User)
|
||||
assert.Equal(t, user.Username, result.User.Username)
|
||||
assert.NotNil(t, result.PAT)
|
||||
assert.Equal(t, tokenID, result.PAT.TokenId)
|
||||
})
|
||||
}
|
||||
552
server/router/api/v1/test/idp_service_test.go
Normal file
552
server/router/api/v1/test/idp_service_test.go
Normal file
@@ -0,0 +1,552 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestCreateIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
ctx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create OAuth2 identity provider
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test OAuth2 Provider",
|
||||
IdentifierFilter: "",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthUrl: "https://example.com/oauth/authorize",
|
||||
TokenUrl: "https://example.com/oauth/token",
|
||||
UserInfoUrl: "https://example.com/oauth/userinfo",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
AvatarUrl: "avatar_url",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "Test OAuth2 Provider", resp.Title)
|
||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
|
||||
require.Contains(t, resp.Name, "identity-providers/")
|
||||
require.NotNil(t, resp.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client-id", resp.Config.GetOauth2Config().ClientId)
|
||||
})
|
||||
|
||||
t.Run("CreateIdentityProvider permission denied for non-host user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
ctx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("CreateIdentityProvider unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "user not authenticated")
|
||||
})
|
||||
}
|
||||
|
||||
func TestListIdentityProviders(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ListIdentityProviders empty", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.ListIdentityProvidersRequest{}
|
||||
resp, err := ts.Service.ListIdentityProviders(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Empty(t, resp.IdentityProviders)
|
||||
})
|
||||
|
||||
t.Run("ListIdentityProviders with providers", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a couple of identity providers
|
||||
createReq1 := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Provider 1",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "client1",
|
||||
AuthUrl: "https://example1.com/auth",
|
||||
TokenUrl: "https://example1.com/token",
|
||||
UserInfoUrl: "https://example1.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
createReq2 := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Provider 2",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "client2",
|
||||
AuthUrl: "https://example2.com/auth",
|
||||
TokenUrl: "https://example2.com/token",
|
||||
UserInfoUrl: "https://example2.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq1)
|
||||
require.NoError(t, err)
|
||||
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List providers
|
||||
listReq := &v1pb.ListIdentityProvidersRequest{}
|
||||
resp, err := ts.Service.ListIdentityProviders(ctx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.IdentityProviders, 2)
|
||||
|
||||
// Verify response contains expected providers
|
||||
titles := []string{resp.IdentityProviders[0].Title, resp.IdentityProviders[1].Title}
|
||||
require.Contains(t, titles, "Provider 1")
|
||||
require.Contains(t, titles, "Provider 2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
AuthUrl: "https://example.com/auth",
|
||||
TokenUrl: "https://example.com/token",
|
||||
UserInfoUrl: "https://example.com/user",
|
||||
Scopes: []string{"openid", "profile"},
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get identity provider
|
||||
getReq := &v1pb.GetIdentityProviderRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
// Test unauthenticated, should not contain client secret
|
||||
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, created.Name, resp.Name)
|
||||
require.Equal(t, "Test Provider", resp.Title)
|
||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
|
||||
require.NotNil(t, resp.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "", resp.Config.GetOauth2Config().ClientSecret)
|
||||
|
||||
// Test as host user, should contain client secret
|
||||
respHostUser, err := ts.Service.GetIdentityProvider(userCtx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respHostUser)
|
||||
require.Equal(t, created.Name, respHostUser.Name)
|
||||
require.Equal(t, "Test Provider", respHostUser.Title)
|
||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, respHostUser.Type)
|
||||
require.NotNil(t, respHostUser.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client", respHostUser.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "test-secret", respHostUser.Config.GetOauth2Config().ClientSecret)
|
||||
})
|
||||
|
||||
t.Run("GetIdentityProvider not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.GetIdentityProviderRequest{
|
||||
Name: "identity-providers/999",
|
||||
}
|
||||
|
||||
_, err := ts.Service.GetIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("GetIdentityProvider invalid name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.GetIdentityProviderRequest{
|
||||
Name: "invalid-name",
|
||||
}
|
||||
|
||||
_, err := ts.Service.GetIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Original Provider",
|
||||
IdentifierFilter: "",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "original-client",
|
||||
AuthUrl: "https://original.com/auth",
|
||||
TokenUrl: "https://original.com/token",
|
||||
UserInfoUrl: "https://original.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update identity provider
|
||||
updateReq := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: created.Name,
|
||||
Title: "Updated Provider",
|
||||
IdentifierFilter: "test@example.com",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "updated-client",
|
||||
ClientSecret: "updated-secret",
|
||||
AuthUrl: "https://updated.com/auth",
|
||||
TokenUrl: "https://updated.com/token",
|
||||
UserInfoUrl: "https://updated.com/user",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "given_name",
|
||||
Email: "email",
|
||||
AvatarUrl: "picture",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "identifier_filter", "config"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := ts.Service.UpdateIdentityProvider(userCtx, updateReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, "Updated Provider", updated.Title)
|
||||
require.Equal(t, "test@example.com", updated.IdentifierFilter)
|
||||
require.Equal(t, "updated-client", updated.Config.GetOauth2Config().ClientId)
|
||||
})
|
||||
|
||||
t.Run("UpdateIdentityProvider missing update mask", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: "identity-providers/1",
|
||||
Title: "Updated Provider",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update_mask is required")
|
||||
})
|
||||
|
||||
t.Run("UpdateIdentityProvider invalid name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: "invalid-name",
|
||||
Title: "Updated Provider",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DeleteIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Provider to Delete",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "client-to-delete",
|
||||
AuthUrl: "https://example.com/auth",
|
||||
TokenUrl: "https://example.com/token",
|
||||
UserInfoUrl: "https://example.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete identity provider
|
||||
deleteReq := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteIdentityProvider(userCtx, deleteReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
getReq := &v1pb.GetIdentityProviderRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetIdentityProvider(ctx, getReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("DeleteIdentityProvider invalid name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: "invalid-name",
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
|
||||
t.Run("DeleteIdentityProvider not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: "identity-providers/999",
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
// Note: Delete might succeed even if item doesn't exist, depending on store implementation
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviderPermissions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Only host users can create identity providers", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("Authentication required", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "user not authenticated")
|
||||
})
|
||||
}
|
||||
54
server/router/api/v1/test/instance_admin_cache_test.go
Normal file
54
server/router/api/v1/test/instance_admin_cache_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestInstanceAdminRetrieval(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Instance becomes initialized after first admin user is created", func(t *testing.T) {
|
||||
// Create test service
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Verify instance is not initialized initially
|
||||
profile1, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, profile1.Admin, "Instance should not be initialized before first admin user")
|
||||
|
||||
// Create the first admin user
|
||||
user, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
|
||||
// Verify instance is now initialized
|
||||
profile2, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, profile2.Admin, "Instance should be initialized after first admin user is created")
|
||||
require.Equal(t, user.Username, profile2.Admin.Username)
|
||||
})
|
||||
|
||||
t.Run("Admin retrieval is cached by Store layer", func(t *testing.T) {
|
||||
// Create test service
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create admin user
|
||||
user, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Multiple calls should return consistent admin user (from cache)
|
||||
for i := 0; i < 5; i++ {
|
||||
profile, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, profile.Admin)
|
||||
require.Equal(t, user.Username, profile.Admin.Username)
|
||||
}
|
||||
})
|
||||
}
|
||||
204
server/router/api/v1/test/instance_service_test.go
Normal file
204
server/router/api/v1/test/instance_service_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestGetInstanceProfile(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetInstanceProfile returns instance profile", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetInstanceProfile directly
|
||||
req := &v1pb.GetInstanceProfileRequest{}
|
||||
resp, err := ts.Service.GetInstanceProfile(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify the response contains expected data
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.True(t, resp.Demo)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
|
||||
// Instance should not be initialized since no admin users are created
|
||||
require.Nil(t, resp.Admin)
|
||||
})
|
||||
|
||||
t.Run("GetInstanceProfile with initialized instance", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user in the store
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, hostUser)
|
||||
|
||||
// Call GetInstanceProfile directly
|
||||
req := &v1pb.GetInstanceProfileRequest{}
|
||||
resp, err := ts.Service.GetInstanceProfile(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify the response contains expected data with initialized flag
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.True(t, resp.Demo)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
|
||||
// Instance should be initialized since an admin user exists
|
||||
require.NotNil(t, resp.Admin)
|
||||
require.Equal(t, hostUser.Username, resp.Admin.Username)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetInstanceProfile_Concurrency(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Concurrent access to service", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user
|
||||
_, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make concurrent requests
|
||||
numGoroutines := 10
|
||||
results := make(chan *v1pb.InstanceProfile, numGoroutines)
|
||||
errors := make(chan error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
req := &v1pb.GetInstanceProfileRequest{}
|
||||
resp, err := ts.Service.GetInstanceProfile(ctx, req)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
results <- resp
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect all results
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
select {
|
||||
case err := <-errors:
|
||||
t.Fatalf("Goroutine returned error: %v", err)
|
||||
case resp := <-results:
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.True(t, resp.Demo)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
require.NotNil(t, resp.Admin)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetInstanceSetting(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetInstanceSetting - general setting", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetInstanceSetting for general setting
|
||||
req := &v1pb.GetInstanceSettingRequest{
|
||||
Name: "instance/settings/GENERAL",
|
||||
}
|
||||
resp, err := ts.Service.GetInstanceSetting(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "instance/settings/GENERAL", resp.Name)
|
||||
|
||||
// The general setting should have a general_setting field
|
||||
generalSetting := resp.GetGeneralSetting()
|
||||
require.NotNil(t, generalSetting)
|
||||
|
||||
// General setting should have default values
|
||||
require.False(t, generalSetting.DisallowUserRegistration)
|
||||
require.False(t, generalSetting.DisallowPasswordAuth)
|
||||
require.Empty(t, generalSetting.AdditionalScript)
|
||||
})
|
||||
|
||||
t.Run("GetInstanceSetting - storage setting", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user for storage setting access
|
||||
hostUser, err := ts.CreateHostUser(ctx, "testhost")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add user to context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Call GetInstanceSetting for storage setting
|
||||
req := &v1pb.GetInstanceSettingRequest{
|
||||
Name: "instance/settings/STORAGE",
|
||||
}
|
||||
resp, err := ts.Service.GetInstanceSetting(userCtx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "instance/settings/STORAGE", resp.Name)
|
||||
|
||||
// The storage setting should have a storage_setting field
|
||||
storageSetting := resp.GetStorageSetting()
|
||||
require.NotNil(t, storageSetting)
|
||||
})
|
||||
|
||||
t.Run("GetInstanceSetting - memo related setting", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetInstanceSetting for memo related setting
|
||||
req := &v1pb.GetInstanceSettingRequest{
|
||||
Name: "instance/settings/MEMO_RELATED",
|
||||
}
|
||||
resp, err := ts.Service.GetInstanceSetting(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "instance/settings/MEMO_RELATED", resp.Name)
|
||||
|
||||
// The memo related setting should have a memo_related_setting field
|
||||
memoRelatedSetting := resp.GetMemoRelatedSetting()
|
||||
require.NotNil(t, memoRelatedSetting)
|
||||
})
|
||||
|
||||
t.Run("GetInstanceSetting - invalid setting name", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetInstanceSetting with invalid name
|
||||
req := &v1pb.GetInstanceSettingRequest{
|
||||
Name: "invalid/setting/name",
|
||||
}
|
||||
_, err := ts.Service.GetInstanceSetting(ctx, req)
|
||||
|
||||
// Should return an error
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid instance setting name")
|
||||
})
|
||||
}
|
||||
166
server/router/api/v1/test/memo_attachment_service_test.go
Normal file
166
server/router/api/v1/test/memo_attachment_service_test.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestSetMemoAttachments(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMemoAttachments success by memo owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Create attachment
|
||||
attachment, err := ts.Service.CreateAttachment(userCtx, &apiv1.CreateAttachmentRequest{
|
||||
Attachment: &apiv1.Attachment{
|
||||
Filename: "test.txt",
|
||||
Size: 5,
|
||||
Type: "text/plain",
|
||||
Content: []byte("hello"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, attachment)
|
||||
|
||||
// Set memo attachments - should succeed
|
||||
_, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{
|
||||
Name: memo.Name,
|
||||
Attachments: []*apiv1.Attachment{
|
||||
{Name: attachment.Name},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetMemoAttachments success by host user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create memo by regular user
|
||||
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Host user can modify attachments - should succeed
|
||||
_, err = ts.Service.SetMemoAttachments(hostCtx, &apiv1.SetMemoAttachmentsRequest{
|
||||
Name: memo.Name,
|
||||
Attachments: []*apiv1.Attachment{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetMemoAttachments permission denied for non-owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user1
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
// Create user2
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
|
||||
// Create memo by user1
|
||||
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// User2 tries to modify attachments - should fail
|
||||
_, err = ts.Service.SetMemoAttachments(user2Ctx, &apiv1.SetMemoAttachmentsRequest{
|
||||
Name: memo.Name,
|
||||
Attachments: []*apiv1.Attachment{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("SetMemoAttachments unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Unauthenticated user tries to modify attachments - should fail
|
||||
_, err = ts.Service.SetMemoAttachments(ctx, &apiv1.SetMemoAttachmentsRequest{
|
||||
Name: memo.Name,
|
||||
Attachments: []*apiv1.Attachment{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not authenticated")
|
||||
})
|
||||
|
||||
t.Run("SetMemoAttachments memo not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Try to set attachments on non-existent memo - should fail
|
||||
_, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{
|
||||
Name: "memos/nonexistent-uid-12345",
|
||||
Attachments: []*apiv1.Attachment{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
169
server/router/api/v1/test/memo_relation_service_test.go
Normal file
169
server/router/api/v1/test/memo_relation_service_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestSetMemoRelations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMemoRelations success by memo owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo1
|
||||
memo1, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo 1",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo1)
|
||||
|
||||
// Create memo2
|
||||
memo2, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo 2",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo2)
|
||||
|
||||
// Set memo relations - should succeed
|
||||
_, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{
|
||||
Name: memo1.Name,
|
||||
Relations: []*apiv1.MemoRelation{
|
||||
{
|
||||
RelatedMemo: &apiv1.MemoRelation_Memo{
|
||||
Name: memo2.Name,
|
||||
},
|
||||
Type: apiv1.MemoRelation_REFERENCE,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetMemoRelations success by host user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create memo by regular user
|
||||
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Host user can modify relations - should succeed
|
||||
_, err = ts.Service.SetMemoRelations(hostCtx, &apiv1.SetMemoRelationsRequest{
|
||||
Name: memo.Name,
|
||||
Relations: []*apiv1.MemoRelation{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetMemoRelations permission denied for non-owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user1
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
// Create user2
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
|
||||
// Create memo by user1
|
||||
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// User2 tries to modify relations - should fail
|
||||
_, err = ts.Service.SetMemoRelations(user2Ctx, &apiv1.SetMemoRelationsRequest{
|
||||
Name: memo.Name,
|
||||
Relations: []*apiv1.MemoRelation{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("SetMemoRelations unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Unauthenticated user tries to modify relations - should fail
|
||||
_, err = ts.Service.SetMemoRelations(ctx, &apiv1.SetMemoRelationsRequest{
|
||||
Name: memo.Name,
|
||||
Relations: []*apiv1.MemoRelation{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not authenticated")
|
||||
})
|
||||
|
||||
t.Run("SetMemoRelations memo not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Try to set relations on non-existent memo - should fail
|
||||
_, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{
|
||||
Name: "memos/nonexistent-uid-12345",
|
||||
Relations: []*apiv1.MemoRelation{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
369
server/router/api/v1/test/memo_service_test.go
Normal file
369
server/router/api/v1/test/memo_service_test.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestListMemos(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create userOne
|
||||
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, userOne)
|
||||
|
||||
// Create userOne context
|
||||
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
|
||||
|
||||
// Create userTwo
|
||||
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, userTwo)
|
||||
|
||||
// Create userTwo context
|
||||
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
|
||||
|
||||
// Create attachmentOne by userOne
|
||||
attachmentOne, err := ts.Service.CreateAttachment(userOneCtx, &apiv1.CreateAttachmentRequest{
|
||||
Attachment: &apiv1.Attachment{
|
||||
Name: "",
|
||||
Filename: "hello.txt",
|
||||
Size: 5,
|
||||
Type: "text/plain",
|
||||
Content: []byte{
|
||||
104, 101, 108, 108, 111,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, attachmentOne)
|
||||
|
||||
// Create attachmentTwo by userOne
|
||||
attachmentTwo, err := ts.Service.CreateAttachment(userOneCtx, &apiv1.CreateAttachmentRequest{
|
||||
Attachment: &apiv1.Attachment{
|
||||
Name: "",
|
||||
Filename: "world.txt",
|
||||
Size: 5,
|
||||
Type: "text/plain",
|
||||
Content: []byte{
|
||||
119, 111, 114, 108, 100,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, attachmentTwo)
|
||||
|
||||
// Create memoOne with two attachments by userOne
|
||||
memoOne, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Hellooo, any words after this sentence won't be in the snippet. This is the next sentence. And I also have two attachments.",
|
||||
Visibility: apiv1.Visibility_PROTECTED,
|
||||
Attachments: []*apiv1.Attachment{
|
||||
&apiv1.Attachment{
|
||||
Name: attachmentOne.Name,
|
||||
},
|
||||
&apiv1.Attachment{
|
||||
Name: attachmentTwo.Name,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoOne)
|
||||
|
||||
// Create memoTwo by userTwo referencing memoOne
|
||||
memoTwo, err := ts.Service.CreateMemo(userTwoCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This is a memo reminding you to check the attachment attached to memoOne. I have referenced the memo below.⬇️",
|
||||
Visibility: apiv1.Visibility_PROTECTED,
|
||||
Relations: []*apiv1.MemoRelation{
|
||||
&apiv1.MemoRelation{
|
||||
RelatedMemo: &apiv1.MemoRelation_Memo{
|
||||
Name: memoOne.Name,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoTwo)
|
||||
|
||||
// Create memoThree by userOne
|
||||
memoThree, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This is a very popular memo. I have 2 reactions!",
|
||||
Visibility: apiv1.Visibility_PROTECTED,
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoThree)
|
||||
|
||||
// Create reaction from userOne on memoThree
|
||||
reactionOne, err := ts.Service.UpsertMemoReaction(userOneCtx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memoThree.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memoThree.Name,
|
||||
ReactionType: "❤️",
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reactionOne)
|
||||
|
||||
// Create reaction from userTwo on memoThree
|
||||
reactionTwo, err := ts.Service.UpsertMemoReaction(userTwoCtx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memoThree.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memoThree.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reactionTwo)
|
||||
|
||||
memos, err := ts.Service.ListMemos(userOneCtx, &apiv1.ListMemosRequest{PageSize: 10})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memos)
|
||||
require.Equal(t, 3, len(memos.Memos))
|
||||
|
||||
// ///////////////
|
||||
// VERIFY MEMO ONE
|
||||
// ///////////////
|
||||
memoOneResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoOne.GetName() })
|
||||
require.NotEqual(t, memoOneResIdx, -1)
|
||||
|
||||
memoOneRes := memos.Memos[memoOneResIdx]
|
||||
require.NotNil(t, memoOneRes)
|
||||
|
||||
require.Equal(t, fmt.Sprintf("users/%d", userOne.ID), memoOneRes.GetCreator())
|
||||
require.Equal(t, apiv1.Visibility_PROTECTED, memoOneRes.GetVisibility())
|
||||
require.Equal(t, memoOne.Content, memoOneRes.GetContent())
|
||||
require.Equal(t, memoOne.Content[:64]+"...", memoOneRes.GetSnippet(), "memoOne's content is snipped past the 64 char limit")
|
||||
require.Len(t, memoOneRes.Attachments, 2)
|
||||
require.Len(t, memoOneRes.Relations, 1)
|
||||
require.Empty(t, memoOneRes.Reactions)
|
||||
|
||||
// verify memoOne's attachments
|
||||
// attachment one
|
||||
attachmentOneResIdx := slices.IndexFunc(memoOneRes.Attachments, func(a *apiv1.Attachment) bool { return a.GetName() == attachmentOne.GetName() })
|
||||
require.NotEqual(t, attachmentOneResIdx, -1)
|
||||
|
||||
attachmentOneRes := memoOneRes.Attachments[attachmentOneResIdx]
|
||||
require.NotNil(t, attachmentOneRes)
|
||||
|
||||
require.Equal(t, attachmentOne.GetName(), attachmentOneRes.GetName())
|
||||
require.Equal(t, attachmentOne.GetContent(), attachmentOneRes.GetContent())
|
||||
|
||||
// attachment two
|
||||
attachmentTwoResIdx := slices.IndexFunc(memoOneRes.Attachments, func(a *apiv1.Attachment) bool { return a.GetName() == attachmentTwo.GetName() })
|
||||
require.NotEqual(t, attachmentTwoResIdx, -1)
|
||||
|
||||
attachmentTwoRes := memoOneRes.Attachments[attachmentTwoResIdx]
|
||||
require.NotNil(t, attachmentTwoRes)
|
||||
require.Equal(t, attachmentTwo.GetName(), attachmentTwoRes.GetName())
|
||||
|
||||
require.Equal(t, attachmentTwo.GetName(), attachmentTwoRes.GetName())
|
||||
require.Equal(t, attachmentTwo.GetContent(), attachmentTwoRes.GetContent())
|
||||
|
||||
// verify memoOne's relations
|
||||
require.Len(t, memoOneRes.Relations, 1)
|
||||
memoOneExpectedRelation := &apiv1.MemoRelation{
|
||||
Memo: &apiv1.MemoRelation_Memo{Name: memoTwo.GetName()},
|
||||
RelatedMemo: &apiv1.MemoRelation_Memo{Name: memoOne.GetName()},
|
||||
}
|
||||
require.Equal(t, memoOneExpectedRelation.Memo.GetName(), memoOneRes.Relations[0].Memo.GetName())
|
||||
require.Equal(t, memoOneExpectedRelation.RelatedMemo.GetName(), memoOneRes.Relations[0].RelatedMemo.GetName())
|
||||
|
||||
// ///////////////
|
||||
// VERIFY MEMO TWO
|
||||
// ///////////////
|
||||
memoTwoResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoTwo.GetName() })
|
||||
require.NotEqual(t, memoTwoResIdx, -1)
|
||||
|
||||
memoTwoRes := memos.Memos[memoTwoResIdx]
|
||||
require.NotNil(t, memoTwoRes)
|
||||
|
||||
require.Equal(t, fmt.Sprintf("users/%d", userTwo.ID), memoTwoRes.GetCreator())
|
||||
require.Equal(t, apiv1.Visibility_PROTECTED, memoTwoRes.GetVisibility())
|
||||
require.Equal(t, memoTwo.Content, memoTwoRes.GetContent())
|
||||
require.Empty(t, memoTwoRes.Attachments)
|
||||
require.Len(t, memoTwoRes.Relations, 1)
|
||||
require.Empty(t, memoTwoRes.Reactions)
|
||||
|
||||
// verify memoTwo's relations
|
||||
require.Len(t, memoTwoRes.Relations, 1)
|
||||
memoTwoExpectedRelation := &apiv1.MemoRelation{
|
||||
Memo: &apiv1.MemoRelation_Memo{Name: memoTwo.GetName()},
|
||||
RelatedMemo: &apiv1.MemoRelation_Memo{Name: memoOne.GetName()},
|
||||
}
|
||||
require.Equal(t, memoTwoExpectedRelation.Memo.GetName(), memoTwoRes.Relations[0].Memo.GetName())
|
||||
require.Equal(t, memoTwoExpectedRelation.RelatedMemo.GetName(), memoTwoRes.Relations[0].RelatedMemo.GetName())
|
||||
|
||||
// ///////////////
|
||||
// VERIFY MEMO THREE
|
||||
// ///////////////
|
||||
memoThreeResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoThree.GetName() })
|
||||
require.NotEqual(t, memoThreeResIdx, -1)
|
||||
|
||||
memoThreeRes := memos.Memos[memoThreeResIdx]
|
||||
require.NotNil(t, memoThreeRes)
|
||||
|
||||
require.Equal(t, fmt.Sprintf("users/%d", userOne.ID), memoThreeRes.GetCreator())
|
||||
require.Equal(t, apiv1.Visibility_PROTECTED, memoThreeRes.GetVisibility())
|
||||
require.Equal(t, memoThree.Content, memoThreeRes.GetContent())
|
||||
require.Empty(t, memoThreeRes.Attachments)
|
||||
require.Empty(t, memoThreeRes.Relations)
|
||||
require.Len(t, memoThreeRes.Reactions, 2)
|
||||
|
||||
// verify memoThree's reactions
|
||||
require.Len(t, memoThreeRes.Reactions, 2)
|
||||
// userOne's reaction
|
||||
userOneReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%d", userOne.ID) })
|
||||
require.NotEqual(t, userOneReactionIdx, -1)
|
||||
|
||||
userOneReaction := memoThreeRes.Reactions[userOneReactionIdx]
|
||||
require.NotNil(t, userOneReaction)
|
||||
require.Equal(t, "❤️", userOneReaction.ReactionType)
|
||||
|
||||
// userTwo's reaction
|
||||
userTwoReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%d", userTwo.ID) })
|
||||
require.NotEqual(t, userTwoReactionIdx, -1)
|
||||
|
||||
userTwoReaction := memoThreeRes.Reactions[userTwoReactionIdx]
|
||||
require.NotNil(t, userTwoReaction)
|
||||
require.Equal(t, "👍", userTwoReaction.ReactionType)
|
||||
}
|
||||
|
||||
// TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments.
|
||||
// This addresses issue #5483: https://github.com/usememos/memos/issues/5483
|
||||
func TestCreateMemoWithCustomTimestamps(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "test-user-timestamps")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Define custom timestamps (January 1, 2020)
|
||||
customCreateTime := time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
customUpdateTime := time.Date(2020, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
customDisplayTime := time.Date(2020, 1, 3, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Test 1: Create a memo with custom create_time
|
||||
memoWithCreateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has a custom creation time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
CreateTime: timestamppb.New(customCreateTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithCreateTime)
|
||||
require.Equal(t, customCreateTime.Unix(), memoWithCreateTime.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
|
||||
|
||||
// Test 2: Create a memo with custom update_time
|
||||
memoWithUpdateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has a custom update time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
UpdateTime: timestamppb.New(customUpdateTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithUpdateTime)
|
||||
require.Equal(t, customUpdateTime.Unix(), memoWithUpdateTime.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
|
||||
|
||||
// Test 3: Create a memo with custom display_time
|
||||
// Note: display_time is computed from either created_ts or updated_ts based on instance setting
|
||||
// Since DisplayWithUpdateTime defaults to false, display_time maps to created_ts
|
||||
memoWithDisplayTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has a custom display time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
DisplayTime: timestamppb.New(customDisplayTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithDisplayTime)
|
||||
// Since DisplayWithUpdateTime is false by default, display_time sets created_ts
|
||||
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.DisplayTime.AsTime().Unix(), "display_time should match the custom timestamp")
|
||||
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.CreateTime.AsTime().Unix(), "create_time should also match since display_time maps to created_ts")
|
||||
|
||||
// Test 4: Create a memo with all custom timestamps
|
||||
// When both display_time and create_time are provided, create_time takes precedence
|
||||
memoWithAllTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has all custom timestamps",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
CreateTime: timestamppb.New(customCreateTime),
|
||||
UpdateTime: timestamppb.New(customUpdateTime),
|
||||
DisplayTime: timestamppb.New(customDisplayTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithAllTimestamps)
|
||||
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
|
||||
require.Equal(t, customUpdateTime.Unix(), memoWithAllTimestamps.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
|
||||
// display_time is computed from created_ts when DisplayWithUpdateTime is false
|
||||
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.DisplayTime.AsTime().Unix(), "display_time should be derived from create_time")
|
||||
|
||||
// Test 5: Create a comment (memo relation) with custom timestamps
|
||||
parentMemo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This is the parent memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parentMemo)
|
||||
|
||||
customCommentCreateTime := time.Date(2021, 6, 15, 10, 30, 0, 0, time.UTC)
|
||||
comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: parentMemo.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "This is a comment with custom create time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
CreateTime: timestamppb.New(customCommentCreateTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, comment)
|
||||
require.Equal(t, customCommentCreateTime.Unix(), comment.CreateTime.AsTime().Unix(), "comment create_time should match the custom timestamp")
|
||||
|
||||
// Test 6: Verify that memos without custom timestamps still get auto-generated ones
|
||||
memoWithoutTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has auto-generated timestamps",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithoutTimestamps)
|
||||
require.NotNil(t, memoWithoutTimestamps.CreateTime, "create_time should be auto-generated")
|
||||
require.NotNil(t, memoWithoutTimestamps.UpdateTime, "update_time should be auto-generated")
|
||||
require.True(t, time.Now().Unix()-memoWithoutTimestamps.CreateTime.AsTime().Unix() < 5, "create_time should be recent (within 5 seconds)")
|
||||
}
|
||||
194
server/router/api/v1/test/reaction_service_test.go
Normal file
194
server/router/api/v1/test/reaction_service_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestDeleteMemoReaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DeleteMemoReaction success by reaction owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Create reaction
|
||||
reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reaction)
|
||||
|
||||
// Delete reaction - should succeed
|
||||
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
|
||||
Name: reaction.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("DeleteMemoReaction success by host user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create memo by regular user
|
||||
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Create reaction by regular user
|
||||
reaction, err := ts.Service.UpsertMemoReaction(regularUserCtx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reaction)
|
||||
|
||||
// Host user can delete reaction - should succeed
|
||||
_, err = ts.Service.DeleteMemoReaction(hostCtx, &apiv1.DeleteMemoReactionRequest{
|
||||
Name: reaction.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("DeleteMemoReaction permission denied for non-owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user1
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
// Create user2
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
|
||||
// Create memo by user1
|
||||
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Create reaction by user1
|
||||
reaction, err := ts.Service.UpsertMemoReaction(user1Ctx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reaction)
|
||||
|
||||
// User2 tries to delete reaction - should fail with permission denied
|
||||
_, err = ts.Service.DeleteMemoReaction(user2Ctx, &apiv1.DeleteMemoReactionRequest{
|
||||
Name: reaction.Name,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("DeleteMemoReaction unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Create reaction
|
||||
reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reaction)
|
||||
|
||||
// Unauthenticated user tries to delete reaction - should fail
|
||||
_, err = ts.Service.DeleteMemoReaction(ctx, &apiv1.DeleteMemoReactionRequest{
|
||||
Name: reaction.Name,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not authenticated")
|
||||
})
|
||||
|
||||
t.Run("DeleteMemoReaction not found returns permission denied", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Try to delete non-existent reaction - should fail with permission denied
|
||||
// (not "not found" to avoid information disclosure)
|
||||
// Use new nested resource format: memos/{memo}/reactions/{reaction}
|
||||
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
|
||||
Name: "memos/nonexistent/reactions/99999",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
require.NotContains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
819
server/router/api/v1/test/shortcut_service_test.go
Normal file
819
server/router/api/v1/test/shortcut_service_test.go
Normal file
@@ -0,0 +1,819 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestListShortcuts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ListShortcuts success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// List shortcuts (should be empty initially)
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
resp, err := ts.Service.ListShortcuts(userCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Empty(t, resp.Shortcuts)
|
||||
})
|
||||
|
||||
t.Run("ListShortcuts permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to list user2's shortcuts
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
}
|
||||
|
||||
_, err = ts.Service.ListShortcuts(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("ListShortcuts invalid parent format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: "invalid-parent-format",
|
||||
}
|
||||
|
||||
_, err = ts.Service.ListShortcuts(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid user name")
|
||||
})
|
||||
|
||||
t.Run("ListShortcuts unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: "users/1",
|
||||
}
|
||||
|
||||
_, err := ts.Service.ListShortcuts(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// First create a shortcut
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now get the shortcut
|
||||
getReq := &v1pb.GetShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
resp, err := ts.Service.GetShortcut(userCtx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, created.Name, resp.Name)
|
||||
require.Equal(t, "Test Shortcut", resp.Title)
|
||||
require.Equal(t, "tag in [\"test\"]", resp.Filter)
|
||||
})
|
||||
|
||||
t.Run("GetShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
getReq := &v1pb.GetShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(user2Ctx, getReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("GetShortcut invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.GetShortcutRequest{
|
||||
Name: "invalid-shortcut-name",
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid shortcut name")
|
||||
})
|
||||
|
||||
t.Run("GetShortcut not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.GetShortcutRequest{
|
||||
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "My Shortcut",
|
||||
Filter: "tag in [\"important\"]",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ts.Service.CreateShortcut(userCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "My Shortcut", resp.Title)
|
||||
require.Equal(t, "tag in [\"important\"]", resp.Filter)
|
||||
require.Contains(t, resp.Name, fmt.Sprintf("users/%d/shortcuts/", user.ID))
|
||||
|
||||
// Verify the shortcut was created by listing
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listResp.Shortcuts, 1)
|
||||
require.Equal(t, "My Shortcut", listResp.Shortcuts[0].Title)
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to create shortcut for user2
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Forbidden Shortcut",
|
||||
Filter: "tag in [\"forbidden\"]",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut invalid parent format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: "invalid-parent",
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid user name")
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut invalid filter", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Invalid Filter Shortcut",
|
||||
Filter: "invalid||filter))syntax",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid filter")
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut missing title", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "title is required")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Original Title",
|
||||
Filter: "tag in [\"original\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the shortcut
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created.Name,
|
||||
Title: "Updated Title",
|
||||
Filter: "tag in [\"updated\"]",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "filter"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, "Updated Title", updated.Title)
|
||||
require.Equal(t, "tag in [\"updated\"]", updated.Filter)
|
||||
require.Equal(t, created.Name, updated.Name)
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to update shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created.Name,
|
||||
Title: "Hacked Title",
|
||||
Filter: "tag in [\"hacked\"]",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "filter"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateShortcut(user2Ctx, updateReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut missing update mask", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user and context for authentication
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: fmt.Sprintf("users/%d/shortcuts/test", user.ID),
|
||||
Title: "Updated Title",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update mask is required")
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: "invalid-shortcut-name",
|
||||
Title: "Updated Title",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.UpdateShortcut(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid shortcut name")
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut invalid filter", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to update with invalid filter
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created.Name,
|
||||
Filter: "invalid||filter))syntax",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"filter"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateShortcut(userCtx, updateReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid filter")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DeleteShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Shortcut to Delete",
|
||||
Filter: "tag in [\"delete\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the shortcut
|
||||
deleteReq := &v1pb.DeleteShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion by listing shortcuts
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, listResp.Shortcuts)
|
||||
|
||||
// Also verify by trying to get the deleted shortcut
|
||||
getReq := &v1pb.GetShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(userCtx, getReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("DeleteShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to delete shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
deleteReq := &v1pb.DeleteShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(user2Ctx, deleteReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("DeleteShortcut invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.DeleteShortcutRequest{
|
||||
Name: "invalid-shortcut-name",
|
||||
}
|
||||
|
||||
_, err := ts.Service.DeleteShortcut(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid shortcut name")
|
||||
})
|
||||
|
||||
t.Run("DeleteShortcut not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.DeleteShortcutRequest{
|
||||
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestShortcutFiltering(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateShortcut with valid filters", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Test various valid filter formats
|
||||
validFilters := []string{
|
||||
"tag in [\"work\"]",
|
||||
"content.contains(\"meeting\")",
|
||||
"tag in [\"work\"] && content.contains(\"meeting\")",
|
||||
"tag in [\"work\"] || tag in [\"personal\"]",
|
||||
"creator_id == 1",
|
||||
"visibility == \"PUBLIC\"",
|
||||
"has_task_list == true",
|
||||
"has_task_list == false",
|
||||
}
|
||||
|
||||
for i, filter := range validFilters {
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Valid Filter " + string(rune(i)),
|
||||
Filter: filter,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.NoError(t, err, "Filter should be valid: %s", filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut with invalid filters", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Test various invalid filter formats
|
||||
invalidFilters := []string{
|
||||
"tag in ", // incomplete expression
|
||||
"invalid_field @in [\"value\"]", // unknown field
|
||||
"tag in [\"work\"] &&", // incomplete expression
|
||||
"tag in [\"work\"] || || tag in [\"test\"]", // double operator
|
||||
"((tag in [\"work\"]", // unmatched parentheses
|
||||
"tag in [\"work\"] && )", // mismatched parentheses
|
||||
"tag == \"work\"", // wrong operator (== not supported for tags)
|
||||
"tag in work", // missing brackets
|
||||
}
|
||||
|
||||
for _, filter := range invalidFilters {
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Invalid Filter Test",
|
||||
Filter: filter,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err, "Filter should be invalid: %s", filter)
|
||||
require.Contains(t, err.Error(), "invalid filter", "Error should mention invalid filter for: %s", filter)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShortcutCRUDComplete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Complete CRUD lifecycle", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// 1. Create multiple shortcuts
|
||||
shortcut1Req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Work Notes",
|
||||
Filter: "tag in [\"work\"]",
|
||||
},
|
||||
}
|
||||
|
||||
shortcut2Req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Personal Notes",
|
||||
Filter: "tag in [\"personal\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created1, err := ts.Service.CreateShortcut(userCtx, shortcut1Req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Work Notes", created1.Title)
|
||||
|
||||
created2, err := ts.Service.CreateShortcut(userCtx, shortcut2Req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Personal Notes", created2.Title)
|
||||
|
||||
// 2. List shortcuts and verify both exist
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listResp.Shortcuts, 2)
|
||||
|
||||
// 3. Get individual shortcuts
|
||||
getReq1 := &v1pb.GetShortcutRequest{Name: created1.Name}
|
||||
getResp1, err := ts.Service.GetShortcut(userCtx, getReq1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, created1.Name, getResp1.Name)
|
||||
require.Equal(t, "Work Notes", getResp1.Title)
|
||||
|
||||
getReq2 := &v1pb.GetShortcutRequest{Name: created2.Name}
|
||||
getResp2, err := ts.Service.GetShortcut(userCtx, getReq2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, created2.Name, getResp2.Name)
|
||||
require.Equal(t, "Personal Notes", getResp2.Title)
|
||||
|
||||
// 4. Update one shortcut
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created1.Name,
|
||||
Title: "Work & Meeting Notes",
|
||||
Filter: "tag in [\"work\"] || tag in [\"meeting\"]",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "filter"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Work & Meeting Notes", updated.Title)
|
||||
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", updated.Filter)
|
||||
|
||||
// 5. Verify update by getting it again
|
||||
getUpdatedReq := &v1pb.GetShortcutRequest{Name: created1.Name}
|
||||
getUpdatedResp, err := ts.Service.GetShortcut(userCtx, getUpdatedReq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Work & Meeting Notes", getUpdatedResp.Title)
|
||||
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", getUpdatedResp.Filter)
|
||||
|
||||
// 6. Delete one shortcut
|
||||
deleteReq := &v1pb.DeleteShortcutRequest{
|
||||
Name: created2.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 7. Verify deletion by listing (should only have 1 left)
|
||||
finalListResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, finalListResp.Shortcuts, 1)
|
||||
require.Equal(t, "Work & Meeting Notes", finalListResp.Shortcuts[0].Title)
|
||||
|
||||
// 8. Verify deleted shortcut can't be accessed
|
||||
getDeletedReq := &v1pb.GetShortcutRequest{Name: created2.Name}
|
||||
_, err = ts.Service.GetShortcut(userCtx, getDeletedReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
86
server/router/api/v1/test/test_helper.go
Normal file
86
server/router/api/v1/test/test_helper.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/plugin/markdown"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
apiv1 "github.com/usememos/memos/server/router/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
teststore "github.com/usememos/memos/store/test"
|
||||
)
|
||||
|
||||
// TestService holds the test service setup for API v1 services.
|
||||
type TestService struct {
|
||||
Service *apiv1.APIV1Service
|
||||
Store *store.Store
|
||||
Profile *profile.Profile
|
||||
Secret string
|
||||
}
|
||||
|
||||
// NewTestService creates a new test service with SQLite database.
|
||||
func NewTestService(t *testing.T) *TestService {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a test store with SQLite
|
||||
testStore := teststore.NewTestingStore(ctx, t)
|
||||
|
||||
// Create a test profile
|
||||
testProfile := &profile.Profile{
|
||||
Demo: true,
|
||||
Version: "test-1.0.0",
|
||||
InstanceURL: "http://localhost:8080",
|
||||
Driver: "sqlite",
|
||||
DSN: ":memory:",
|
||||
}
|
||||
|
||||
// Create APIV1Service with nil grpcServer since we're testing direct calls
|
||||
secret := "test-secret"
|
||||
markdownService := markdown.NewService(
|
||||
markdown.WithTagExtension(),
|
||||
)
|
||||
service := &apiv1.APIV1Service{
|
||||
Secret: secret,
|
||||
Profile: testProfile,
|
||||
Store: testStore,
|
||||
MarkdownService: markdownService,
|
||||
}
|
||||
|
||||
return &TestService{
|
||||
Service: service,
|
||||
Store: testStore,
|
||||
Profile: testProfile,
|
||||
Secret: secret,
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup closes resources after test.
|
||||
func (ts *TestService) Cleanup() {
|
||||
ts.Store.Close()
|
||||
}
|
||||
|
||||
// CreateHostUser creates an admin user for testing.
|
||||
func (ts *TestService) CreateHostUser(ctx context.Context, username string) (*store.User, error) {
|
||||
return ts.Store.CreateUser(ctx, &store.User{
|
||||
Username: username,
|
||||
Role: store.RoleAdmin,
|
||||
Email: username + "@example.com",
|
||||
})
|
||||
}
|
||||
|
||||
// CreateRegularUser creates a regular user for testing.
|
||||
func (ts *TestService) CreateRegularUser(ctx context.Context, username string) (*store.User, error) {
|
||||
return ts.Store.CreateUser(ctx, &store.User{
|
||||
Username: username,
|
||||
Role: store.RoleUser,
|
||||
Email: username + "@example.com",
|
||||
})
|
||||
}
|
||||
|
||||
// CreateUserContext creates a context with the given user's ID for authentication.
|
||||
func (*TestService) CreateUserContext(ctx context.Context, userID int32) context.Context {
|
||||
// Use the context key from the auth package
|
||||
return context.WithValue(ctx, auth.UserIDContextKey, userID)
|
||||
}
|
||||
173
server/router/api/v1/test/user_service_registration_test.go
Normal file
173
server/router/api/v1/test/user_service_registration_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
func TestCreateUserRegistration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateUser success when registration enabled", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// User registration is enabled by default, no need to set it explicitly
|
||||
|
||||
// Create user without authentication - should succeed
|
||||
_, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newuser",
|
||||
Email: "newuser@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("CreateUser blocked when registration disabled", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user first so we're not in first-user setup mode
|
||||
_, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Disable user registration
|
||||
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
DisallowUserRegistration: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to create user without authentication - should fail
|
||||
_, err = ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newuser",
|
||||
Email: "newuser@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not allowed")
|
||||
})
|
||||
|
||||
t.Run("CreateUser succeeds for superuser even when registration disabled", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Disable user registration
|
||||
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
DisallowUserRegistration: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Host user can create users even when registration is disabled - should succeed
|
||||
_, err = ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newuser",
|
||||
Email: "newuser@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("CreateUser regular user cannot create users when registration disabled", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
|
||||
require.NoError(t, err)
|
||||
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
// Disable user registration
|
||||
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
DisallowUserRegistration: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Regular user tries to create user when registration is disabled - should fail
|
||||
_, err = ts.Service.CreateUser(regularUserCtx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newuser",
|
||||
Email: "newuser@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not allowed")
|
||||
})
|
||||
|
||||
t.Run("CreateUser host can assign roles", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Host user can create user with specific role - should succeed
|
||||
createdUser, err := ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newadmin",
|
||||
Email: "newadmin@example.com",
|
||||
Password: "password123",
|
||||
Role: apiv1.User_ADMIN,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, createdUser)
|
||||
require.Equal(t, apiv1.User_ADMIN, createdUser.Role)
|
||||
})
|
||||
|
||||
t.Run("CreateUser unauthenticated user can only create regular user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user first so we're not in first-user setup mode
|
||||
_, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// User registration is enabled by default
|
||||
|
||||
// Unauthenticated user tries to create admin user - role should be ignored
|
||||
createdUser, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "wannabeadmin",
|
||||
Email: "wannabeadmin@example.com",
|
||||
Password: "password123",
|
||||
Role: apiv1.User_ADMIN, // This should be ignored
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, createdUser)
|
||||
require.Equal(t, apiv1.User_USER, createdUser.Role, "Unauthenticated users can only create USER role")
|
||||
})
|
||||
}
|
||||
105
server/router/api/v1/test/user_service_stats_test.go
Normal file
105
server/router/api/v1/test/user_service_stats_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestGetUserStats_TagCount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test service
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test host user
|
||||
user, err := ts.CreateHostUser(ctx, "test_user")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user context for authentication
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a memo with a single tag
|
||||
memo, err := ts.Store.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "This is a test memo with #test tag",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: []string{"test"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Test GetUserStats
|
||||
userName := fmt.Sprintf("users/%d", user.ID)
|
||||
response, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: userName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
// Check that the tag count is exactly 1, not 2
|
||||
require.Contains(t, response.TagCount, "test")
|
||||
require.Equal(t, int32(1), response.TagCount["test"], "Tag count should be 1 for a single occurrence")
|
||||
|
||||
// Create another memo with the same tag
|
||||
memo2, err := ts.Store.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "Another memo with #test tag",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: []string{"test"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo2)
|
||||
|
||||
// Test GetUserStats again
|
||||
response2, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: userName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response2)
|
||||
|
||||
// Check that the tag count is exactly 2, not 3
|
||||
require.Contains(t, response2.TagCount, "test")
|
||||
require.Equal(t, int32(2), response2.TagCount["test"], "Tag count should be 2 for two occurrences")
|
||||
|
||||
// Test with a new unique tag
|
||||
memo3, err := ts.Store.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-3",
|
||||
CreatorID: user.ID,
|
||||
Content: "Memo with #unique tag",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: []string{"unique"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo3)
|
||||
|
||||
// Test GetUserStats for the new tag
|
||||
response3, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: userName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response3)
|
||||
|
||||
// Check that the unique tag count is exactly 1
|
||||
require.Contains(t, response3.TagCount, "unique")
|
||||
require.Equal(t, int32(1), response3.TagCount["unique"], "New tag count should be 1 for first occurrence")
|
||||
|
||||
// The original test tag should still be 2
|
||||
require.Contains(t, response3.TagCount, "test")
|
||||
require.Equal(t, int32(2), response3.TagCount["test"], "Original tag count should remain 2")
|
||||
}
|
||||
Reference in New Issue
Block a user