first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
This commit is contained in:
42
server/router/api/v1/acl_config.go
Normal file
42
server/router/api/v1/acl_config.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package v1
|
||||
|
||||
// PublicMethods defines API endpoints that don't require authentication.
|
||||
// All other endpoints require a valid session or access token.
|
||||
//
|
||||
// This is the SINGLE SOURCE OF TRUTH for public endpoints.
|
||||
// Both Connect interceptor and gRPC-Gateway interceptor use this map.
|
||||
//
|
||||
// Format: Full gRPC procedure path as returned by req.Spec().Procedure (Connect)
|
||||
// or info.FullMethod (gRPC interceptor).
|
||||
var PublicMethods = map[string]struct{}{
|
||||
// Auth Service - login/token endpoints must be accessible without auth
|
||||
"/memos.api.v1.AuthService/SignIn": {},
|
||||
"/memos.api.v1.AuthService/RefreshToken": {}, // Token refresh uses cookie, must be accessible when access token expired
|
||||
|
||||
// Instance Service - needed before login to show instance info
|
||||
"/memos.api.v1.InstanceService/GetInstanceProfile": {},
|
||||
"/memos.api.v1.InstanceService/GetInstanceSetting": {},
|
||||
|
||||
// User Service - public user profiles and stats
|
||||
"/memos.api.v1.UserService/CreateUser": {}, // Allow first user registration
|
||||
"/memos.api.v1.UserService/GetUser": {},
|
||||
"/memos.api.v1.UserService/GetUserAvatar": {},
|
||||
"/memos.api.v1.UserService/GetUserStats": {},
|
||||
"/memos.api.v1.UserService/ListAllUserStats": {},
|
||||
"/memos.api.v1.UserService/SearchUsers": {},
|
||||
|
||||
// Identity Provider Service - SSO buttons on login page
|
||||
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": {},
|
||||
|
||||
// Memo Service - public memos (visibility filtering done in service layer)
|
||||
"/memos.api.v1.MemoService/GetMemo": {},
|
||||
"/memos.api.v1.MemoService/ListMemos": {},
|
||||
"/memos.api.v1.MemoService/ListMemoComments": {},
|
||||
}
|
||||
|
||||
// IsPublicMethod checks if a procedure path is public (no authentication required).
|
||||
// Returns true for public methods, false for protected methods.
|
||||
func IsPublicMethod(procedure string) bool {
|
||||
_, ok := PublicMethods[procedure]
|
||||
return ok
|
||||
}
|
||||
88
server/router/api/v1/acl_config_test.go
Normal file
88
server/router/api/v1/acl_config_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestPublicMethodsArePublic verifies that methods in PublicMethods are recognized as public.
|
||||
func TestPublicMethodsArePublic(t *testing.T) {
|
||||
publicMethods := []string{
|
||||
// Auth Service
|
||||
"/memos.api.v1.AuthService/SignIn",
|
||||
"/memos.api.v1.AuthService/RefreshToken",
|
||||
// Instance Service
|
||||
"/memos.api.v1.InstanceService/GetInstanceProfile",
|
||||
"/memos.api.v1.InstanceService/GetInstanceSetting",
|
||||
// User Service
|
||||
"/memos.api.v1.UserService/CreateUser",
|
||||
"/memos.api.v1.UserService/GetUser",
|
||||
"/memos.api.v1.UserService/GetUserAvatar",
|
||||
"/memos.api.v1.UserService/GetUserStats",
|
||||
"/memos.api.v1.UserService/ListAllUserStats",
|
||||
"/memos.api.v1.UserService/SearchUsers",
|
||||
// Identity Provider Service
|
||||
"/memos.api.v1.IdentityProviderService/ListIdentityProviders",
|
||||
// Memo Service
|
||||
"/memos.api.v1.MemoService/GetMemo",
|
||||
"/memos.api.v1.MemoService/ListMemos",
|
||||
}
|
||||
|
||||
for _, method := range publicMethods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
assert.True(t, IsPublicMethod(method), "Expected %s to be public", method)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProtectedMethodsRequireAuth verifies that non-public methods are recognized as protected.
|
||||
func TestProtectedMethodsRequireAuth(t *testing.T) {
|
||||
protectedMethods := []string{
|
||||
// Auth Service - logout and get current user require auth
|
||||
"/memos.api.v1.AuthService/SignOut",
|
||||
"/memos.api.v1.AuthService/GetCurrentUser",
|
||||
// Instance Service - admin operations
|
||||
"/memos.api.v1.InstanceService/UpdateInstanceSetting",
|
||||
// User Service - modification operations
|
||||
"/memos.api.v1.UserService/ListUsers",
|
||||
"/memos.api.v1.UserService/UpdateUser",
|
||||
"/memos.api.v1.UserService/DeleteUser",
|
||||
// Memo Service - write operations
|
||||
"/memos.api.v1.MemoService/CreateMemo",
|
||||
"/memos.api.v1.MemoService/UpdateMemo",
|
||||
"/memos.api.v1.MemoService/DeleteMemo",
|
||||
// Attachment Service - write operations
|
||||
"/memos.api.v1.AttachmentService/CreateAttachment",
|
||||
"/memos.api.v1.AttachmentService/DeleteAttachment",
|
||||
// Shortcut Service
|
||||
"/memos.api.v1.ShortcutService/CreateShortcut",
|
||||
"/memos.api.v1.ShortcutService/ListShortcuts",
|
||||
"/memos.api.v1.ShortcutService/UpdateShortcut",
|
||||
"/memos.api.v1.ShortcutService/DeleteShortcut",
|
||||
// Activity Service
|
||||
"/memos.api.v1.ActivityService/GetActivity",
|
||||
}
|
||||
|
||||
for _, method := range protectedMethods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
assert.False(t, IsPublicMethod(method), "Expected %s to require auth", method)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnknownMethodsRequireAuth verifies that unknown methods default to requiring auth.
|
||||
func TestUnknownMethodsRequireAuth(t *testing.T) {
|
||||
unknownMethods := []string{
|
||||
"/unknown.Service/Method",
|
||||
"/memos.api.v1.UnknownService/Method",
|
||||
"",
|
||||
"invalid",
|
||||
}
|
||||
|
||||
for _, method := range unknownMethods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
assert.False(t, IsPublicMethod(method), "Unknown method %q should require auth", method)
|
||||
})
|
||||
}
|
||||
}
|
||||
155
server/router/api/v1/activity_service.go
Normal file
155
server/router/api/v1/activity_service.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListActivities(ctx context.Context, request *v1pb.ListActivitiesRequest) (*v1pb.ListActivitiesResponse, error) {
|
||||
// Set default page size if not specified
|
||||
pageSize := request.PageSize
|
||||
if pageSize <= 0 || pageSize > 1000 {
|
||||
pageSize = 100
|
||||
}
|
||||
|
||||
// TODO: Implement pagination with page_token and use pageSize for limiting
|
||||
// For now, we'll fetch all activities and the pageSize will be used in future pagination implementation
|
||||
_ = pageSize // Acknowledge pageSize variable to avoid linter warning
|
||||
|
||||
activities, err := s.Store.ListActivities(ctx, &store.FindActivity{})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list activities: %v", err)
|
||||
}
|
||||
|
||||
var activityMessages []*v1pb.Activity
|
||||
for _, activity := range activities {
|
||||
activityMessage, err := s.convertActivityFromStore(ctx, activity)
|
||||
if err != nil {
|
||||
// Skip activities that reference deleted memos instead of failing the entire list
|
||||
continue
|
||||
}
|
||||
if activityMessage != nil {
|
||||
activityMessages = append(activityMessages, activityMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return &v1pb.ListActivitiesResponse{
|
||||
Activities: activityMessages,
|
||||
// TODO: Implement next_page_token for pagination
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetActivity(ctx context.Context, request *v1pb.GetActivityRequest) (*v1pb.Activity, error) {
|
||||
activityID, err := ExtractActivityIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid activity name: %v", err)
|
||||
}
|
||||
activity, err := s.Store.GetActivity(ctx, &store.FindActivity{
|
||||
ID: &activityID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get activity: %v", err)
|
||||
}
|
||||
|
||||
activityMessage, err := s.convertActivityFromStore(ctx, activity)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert activity from store: %v", err)
|
||||
}
|
||||
if activityMessage == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "activity references deleted content")
|
||||
}
|
||||
return activityMessage, nil
|
||||
}
|
||||
|
||||
// convertActivityFromStore converts a storage-layer activity to an API activity.
|
||||
// This handles the mapping between internal activity representation and the public API,
|
||||
// including proper type and level conversions.
|
||||
// Returns nil if the activity references deleted content (to allow graceful skipping).
|
||||
func (s *APIV1Service) convertActivityFromStore(ctx context.Context, activity *store.Activity) (*v1pb.Activity, error) {
|
||||
payload, err := s.convertActivityPayloadFromStore(ctx, activity.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Skip activities that reference deleted memos
|
||||
if payload == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Convert store activity type to proto enum
|
||||
var activityType v1pb.Activity_Type
|
||||
switch activity.Type {
|
||||
case store.ActivityTypeMemoComment:
|
||||
activityType = v1pb.Activity_MEMO_COMMENT
|
||||
default:
|
||||
activityType = v1pb.Activity_TYPE_UNSPECIFIED
|
||||
}
|
||||
|
||||
// Convert store activity level to proto enum
|
||||
var activityLevel v1pb.Activity_Level
|
||||
switch activity.Level {
|
||||
case store.ActivityLevelInfo:
|
||||
activityLevel = v1pb.Activity_INFO
|
||||
default:
|
||||
activityLevel = v1pb.Activity_LEVEL_UNSPECIFIED
|
||||
}
|
||||
|
||||
return &v1pb.Activity{
|
||||
Name: fmt.Sprintf("%s%d", ActivityNamePrefix, activity.ID),
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, activity.CreatorID),
|
||||
Type: activityType,
|
||||
Level: activityLevel,
|
||||
CreateTime: timestamppb.New(time.Unix(activity.CreatedTs, 0)),
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// convertActivityPayloadFromStore converts a storage-layer activity payload to an API payload.
|
||||
// This resolves references (e.g., memo IDs) to resource names for the API.
|
||||
// Returns nil if the activity references deleted content (to allow graceful skipping).
|
||||
func (s *APIV1Service) convertActivityPayloadFromStore(ctx context.Context, payload *storepb.ActivityPayload) (*v1pb.ActivityPayload, error) {
|
||||
v2Payload := &v1pb.ActivityPayload{}
|
||||
if payload.MemoComment != nil {
|
||||
// Fetch the comment memo
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &payload.MemoComment.MemoId,
|
||||
ExcludeContent: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
// If the comment memo was deleted, skip this activity gracefully
|
||||
if memo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Fetch the related memo (the one being commented on)
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &payload.MemoComment.RelatedMemoId,
|
||||
ExcludeContent: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get related memo: %v", err)
|
||||
}
|
||||
// If the related memo was deleted, skip this activity gracefully
|
||||
if relatedMemo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
v2Payload.Payload = &v1pb.ActivityPayload_MemoComment{
|
||||
MemoComment: &v1pb.ActivityMemoCommentPayload{
|
||||
Memo: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
|
||||
RelatedMemo: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
|
||||
},
|
||||
}
|
||||
}
|
||||
return v2Payload, nil
|
||||
}
|
||||
397
server/router/api/v1/ai_http.go
Normal file
397
server/router/api/v1/ai_http.go
Normal file
@@ -0,0 +1,397 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
groqai "github.com/usememos/memos/plugin/ai/groq"
|
||||
ollamaai "github.com/usememos/memos/plugin/ai/ollama"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const aiSettingsStoreKey = "ai_settings_v1"
|
||||
|
||||
// RegisterAIHTTPHandlers registers plain HTTP JSON handlers for the AI service.
|
||||
// Uses direct raw DB storage to avoid proto generation dependency.
|
||||
func (s *APIV1Service) RegisterAIHTTPHandlers(g *echo.Group) {
|
||||
g.POST("/api/ai/settings", s.handleGetAISettings)
|
||||
g.POST("/api/ai/settings/update", s.handleUpdateAISettings)
|
||||
g.POST("/api/ai/complete", s.handleGenerateCompletion)
|
||||
g.POST("/api/ai/test", s.handleTestAIProvider)
|
||||
g.POST("/api/ai/models/ollama", s.handleListOllamaModels)
|
||||
}
|
||||
|
||||
type aiSettingsJSON struct {
|
||||
Groq *groqSettingsJSON `json:"groq,omitempty"`
|
||||
Ollama *ollamaSettingsJSON `json:"ollama,omitempty"`
|
||||
}
|
||||
|
||||
type groqSettingsJSON struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey"`
|
||||
DefaultModel string `json:"defaultModel"`
|
||||
}
|
||||
|
||||
type ollamaSettingsJSON struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Host string `json:"host"`
|
||||
DefaultModel string `json:"defaultModel"`
|
||||
TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
|
||||
}
|
||||
|
||||
func defaultAISettings() *aiSettingsJSON {
|
||||
return &aiSettingsJSON{
|
||||
Groq: &groqSettingsJSON{DefaultModel: "llama-3.1-8b-instant"},
|
||||
Ollama: &ollamaSettingsJSON{Host: "http://localhost:11434", DefaultModel: "llama3", TimeoutSeconds: 120},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIV1Service) loadAISettings(ctx context.Context) (*aiSettingsJSON, error) {
|
||||
settings := defaultAISettings()
|
||||
list, err := s.Store.GetDriver().ListInstanceSettings(ctx, &store.FindInstanceSetting{Name: aiSettingsStoreKey})
|
||||
if err != nil || len(list) == 0 {
|
||||
return settings, nil
|
||||
}
|
||||
_ = json.Unmarshal([]byte(list[0].Value), settings)
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) saveAISettings(ctx context.Context, settings *aiSettingsJSON) error {
|
||||
b, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.Store.GetDriver().UpsertInstanceSetting(ctx, &store.InstanceSetting{
|
||||
Name: aiSettingsStoreKey,
|
||||
Value: string(b),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *APIV1Service) handleGetAISettings(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
if auth.GetUserID(ctx) == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
settings, err := s.loadAISettings(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, settings)
|
||||
}
|
||||
|
||||
func (s *APIV1Service) handleUpdateAISettings(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
if auth.GetUserID(ctx) == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
var req aiSettingsJSON
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||
}
|
||||
if err := s.saveAISettings(ctx, &req); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, req)
|
||||
}
|
||||
|
||||
type generateCompletionRequest struct {
|
||||
Provider int `json:"provider"`
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
MaxTokens int `json:"maxTokens"`
|
||||
AutoTag bool `json:"autoTag,omitempty"`
|
||||
SpellCheck bool `json:"spellCheck,omitempty"`
|
||||
}
|
||||
|
||||
type generateCompletionResponse struct {
|
||||
Text string `json:"text"`
|
||||
ModelUsed string `json:"modelUsed"`
|
||||
PromptTokens int `json:"promptTokens"`
|
||||
CompletionTokens int `json:"completionTokens"`
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
}
|
||||
|
||||
func (s *APIV1Service) handleGenerateCompletion(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
if auth.GetUserID(ctx) == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
var req generateCompletionRequest
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "prompt is required"})
|
||||
}
|
||||
|
||||
settings, err := s.loadAISettings(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
temp := float32(0.7)
|
||||
if req.Temperature > 0 {
|
||||
temp = float32(req.Temperature)
|
||||
}
|
||||
maxTokens := 1000
|
||||
if req.MaxTokens > 0 {
|
||||
maxTokens = req.MaxTokens
|
||||
}
|
||||
|
||||
switch req.Provider {
|
||||
case 1: // Groq
|
||||
if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Groq is not enabled or API key is missing. Configure it in Settings > AI."})
|
||||
}
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = settings.Groq.DefaultModel
|
||||
}
|
||||
if model == "" {
|
||||
model = "llama-3.1-8b-instant"
|
||||
}
|
||||
client := groqai.NewGroqClient(groqai.GroqConfig{
|
||||
APIKey: settings.Groq.APIKey,
|
||||
DefaultModel: model,
|
||||
})
|
||||
resp, err := client.GenerateCompletion(ctx, groqai.CompletionRequest{
|
||||
Model: model,
|
||||
Prompt: req.Prompt,
|
||||
Temperature: temp,
|
||||
MaxTokens: maxTokens,
|
||||
})
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Apply post-processing if requested
|
||||
text := resp.Text
|
||||
if req.SpellCheck {
|
||||
text = s.correctSpelling(ctx, text)
|
||||
}
|
||||
if req.AutoTag {
|
||||
text = s.addAutoTags(ctx, text)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, generateCompletionResponse{
|
||||
Text: text, ModelUsed: resp.Model,
|
||||
PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens,
|
||||
})
|
||||
|
||||
case 2: // Ollama
|
||||
if settings.Ollama == nil || !settings.Ollama.Enabled {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."})
|
||||
}
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = settings.Ollama.DefaultModel
|
||||
}
|
||||
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
|
||||
Host: settings.Ollama.Host,
|
||||
DefaultModel: model,
|
||||
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
|
||||
})
|
||||
resp, err := client.GenerateCompletion(ctx, ollamaai.CompletionRequest{
|
||||
Model: model,
|
||||
Prompt: req.Prompt,
|
||||
Temperature: temp,
|
||||
MaxTokens: maxTokens,
|
||||
})
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Apply post-processing if requested
|
||||
text := resp.Text
|
||||
if req.SpellCheck {
|
||||
text = s.correctSpelling(ctx, text)
|
||||
}
|
||||
if req.AutoTag {
|
||||
text = s.addAutoTags(ctx, text)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, generateCompletionResponse{
|
||||
Text: text, ModelUsed: resp.Model,
|
||||
PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens,
|
||||
})
|
||||
|
||||
default:
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider. Use 1 for Groq, 2 for Ollama"})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIV1Service) handleTestAIProvider(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
if auth.GetUserID(ctx) == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
var req struct {
|
||||
Provider int `json:"provider"`
|
||||
}
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||
}
|
||||
settings, err := s.loadAISettings(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
switch req.Provider {
|
||||
case 1:
|
||||
if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Groq not enabled or API key missing"})
|
||||
}
|
||||
client := groqai.NewGroqClient(groqai.GroqConfig{APIKey: settings.Groq.APIKey, DefaultModel: settings.Groq.DefaultModel})
|
||||
if err := client.TestConnection(ctx); err != nil {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Groq connected successfully"})
|
||||
case 2:
|
||||
if settings.Ollama == nil || !settings.Ollama.Enabled {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Ollama not enabled"})
|
||||
}
|
||||
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
|
||||
Host: settings.Ollama.Host,
|
||||
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
|
||||
})
|
||||
if err := client.TestConnection(ctx); err != nil {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Ollama connected successfully"})
|
||||
default:
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider"})
|
||||
}
|
||||
}
|
||||
|
||||
// correctSpelling applies basic spelling corrections to the text
|
||||
func (s *APIV1Service) correctSpelling(ctx context.Context, text string) string {
|
||||
// TODO: Implement actual spell checking using a spell checker library
|
||||
// For now, return the text as-is
|
||||
return text
|
||||
}
|
||||
|
||||
// addAutoTags analyzes the text and adds relevant #tags
|
||||
func (s *APIV1Service) addAutoTags(ctx context.Context, text string) string {
|
||||
// Common programming languages and technologies
|
||||
tagMap := map[string][]string{
|
||||
"javascript": {"javascript", "js"},
|
||||
"python": {"python", "py"},
|
||||
"java": {"java"},
|
||||
"go": {"go", "golang"},
|
||||
"rust": {"rust"},
|
||||
"dart": {"dart"},
|
||||
"flutter": {"flutter"},
|
||||
"react": {"react", "reactjs"},
|
||||
"vue": {"vue", "vuejs"},
|
||||
"angular": {"angular"},
|
||||
"node": {"nodejs", "node"},
|
||||
"express": {"express"},
|
||||
"mongodb": {"mongodb"},
|
||||
"postgresql": {"postgresql", "postgres"},
|
||||
"mysql": {"mysql"},
|
||||
"redis": {"redis"},
|
||||
"docker": {"docker"},
|
||||
"kubernetes": {"kubernetes", "k8s"},
|
||||
"aws": {"aws", "amazon"},
|
||||
"azure": {"azure"},
|
||||
"gcp": {"gcp", "google-cloud"},
|
||||
"git": {"git"},
|
||||
"github": {"github"},
|
||||
"gitlab": {"gitlab"},
|
||||
"ci/cd": {"ci-cd"},
|
||||
"testing": {"testing"},
|
||||
"api": {"api"},
|
||||
"rest": {"rest", "api"},
|
||||
"graphql": {"graphql"},
|
||||
"database": {"database", "db"},
|
||||
"frontend": {"frontend"},
|
||||
"backend": {"backend"},
|
||||
"fullstack": {"fullstack"},
|
||||
"mobile": {"mobile"},
|
||||
"web": {"web"},
|
||||
"devops": {"devops"},
|
||||
"security": {"security"},
|
||||
"machine learning": {"ml", "machine-learning"},
|
||||
"ai": {"ai", "artificial-intelligence"},
|
||||
"blockchain": {"blockchain"},
|
||||
"crypto": {"crypto", "cryptocurrency"},
|
||||
}
|
||||
|
||||
// Convert text to lowercase for matching
|
||||
lowerText := strings.ToLower(text)
|
||||
var tags []string
|
||||
tagSet := make(map[string]bool)
|
||||
|
||||
// Check for programming languages and technologies
|
||||
for keyword, tagList := range tagMap {
|
||||
if strings.Contains(lowerText, keyword) {
|
||||
for _, tag := range tagList {
|
||||
if !tagSet[tag] {
|
||||
tags = append(tags, tag)
|
||||
tagSet[tag] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add tags to the end of the text if any were found
|
||||
if len(tags) > 0 {
|
||||
tagString := "\n\n"
|
||||
for i, tag := range tags {
|
||||
if i > 0 {
|
||||
tagString += " "
|
||||
}
|
||||
tagString += "#" + tag
|
||||
}
|
||||
text += tagString
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
func (s *APIV1Service) handleListOllamaModels(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
if auth.GetUserID(ctx) == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
|
||||
settings, err := s.loadAISettings(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
if settings.Ollama == nil || !settings.Ollama.Enabled {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."})
|
||||
}
|
||||
|
||||
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
|
||||
Host: settings.Ollama.Host,
|
||||
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
|
||||
})
|
||||
|
||||
models, err := client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": fmt.Sprintf("failed to list Ollama models: %v", err)})
|
||||
}
|
||||
|
||||
// Convert to the same format as frontend expects
|
||||
var modelList []map[string]string
|
||||
for _, model := range models {
|
||||
modelList = append(modelList, map[string]string{
|
||||
"id": model,
|
||||
"name": model,
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"models": modelList,
|
||||
})
|
||||
}
|
||||
191
server/router/api/v1/attachment_exif_test.go
Normal file
191
server/router/api/v1/attachment_exif_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/jpeg"
|
||||
"testing"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestShouldStripExif(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mimeType string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "JPEG should strip EXIF",
|
||||
mimeType: "image/jpeg",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "JPG should strip EXIF",
|
||||
mimeType: "image/jpg",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "TIFF should strip EXIF",
|
||||
mimeType: "image/tiff",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "WebP should strip EXIF",
|
||||
mimeType: "image/webp",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "HEIC should strip EXIF",
|
||||
mimeType: "image/heic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "HEIF should strip EXIF",
|
||||
mimeType: "image/heif",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PNG should not strip EXIF",
|
||||
mimeType: "image/png",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "GIF should not strip EXIF",
|
||||
mimeType: "image/gif",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "text file should not strip EXIF",
|
||||
mimeType: "text/plain",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "PDF should not strip EXIF",
|
||||
mimeType: "application/pdf",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := shouldStripExif(tt.mimeType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripImageExif(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a simple test image
|
||||
img := image.NewRGBA(image.Rect(0, 0, 100, 100))
|
||||
// Fill with red color
|
||||
for y := 0; y < 100; y++ {
|
||||
for x := 0; x < 100; x++ {
|
||||
img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255})
|
||||
}
|
||||
}
|
||||
|
||||
// Encode as JPEG
|
||||
var buf bytes.Buffer
|
||||
err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 90})
|
||||
require.NoError(t, err)
|
||||
originalData := buf.Bytes()
|
||||
|
||||
t.Run("strip JPEG metadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strippedData, err := stripImageExif(originalData, "image/jpeg")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, strippedData)
|
||||
|
||||
// Verify it's still a valid image
|
||||
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, decodedImg.Bounds().Dx())
|
||||
assert.Equal(t, 100, decodedImg.Bounds().Dy())
|
||||
})
|
||||
|
||||
t.Run("strip JPG metadata (alternate extension)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strippedData, err := stripImageExif(originalData, "image/jpg")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, strippedData)
|
||||
|
||||
// Verify it's still a valid image
|
||||
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, decodedImg)
|
||||
})
|
||||
|
||||
t.Run("strip PNG metadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Encode as PNG first
|
||||
var pngBuf bytes.Buffer
|
||||
err := imaging.Encode(&pngBuf, img, imaging.PNG)
|
||||
require.NoError(t, err)
|
||||
|
||||
strippedData, err := stripImageExif(pngBuf.Bytes(), "image/png")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, strippedData)
|
||||
|
||||
// Verify it's still a valid image
|
||||
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, decodedImg.Bounds().Dx())
|
||||
assert.Equal(t, 100, decodedImg.Bounds().Dy())
|
||||
})
|
||||
|
||||
t.Run("handle WebP format by converting to JPEG", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// WebP format will be converted to JPEG
|
||||
strippedData, err := stripImageExif(originalData, "image/webp")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, strippedData)
|
||||
|
||||
// Verify it's a valid image
|
||||
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, decodedImg)
|
||||
})
|
||||
|
||||
t.Run("handle HEIC format by converting to JPEG", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strippedData, err := stripImageExif(originalData, "image/heic")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, strippedData)
|
||||
|
||||
// Verify it's a valid image
|
||||
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, decodedImg)
|
||||
})
|
||||
|
||||
t.Run("return error for invalid image data", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
invalidData := []byte("not an image")
|
||||
_, err := stripImageExif(invalidData, "image/jpeg")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to decode image")
|
||||
})
|
||||
|
||||
t.Run("return error for empty image data", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
emptyData := []byte{}
|
||||
_, err := stripImageExif(emptyData, "image/jpeg")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
646
server/router/api/v1/attachment_service.go
Normal file
646
server/router/api/v1/attachment_service.go
Normal file
@@ -0,0 +1,646 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/plugin/storage/s3"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
// The upload memory buffer is 32 MiB.
|
||||
// It should be kept low, so RAM usage doesn't get out of control.
|
||||
// This is unrelated to maximum upload size limit, which is now set through system setting.
|
||||
MaxUploadBufferSizeBytes = 32 << 20
|
||||
MebiByte = 1024 * 1024
|
||||
// ThumbnailCacheFolder is the folder name where the thumbnail images are stored.
|
||||
ThumbnailCacheFolder = ".thumbnail_cache"
|
||||
|
||||
// defaultJPEGQuality is the JPEG quality used when re-encoding images for EXIF stripping.
|
||||
// Quality 95 maintains visual quality while ensuring metadata is removed.
|
||||
defaultJPEGQuality = 95
|
||||
)
|
||||
|
||||
var SupportedThumbnailMimeTypes = []string{
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
}
|
||||
|
||||
// exifCapableImageTypes defines image formats that may contain EXIF metadata.
|
||||
// These formats will have their EXIF metadata stripped on upload for privacy.
|
||||
var exifCapableImageTypes = map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true,
|
||||
"image/tiff": true,
|
||||
"image/webp": true,
|
||||
"image/heic": true,
|
||||
"image/heif": true,
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.CreateAttachmentRequest) (*v1pb.Attachment, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if request.Attachment == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "attachment is required")
|
||||
}
|
||||
if request.Attachment.Filename == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "filename is required")
|
||||
}
|
||||
if !validateFilename(request.Attachment.Filename) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "filename contains invalid characters or format")
|
||||
}
|
||||
if request.Attachment.Type == "" {
|
||||
ext := filepath.Ext(request.Attachment.Filename)
|
||||
mimeType := mime.TypeByExtension(ext)
|
||||
if mimeType == "" {
|
||||
mimeType = http.DetectContentType(request.Attachment.Content)
|
||||
}
|
||||
// ParseMediaType to strip parameters
|
||||
mediaType, _, err := mime.ParseMediaType(mimeType)
|
||||
if err == nil {
|
||||
request.Attachment.Type = mediaType
|
||||
}
|
||||
}
|
||||
if request.Attachment.Type == "" {
|
||||
request.Attachment.Type = "application/octet-stream"
|
||||
}
|
||||
if !isValidMimeType(request.Attachment.Type) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid MIME type format")
|
||||
}
|
||||
|
||||
// Use provided attachment_id or generate a new one
|
||||
attachmentUID := request.AttachmentId
|
||||
if attachmentUID == "" {
|
||||
attachmentUID = shortuuid.New()
|
||||
}
|
||||
|
||||
create := &store.Attachment{
|
||||
UID: attachmentUID,
|
||||
CreatorID: user.ID,
|
||||
Filename: request.Attachment.Filename,
|
||||
Type: request.Attachment.Type,
|
||||
}
|
||||
|
||||
// No upload size limit - accept files of any size
|
||||
size := binary.Size(request.Attachment.Content)
|
||||
create.Size = int64(size)
|
||||
create.Blob = request.Attachment.Content
|
||||
|
||||
// Strip EXIF metadata from images for privacy protection.
|
||||
// This removes sensitive information like GPS location, device details, etc.
|
||||
if shouldStripExif(create.Type) {
|
||||
if strippedBlob, err := stripImageExif(create.Blob, create.Type); err != nil {
|
||||
// Log warning but continue with original image to ensure uploads don't fail.
|
||||
slog.Warn("failed to strip EXIF metadata from image",
|
||||
slog.String("type", create.Type),
|
||||
slog.String("filename", create.Filename),
|
||||
slog.String("error", err.Error()))
|
||||
} else {
|
||||
create.Blob = strippedBlob
|
||||
create.Size = int64(len(strippedBlob))
|
||||
}
|
||||
}
|
||||
|
||||
if err := SaveAttachmentBlob(ctx, s.Profile, s.Store, create); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err)
|
||||
}
|
||||
|
||||
if request.Attachment.Memo != nil {
|
||||
memoUID, err := ExtractMemoUIDFromName(*request.Attachment.Memo)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to find memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found: %s", *request.Attachment.Memo)
|
||||
}
|
||||
create.MemoID = &memo.ID
|
||||
}
|
||||
attachment, err := s.Store.CreateAttachment(ctx, create)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create attachment: %v", err)
|
||||
}
|
||||
|
||||
return convertAttachmentFromStore(attachment), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAttachmentsRequest) (*v1pb.ListAttachmentsResponse, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Set default page size
|
||||
pageSize := int(request.PageSize)
|
||||
if pageSize <= 0 {
|
||||
pageSize = 50
|
||||
}
|
||||
if pageSize > 1000 {
|
||||
pageSize = 1000
|
||||
}
|
||||
|
||||
// Parse page token for offset
|
||||
offset := 0
|
||||
if request.PageToken != "" {
|
||||
// Simple implementation: page token is the offset as string
|
||||
// In production, you might want to use encrypted tokens
|
||||
if parsed, err := fmt.Sscanf(request.PageToken, "%d", &offset); err != nil || parsed != 1 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid page token")
|
||||
}
|
||||
}
|
||||
|
||||
findAttachment := &store.FindAttachment{
|
||||
CreatorID: &user.ID,
|
||||
Limit: &pageSize,
|
||||
Offset: &offset,
|
||||
}
|
||||
|
||||
// Parse filter if provided
|
||||
if request.Filter != "" {
|
||||
if err := s.validateAttachmentFilter(ctx, request.Filter); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
findAttachment.Filters = append(findAttachment.Filters, request.Filter)
|
||||
}
|
||||
|
||||
attachments, err := s.Store.ListAttachments(ctx, findAttachment)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListAttachmentsResponse{}
|
||||
|
||||
for _, attachment := range attachments {
|
||||
response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment))
|
||||
}
|
||||
|
||||
// For simplicity, set total size to the number of returned attachments.
|
||||
// In a full implementation, you'd want a separate count query
|
||||
response.TotalSize = int32(len(response.Attachments))
|
||||
|
||||
// Set next page token if we got the full page size (indicating there might be more)
|
||||
if len(attachments) == pageSize {
|
||||
response.NextPageToken = fmt.Sprintf("%d", offset+pageSize)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttachmentRequest) (*v1pb.Attachment, error) {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
|
||||
}
|
||||
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
}
|
||||
if attachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found")
|
||||
}
|
||||
|
||||
// Check access permission based on linked memo visibility.
|
||||
if err := s.checkAttachmentAccess(ctx, attachment); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return convertAttachmentFromStore(attachment), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateAttachment(ctx context.Context, request *v1pb.UpdateAttachmentRequest) (*v1pb.Attachment, error) {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(request.Attachment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
|
||||
}
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
}
|
||||
if attachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found")
|
||||
}
|
||||
// Only the creator or admin can update the attachment.
|
||||
if attachment.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
currentTs := time.Now().Unix()
|
||||
update := &store.UpdateAttachment{
|
||||
ID: attachment.ID,
|
||||
UpdatedTs: ¤tTs,
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
if field == "filename" {
|
||||
if !validateFilename(request.Attachment.Filename) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "filename contains invalid characters or format")
|
||||
}
|
||||
update.Filename = &request.Attachment.Filename
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.Store.UpdateAttachment(ctx, update); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
|
||||
}
|
||||
return s.GetAttachment(ctx, &v1pb.GetAttachmentRequest{
|
||||
Name: request.Attachment.Name,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteAttachment(ctx context.Context, request *v1pb.DeleteAttachmentRequest) (*emptypb.Empty, error) {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
|
||||
}
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
|
||||
UID: &attachmentUID,
|
||||
CreatorID: &user.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to find attachment: %v", err)
|
||||
}
|
||||
if attachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found")
|
||||
}
|
||||
// Delete the attachment from the database.
|
||||
if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{
|
||||
ID: attachment.ID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete attachment: %v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func convertAttachmentFromStore(attachment *store.Attachment) *v1pb.Attachment {
|
||||
attachmentMessage := &v1pb.Attachment{
|
||||
Name: fmt.Sprintf("%s%s", AttachmentNamePrefix, attachment.UID),
|
||||
CreateTime: timestamppb.New(time.Unix(attachment.CreatedTs, 0)),
|
||||
Filename: attachment.Filename,
|
||||
Type: attachment.Type,
|
||||
Size: attachment.Size,
|
||||
}
|
||||
if attachment.MemoUID != nil && *attachment.MemoUID != "" {
|
||||
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, *attachment.MemoUID)
|
||||
attachmentMessage.Memo = &memoName
|
||||
}
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 {
|
||||
attachmentMessage.ExternalLink = attachment.Reference
|
||||
}
|
||||
|
||||
return attachmentMessage
|
||||
}
|
||||
|
||||
// SaveAttachmentBlob save the blob of attachment based on the storage config.
|
||||
func SaveAttachmentBlob(ctx context.Context, profile *profile.Profile, stores *store.Store, create *store.Attachment) error {
|
||||
instanceStorageSetting, err := stores.GetInstanceStorageSetting(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to find instance storage setting")
|
||||
}
|
||||
|
||||
if instanceStorageSetting.StorageType == storepb.InstanceStorageSetting_LOCAL {
|
||||
filepathTemplate := "assets/{timestamp}_{filename}"
|
||||
if instanceStorageSetting.FilepathTemplate != "" {
|
||||
filepathTemplate = instanceStorageSetting.FilepathTemplate
|
||||
}
|
||||
|
||||
internalPath := filepathTemplate
|
||||
if !strings.Contains(internalPath, "{filename}") {
|
||||
internalPath = filepath.Join(internalPath, "{filename}")
|
||||
}
|
||||
internalPath = replaceFilenameWithPathTemplate(internalPath, create.Filename)
|
||||
internalPath = filepath.ToSlash(internalPath)
|
||||
|
||||
// Ensure the directory exists.
|
||||
osPath := filepath.FromSlash(internalPath)
|
||||
if !filepath.IsAbs(osPath) {
|
||||
osPath = filepath.Join(profile.Data, osPath)
|
||||
}
|
||||
dir := filepath.Dir(osPath)
|
||||
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
|
||||
return errors.Wrap(err, "Failed to create directory")
|
||||
}
|
||||
|
||||
// Write the blob to the file.
|
||||
if err := os.WriteFile(osPath, create.Blob, 0644); err != nil {
|
||||
return errors.Wrap(err, "Failed to write file")
|
||||
}
|
||||
create.Reference = internalPath
|
||||
create.Blob = nil
|
||||
create.StorageType = storepb.AttachmentStorageType_LOCAL
|
||||
} else if instanceStorageSetting.StorageType == storepb.InstanceStorageSetting_S3 {
|
||||
s3Config := instanceStorageSetting.S3Config
|
||||
if s3Config == nil {
|
||||
return errors.Errorf("No activated external storage found")
|
||||
}
|
||||
s3Client, err := s3.NewClient(ctx, s3Config)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to create s3 client")
|
||||
}
|
||||
|
||||
filepathTemplate := instanceStorageSetting.FilepathTemplate
|
||||
if !strings.Contains(filepathTemplate, "{filename}") {
|
||||
filepathTemplate = filepath.Join(filepathTemplate, "{filename}")
|
||||
}
|
||||
filepathTemplate = replaceFilenameWithPathTemplate(filepathTemplate, create.Filename)
|
||||
key, err := s3Client.UploadObject(ctx, filepathTemplate, create.Type, bytes.NewReader(create.Blob))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to upload via s3 client")
|
||||
}
|
||||
presignURL, err := s3Client.PresignGetObject(ctx, key)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to presign via s3 client")
|
||||
}
|
||||
|
||||
create.Reference = presignURL
|
||||
create.Blob = nil
|
||||
create.StorageType = storepb.AttachmentStorageType_S3
|
||||
create.Payload = &storepb.AttachmentPayload{
|
||||
Payload: &storepb.AttachmentPayload_S3Object_{
|
||||
S3Object: &storepb.AttachmentPayload_S3Object{
|
||||
S3Config: s3Config,
|
||||
Key: key,
|
||||
LastPresignedTime: timestamppb.New(time.Now()),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
|
||||
// For local storage, read the file from the local disk.
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_LOCAL {
|
||||
attachmentPath := filepath.FromSlash(attachment.Reference)
|
||||
if !filepath.IsAbs(attachmentPath) {
|
||||
attachmentPath = filepath.Join(s.Profile.Data, attachmentPath)
|
||||
}
|
||||
|
||||
file, err := os.Open(attachmentPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, "file not found")
|
||||
}
|
||||
return nil, errors.Wrap(err, "failed to open the file")
|
||||
}
|
||||
defer file.Close()
|
||||
blob, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read the file")
|
||||
}
|
||||
return blob, nil
|
||||
}
|
||||
// For S3 storage, download the file from S3.
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_S3 {
|
||||
if attachment.Payload == nil {
|
||||
return nil, errors.New("attachment payload is missing")
|
||||
}
|
||||
s3Object := attachment.Payload.GetS3Object()
|
||||
if s3Object == nil {
|
||||
return nil, errors.New("S3 object payload is missing")
|
||||
}
|
||||
if s3Object.S3Config == nil {
|
||||
return nil, errors.New("S3 config is missing")
|
||||
}
|
||||
if s3Object.Key == "" {
|
||||
return nil, errors.New("S3 object key is missing")
|
||||
}
|
||||
|
||||
s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create S3 client")
|
||||
}
|
||||
|
||||
blob, err := s3Client.GetObject(context.Background(), s3Object.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get object from S3")
|
||||
}
|
||||
return blob, nil
|
||||
}
|
||||
// For database storage, return the blob from the database.
|
||||
return attachment.Blob, nil
|
||||
}
|
||||
|
||||
var fileKeyPattern = regexp.MustCompile(`\{[a-z]{1,9}\}`)
|
||||
|
||||
func replaceFilenameWithPathTemplate(path, filename string) string {
|
||||
t := time.Now()
|
||||
path = fileKeyPattern.ReplaceAllStringFunc(path, func(s string) string {
|
||||
switch s {
|
||||
case "{filename}":
|
||||
return filename
|
||||
case "{timestamp}":
|
||||
return fmt.Sprintf("%d", t.Unix())
|
||||
case "{year}":
|
||||
return fmt.Sprintf("%d", t.Year())
|
||||
case "{month}":
|
||||
return fmt.Sprintf("%02d", t.Month())
|
||||
case "{day}":
|
||||
return fmt.Sprintf("%02d", t.Day())
|
||||
case "{hour}":
|
||||
return fmt.Sprintf("%02d", t.Hour())
|
||||
case "{minute}":
|
||||
return fmt.Sprintf("%02d", t.Minute())
|
||||
case "{second}":
|
||||
return fmt.Sprintf("%02d", t.Second())
|
||||
case "{uuid}":
|
||||
return util.GenUUID()
|
||||
default:
|
||||
return s
|
||||
}
|
||||
})
|
||||
return path
|
||||
}
|
||||
|
||||
func validateFilename(filename string) bool {
|
||||
// Reject path traversal attempts and make sure no additional directories are created
|
||||
if !filepath.IsLocal(filename) || strings.ContainsAny(filename, "/\\") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Reject filenames starting or ending with spaces or periods
|
||||
if strings.HasPrefix(filename, " ") || strings.HasSuffix(filename, " ") ||
|
||||
strings.HasPrefix(filename, ".") || strings.HasSuffix(filename, ".") {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidMimeType(mimeType string) bool {
|
||||
// Reject empty or excessively long MIME types
|
||||
if mimeType == "" || len(mimeType) > 255 {
|
||||
return false
|
||||
}
|
||||
|
||||
// MIME type must match the pattern: type/subtype
|
||||
// Allow common characters in MIME types per RFC 2045
|
||||
matched, _ := regexp.MatchString(`^[a-zA-Z0-9][a-zA-Z0-9!#$&^_.+-]{0,126}/[a-zA-Z0-9][a-zA-Z0-9!#$&^_.+-]{0,126}$`, mimeType)
|
||||
return matched
|
||||
}
|
||||
|
||||
func (s *APIV1Service) validateAttachmentFilter(ctx context.Context, filterStr string) error {
|
||||
if filterStr == "" {
|
||||
return errors.New("filter cannot be empty")
|
||||
}
|
||||
|
||||
engine, err := filter.DefaultAttachmentEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var dialect filter.DialectName
|
||||
switch s.Profile.Driver {
|
||||
case "mysql":
|
||||
dialect = filter.DialectMySQL
|
||||
case "postgres":
|
||||
dialect = filter.DialectPostgres
|
||||
default:
|
||||
dialect = filter.DialectSQLite
|
||||
}
|
||||
|
||||
if _, err := engine.CompileToStatement(ctx, filterStr, filter.RenderOptions{Dialect: dialect}); err != nil {
|
||||
return errors.Wrap(err, "failed to compile filter")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAttachmentAccess verifies the user has permission to access the attachment.
|
||||
// For unlinked attachments (no memo), only the creator can access.
|
||||
// For linked attachments, access follows the memo's visibility rules.
|
||||
func (s *APIV1Service) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment) error {
|
||||
user, _ := s.fetchCurrentUser(ctx)
|
||||
|
||||
// For unlinked attachments, only the creator can access.
|
||||
if attachment.MemoID == nil {
|
||||
if user == nil {
|
||||
return status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if attachment.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For linked attachments, check memo visibility.
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
if memo.Visibility == store.Public {
|
||||
return nil
|
||||
}
|
||||
if user == nil {
|
||||
return status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldStripExif checks if the MIME type is an image format that may contain EXIF metadata.
|
||||
// Returns true for formats like JPEG, TIFF, WebP, HEIC, and HEIF which commonly contain
|
||||
// privacy-sensitive metadata such as GPS coordinates, camera settings, and device information.
|
||||
func shouldStripExif(mimeType string) bool {
|
||||
return exifCapableImageTypes[mimeType]
|
||||
}
|
||||
|
||||
// stripImageExif removes EXIF metadata from image files by decoding and re-encoding them.
|
||||
// This prevents exposure of sensitive metadata such as GPS location, camera details, and timestamps.
|
||||
//
|
||||
// The function preserves the correct image orientation by applying EXIF orientation tags
|
||||
// during decoding before stripping all metadata. Images are re-encoded with high quality
|
||||
// to minimize visual degradation.
|
||||
//
|
||||
// Supported formats:
|
||||
// - JPEG/JPG: Re-encoded as JPEG with quality 95
|
||||
// - PNG: Re-encoded as PNG (lossless)
|
||||
// - TIFF/WebP/HEIC/HEIF: Re-encoded as JPEG with quality 95
|
||||
//
|
||||
// Returns the cleaned image data without any EXIF metadata, or an error if processing fails.
|
||||
func stripImageExif(imageData []byte, mimeType string) ([]byte, error) {
|
||||
// Decode image with automatic EXIF orientation correction.
|
||||
// This ensures the image displays correctly after metadata removal.
|
||||
img, err := imaging.Decode(bytes.NewReader(imageData), imaging.AutoOrientation(true))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to decode image")
|
||||
}
|
||||
|
||||
// Re-encode the image without EXIF metadata.
|
||||
var buf bytes.Buffer
|
||||
var encodeErr error
|
||||
|
||||
if mimeType == "image/png" {
|
||||
// Preserve PNG format for lossless encoding
|
||||
encodeErr = imaging.Encode(&buf, img, imaging.PNG)
|
||||
} else {
|
||||
// For JPEG, TIFF, WebP, HEIC, HEIF - re-encode as JPEG.
|
||||
// This ensures EXIF is stripped and provides good compression.
|
||||
encodeErr = imaging.Encode(&buf, img, imaging.JPEG, imaging.JPEGQuality(defaultJPEGQuality))
|
||||
}
|
||||
|
||||
if encodeErr != nil {
|
||||
return nil, errors.Wrap(encodeErr, "failed to encode image")
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
612
server/router/api/v1/auth_service.go
Normal file
612
server/router/api/v1/auth_service.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
"github.com/usememos/memos/plugin/idp/oauth2"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
unmatchedUsernameAndPasswordError = "unmatched username and password"
|
||||
)
|
||||
|
||||
// GetCurrentUser returns the authenticated user's information.
|
||||
// Validates the access token and returns user details.
|
||||
//
|
||||
// Authentication: Required (access token).
|
||||
// Returns: User information.
|
||||
func (s *APIV1Service) GetCurrentUser(ctx context.Context, _ *v1pb.GetCurrentUserRequest) (*v1pb.GetCurrentUserResponse, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
// Clear auth cookies
|
||||
if err := s.clearAuthCookies(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies: %v", err)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not found")
|
||||
}
|
||||
|
||||
return &v1pb.GetCurrentUserResponse{
|
||||
User: convertUserFromStore(user),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SignIn authenticates a user with credentials and returns tokens.
|
||||
// On success, returns an access token and sets a refresh token cookie.
|
||||
//
|
||||
// Supports two authentication methods:
|
||||
// 1. Password-based authentication (username + password).
|
||||
// 2. SSO authentication (OAuth2 authorization code).
|
||||
//
|
||||
// Authentication: Not required (public endpoint).
|
||||
// Returns: User info, access token, and token expiry.
|
||||
func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.SignInResponse, error) {
|
||||
var existingUser *store.User
|
||||
|
||||
// Authentication Method 1: Password-based authentication
|
||||
if passwordCredentials := request.GetPasswordCredentials(); passwordCredentials != nil {
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &passwordCredentials.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
|
||||
}
|
||||
// Compare the stored hashed password, with the hashed version of the password that was received.
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(passwordCredentials.Password)); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
|
||||
}
|
||||
instanceGeneralSetting, err := s.Store.GetInstanceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance general setting, error: %v", err)
|
||||
}
|
||||
// Check if the password auth in is allowed.
|
||||
if instanceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed")
|
||||
}
|
||||
existingUser = user
|
||||
} else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil {
|
||||
// Authentication Method 2: SSO (OAuth2) authentication
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &ssoCredentials.IdpId,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "identity provider not found")
|
||||
}
|
||||
|
||||
var userInfo *idp.IdentityProviderUserInfo
|
||||
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
|
||||
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
|
||||
}
|
||||
// Pass code_verifier for PKCE support (empty string if not provided for backward compatibility)
|
||||
token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code, ssoCredentials.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
|
||||
}
|
||||
userInfo, err = oauth2IdentityProvider.UserInfo(token)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user info, error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
identifierFilter := identityProvider.IdentifierFilter
|
||||
if identifierFilter != "" {
|
||||
identifierFilterRegex, err := regexp.Compile(identifierFilter)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to compile identifier filter regex, error: %v", err)
|
||||
}
|
||||
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "identifier %s is not allowed", userInfo.Identifier)
|
||||
}
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &userInfo.Identifier,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
// Check if the user is allowed to sign up.
|
||||
instanceGeneralSetting, err := s.Store.GetInstanceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance general setting, error: %v", err)
|
||||
}
|
||||
if instanceGeneralSetting.DisallowUserRegistration {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
|
||||
}
|
||||
|
||||
// Create a new user with the user info from the identity provider.
|
||||
userCreate := &store.User{
|
||||
Username: userInfo.Identifier,
|
||||
// The new signup user should be normal user by default.
|
||||
Role: store.RoleUser,
|
||||
Nickname: userInfo.DisplayName,
|
||||
Email: userInfo.Email,
|
||||
AvatarURL: userInfo.AvatarURL,
|
||||
}
|
||||
password, err := util.RandomString(20)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to generate random password, error: %v", err)
|
||||
}
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err)
|
||||
}
|
||||
userCreate.PasswordHash = string(passwordHash)
|
||||
user, err = s.Store.CreateUser(ctx, userCreate)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err)
|
||||
}
|
||||
}
|
||||
existingUser = user
|
||||
}
|
||||
|
||||
if existingUser == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid credentials")
|
||||
}
|
||||
if existingUser.RowStatus == store.Archived {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", existingUser.Username)
|
||||
}
|
||||
|
||||
accessToken, accessExpiresAt, err := s.doSignIn(ctx, existingUser)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to sign in: %v", err)
|
||||
}
|
||||
|
||||
return &v1pb.SignInResponse{
|
||||
User: convertUserFromStore(existingUser),
|
||||
AccessToken: accessToken,
|
||||
AccessTokenExpiresAt: timestamppb.New(accessExpiresAt),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// doSignIn performs the actual sign-in operation by creating a session and setting the cookie.
|
||||
//
|
||||
// This function:
|
||||
// 1. Generates refresh token and access token.
|
||||
// 2. Stores refresh token metadata in user_setting.
|
||||
// 3. Sets refresh token as HttpOnly cookie.
|
||||
// 4. Returns access token and its expiry time.
|
||||
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User) (string, time.Time, error) {
|
||||
// Generate refresh token
|
||||
tokenID := util.GenUUID()
|
||||
refreshToken, refreshExpiresAt, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(s.Secret))
|
||||
if err != nil {
|
||||
return "", time.Time{}, status.Errorf(codes.Internal, "failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
// Store refresh token metadata
|
||||
clientInfo := s.extractClientInfo(ctx)
|
||||
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(refreshExpiresAt),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
ClientInfo: clientInfo,
|
||||
}
|
||||
if err := s.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord); err != nil {
|
||||
slog.Error("failed to store refresh token", "error", err)
|
||||
}
|
||||
|
||||
// Set refresh token cookie
|
||||
refreshCookie := s.buildRefreshTokenCookie(ctx, refreshToken, refreshExpiresAt)
|
||||
if err := SetResponseHeader(ctx, "Set-Cookie", refreshCookie); err != nil {
|
||||
return "", time.Time{}, status.Errorf(codes.Internal, "failed to set refresh token cookie: %v", err)
|
||||
}
|
||||
|
||||
// Generate access token
|
||||
accessToken, accessExpiresAt, err := auth.GenerateAccessTokenV2(
|
||||
user.ID,
|
||||
user.Username,
|
||||
string(user.Role),
|
||||
string(user.RowStatus),
|
||||
[]byte(s.Secret),
|
||||
)
|
||||
if err != nil {
|
||||
return "", time.Time{}, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
|
||||
}
|
||||
|
||||
return accessToken, accessExpiresAt, nil
|
||||
}
|
||||
|
||||
// SignOut terminates the user's authentication.
|
||||
// Revokes the refresh token and clears the authentication cookie.
|
||||
//
|
||||
// Authentication: Required (access token).
|
||||
// Returns: Empty response on success.
|
||||
func (s *APIV1Service) SignOut(ctx context.Context, _ *v1pb.SignOutRequest) (*emptypb.Empty, error) {
|
||||
// Get user from access token claims
|
||||
claims := auth.GetUserClaims(ctx)
|
||||
if claims != nil {
|
||||
// Revoke refresh token if we can identify it
|
||||
refreshToken := ""
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
if cookies := md.Get("cookie"); len(cookies) > 0 {
|
||||
refreshToken = auth.ExtractRefreshTokenFromCookie(cookies[0])
|
||||
}
|
||||
}
|
||||
if refreshToken != "" {
|
||||
refreshClaims, err := auth.ParseRefreshToken(refreshToken, []byte(s.Secret))
|
||||
if err == nil {
|
||||
// Remove refresh token from user_setting by token_id
|
||||
_ = s.Store.RemoveUserRefreshToken(ctx, claims.UserID, refreshClaims.TokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear refresh token cookie
|
||||
if err := s.clearAuthCookies(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies, error: %v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// RefreshToken exchanges a valid refresh token for a new access token.
|
||||
//
|
||||
// This endpoint implements refresh token rotation with sliding window sessions:
|
||||
// 1. Extracts the refresh token from the HttpOnly cookie (memos_refresh)
|
||||
// 2. Validates the refresh token against the database (checking expiry and revocation)
|
||||
// 3. Rotates the refresh token: generates a new one with fresh 30-day expiry
|
||||
// 4. Generates a new short-lived access token (15 minutes)
|
||||
// 5. Sets the new refresh token as HttpOnly cookie
|
||||
// 6. Returns the new access token and its expiry time
|
||||
//
|
||||
// Token rotation provides:
|
||||
// - Sliding window sessions: active users stay logged in indefinitely
|
||||
// - Better security: stolen refresh tokens become invalid after legitimate refresh
|
||||
//
|
||||
// Authentication: Requires valid refresh token in cookie (public endpoint)
|
||||
// Returns: New access token and expiry timestamp.
|
||||
func (s *APIV1Service) RefreshToken(ctx context.Context, _ *v1pb.RefreshTokenRequest) (*v1pb.RefreshTokenResponse, error) {
|
||||
// Extract refresh token from cookie
|
||||
refreshToken := ""
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
if cookies := md.Get("cookie"); len(cookies) > 0 {
|
||||
refreshToken = auth.ExtractRefreshTokenFromCookie(cookies[0])
|
||||
}
|
||||
}
|
||||
|
||||
if refreshToken == "" {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "refresh token not found")
|
||||
}
|
||||
|
||||
// Validate refresh token and get old token ID for rotation
|
||||
authenticator := auth.NewAuthenticator(s.Store, s.Secret)
|
||||
user, oldTokenID, err := authenticator.AuthenticateByRefreshToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid refresh token: %v", err)
|
||||
}
|
||||
|
||||
// --- Refresh Token Rotation ---
|
||||
// Generate new refresh token with fresh 30-day expiry (sliding window)
|
||||
newTokenID := util.GenUUID()
|
||||
newRefreshToken, newRefreshExpiresAt, err := auth.GenerateRefreshToken(user.ID, newTokenID, []byte(s.Secret))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
// Store new refresh token (add before remove to handle race conditions)
|
||||
clientInfo := s.extractClientInfo(ctx)
|
||||
newRefreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: newTokenID,
|
||||
ExpiresAt: timestamppb.New(newRefreshExpiresAt),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
ClientInfo: clientInfo,
|
||||
}
|
||||
if err := s.Store.AddUserRefreshToken(ctx, user.ID, newRefreshTokenRecord); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to store refresh token: %v", err)
|
||||
}
|
||||
|
||||
// Remove old refresh token
|
||||
if err := s.Store.RemoveUserRefreshToken(ctx, user.ID, oldTokenID); err != nil {
|
||||
// Log but don't fail - old token will expire naturally
|
||||
slog.Warn("failed to remove old refresh token", "error", err, "userID", user.ID, "tokenID", oldTokenID)
|
||||
}
|
||||
|
||||
// Set new refresh token cookie
|
||||
newRefreshCookie := s.buildRefreshTokenCookie(ctx, newRefreshToken, newRefreshExpiresAt)
|
||||
if err := SetResponseHeader(ctx, "Set-Cookie", newRefreshCookie); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to set refresh token cookie: %v", err)
|
||||
}
|
||||
// --- End Rotation ---
|
||||
|
||||
// Generate new access token
|
||||
accessToken, expiresAt, err := auth.GenerateAccessTokenV2(
|
||||
user.ID,
|
||||
user.Username,
|
||||
string(user.Role),
|
||||
string(user.RowStatus),
|
||||
[]byte(s.Secret),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
|
||||
}
|
||||
|
||||
return &v1pb.RefreshTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
ExpiresAt: timestamppb.New(expiresAt),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) clearAuthCookies(ctx context.Context) error {
|
||||
// Clear refresh token cookie
|
||||
refreshCookie := s.buildRefreshTokenCookie(ctx, "", time.Time{})
|
||||
if err := SetResponseHeader(ctx, "Set-Cookie", refreshCookie); err != nil {
|
||||
return errors.Wrap(err, "failed to set refresh cookie")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken string, expireTime time.Time) string {
|
||||
attrs := []string{
|
||||
fmt.Sprintf("%s=%s", auth.RefreshTokenCookieName, refreshToken),
|
||||
"Path=/",
|
||||
"HttpOnly",
|
||||
}
|
||||
if expireTime.IsZero() {
|
||||
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
|
||||
} else {
|
||||
// RFC 6265 requires cookie expiration dates to use GMT timezone
|
||||
// Convert to UTC and format with explicit "GMT" to ensure browser compatibility
|
||||
attrs = append(attrs, "Expires="+expireTime.UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT"))
|
||||
}
|
||||
|
||||
// Try to determine if the request is HTTPS by checking the origin header
|
||||
// Default to non-HTTPS (Lax SameSite) if metadata is not available
|
||||
isHTTPS := false
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
for _, v := range md.Get("origin") {
|
||||
if strings.HasPrefix(v, "https://") {
|
||||
isHTTPS = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isHTTPS {
|
||||
attrs = append(attrs, "SameSite=Lax", "Secure")
|
||||
} else {
|
||||
attrs = append(attrs, "SameSite=Lax")
|
||||
}
|
||||
return strings.Join(attrs, "; ")
|
||||
}
|
||||
|
||||
func (s *APIV1Service) fetchCurrentUser(ctx context.Context) (*store.User, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
if userID == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.Errorf("user %d not found", userID)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// extractClientInfo extracts comprehensive client information from the request context.
|
||||
//
|
||||
// This function parses metadata from the gRPC context to extract:
|
||||
// - User Agent: Raw user agent string for detailed parsing
|
||||
// - IP Address: Client IP from X-Forwarded-For or X-Real-IP headers
|
||||
// - Device Type: "mobile", "tablet", or "desktop" (parsed from user agent)
|
||||
// - Operating System: OS name and version (e.g., "iOS 17.1", "Windows 10/11")
|
||||
// - Browser: Browser name and version (e.g., "Chrome 120.0.0.0")
|
||||
//
|
||||
// This information enables users to:
|
||||
// - See all active sessions with device details
|
||||
// - Identify suspicious login attempts
|
||||
// - Revoke specific sessions from unknown devices.
|
||||
func (s *APIV1Service) extractClientInfo(ctx context.Context) *storepb.RefreshTokensUserSetting_ClientInfo {
|
||||
clientInfo := &storepb.RefreshTokensUserSetting_ClientInfo{}
|
||||
|
||||
// Extract user agent from metadata if available
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
if userAgents := md.Get("user-agent"); len(userAgents) > 0 {
|
||||
userAgent := userAgents[0]
|
||||
clientInfo.UserAgent = userAgent
|
||||
|
||||
// Parse user agent to extract device type, OS, browser info
|
||||
s.parseUserAgent(userAgent, clientInfo)
|
||||
}
|
||||
if forwardedFor := md.Get("x-forwarded-for"); len(forwardedFor) > 0 {
|
||||
ipAddress := strings.Split(forwardedFor[0], ",")[0] // Get the first IP in case of multiple
|
||||
ipAddress = strings.TrimSpace(ipAddress)
|
||||
clientInfo.IpAddress = ipAddress
|
||||
} else if realIP := md.Get("x-real-ip"); len(realIP) > 0 {
|
||||
clientInfo.IpAddress = realIP[0]
|
||||
}
|
||||
}
|
||||
|
||||
return clientInfo
|
||||
}
|
||||
|
||||
// parseUserAgent extracts device type, OS, and browser information from user agent string.
|
||||
//
|
||||
// Detection logic:
|
||||
// - Device Type: Checks for keywords like "mobile", "tablet", "ipad"
|
||||
// - OS: Pattern matches for iOS, Android, Windows, macOS, Linux, Chrome OS
|
||||
// - Browser: Identifies Edge, Chrome, Firefox, Safari, Opera
|
||||
//
|
||||
// Note: This is a simplified parser. For production use with high accuracy requirements,
|
||||
// consider using a dedicated user agent parsing library.
|
||||
func (*APIV1Service) parseUserAgent(userAgent string, clientInfo *storepb.RefreshTokensUserSetting_ClientInfo) {
|
||||
if userAgent == "" {
|
||||
return
|
||||
}
|
||||
|
||||
userAgent = strings.ToLower(userAgent)
|
||||
|
||||
// Detect device type
|
||||
if strings.Contains(userAgent, "ipad") || strings.Contains(userAgent, "tablet") {
|
||||
clientInfo.DeviceType = "tablet"
|
||||
} else if strings.Contains(userAgent, "mobile") || strings.Contains(userAgent, "android") ||
|
||||
strings.Contains(userAgent, "iphone") || strings.Contains(userAgent, "ipod") ||
|
||||
strings.Contains(userAgent, "windows phone") || strings.Contains(userAgent, "blackberry") {
|
||||
clientInfo.DeviceType = "mobile"
|
||||
} else {
|
||||
clientInfo.DeviceType = "desktop"
|
||||
}
|
||||
|
||||
// Detect operating system
|
||||
if strings.Contains(userAgent, "iphone os") || strings.Contains(userAgent, "cpu os") {
|
||||
// Extract iOS version
|
||||
if idx := strings.Index(userAgent, "cpu os "); idx != -1 {
|
||||
versionStart := idx + 7
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd != -1 {
|
||||
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
|
||||
clientInfo.Os = "iOS " + version
|
||||
} else {
|
||||
clientInfo.Os = "iOS"
|
||||
}
|
||||
} else if idx := strings.Index(userAgent, "iphone os "); idx != -1 {
|
||||
versionStart := idx + 10
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd != -1 {
|
||||
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
|
||||
clientInfo.Os = "iOS " + version
|
||||
} else {
|
||||
clientInfo.Os = "iOS"
|
||||
}
|
||||
} else {
|
||||
clientInfo.Os = "iOS"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "android") {
|
||||
// Extract Android version
|
||||
if idx := strings.Index(userAgent, "android "); idx != -1 {
|
||||
versionStart := idx + 8
|
||||
versionEnd := strings.Index(userAgent[versionStart:], ";")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = strings.Index(userAgent[versionStart:], ")")
|
||||
}
|
||||
if versionEnd != -1 {
|
||||
version := userAgent[versionStart : versionStart+versionEnd]
|
||||
clientInfo.Os = "Android " + version
|
||||
} else {
|
||||
clientInfo.Os = "Android"
|
||||
}
|
||||
} else {
|
||||
clientInfo.Os = "Android"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "windows nt 10.0") {
|
||||
clientInfo.Os = "Windows 10/11"
|
||||
} else if strings.Contains(userAgent, "windows nt 6.3") {
|
||||
clientInfo.Os = "Windows 8.1"
|
||||
} else if strings.Contains(userAgent, "windows nt 6.1") {
|
||||
clientInfo.Os = "Windows 7"
|
||||
} else if strings.Contains(userAgent, "windows") {
|
||||
clientInfo.Os = "Windows"
|
||||
} else if strings.Contains(userAgent, "mac os x") {
|
||||
// Extract macOS version
|
||||
if idx := strings.Index(userAgent, "mac os x "); idx != -1 {
|
||||
versionStart := idx + 9
|
||||
versionEnd := strings.Index(userAgent[versionStart:], ";")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = strings.Index(userAgent[versionStart:], ")")
|
||||
}
|
||||
if versionEnd != -1 {
|
||||
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
|
||||
clientInfo.Os = "macOS " + version
|
||||
} else {
|
||||
clientInfo.Os = "macOS"
|
||||
}
|
||||
} else {
|
||||
clientInfo.Os = "macOS"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "linux") {
|
||||
clientInfo.Os = "Linux"
|
||||
} else if strings.Contains(userAgent, "cros") {
|
||||
clientInfo.Os = "Chrome OS"
|
||||
}
|
||||
|
||||
// Detect browser
|
||||
if strings.Contains(userAgent, "edg/") {
|
||||
// Extract Edge version
|
||||
if idx := strings.Index(userAgent, "edg/"); idx != -1 {
|
||||
versionStart := idx + 4
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = len(userAgent) - versionStart
|
||||
}
|
||||
version := userAgent[versionStart : versionStart+versionEnd]
|
||||
clientInfo.Browser = "Edge " + version
|
||||
} else {
|
||||
clientInfo.Browser = "Edge"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "chrome/") && !strings.Contains(userAgent, "edg") {
|
||||
// Extract Chrome version
|
||||
if idx := strings.Index(userAgent, "chrome/"); idx != -1 {
|
||||
versionStart := idx + 7
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = len(userAgent) - versionStart
|
||||
}
|
||||
version := userAgent[versionStart : versionStart+versionEnd]
|
||||
clientInfo.Browser = "Chrome " + version
|
||||
} else {
|
||||
clientInfo.Browser = "Chrome"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "firefox/") {
|
||||
// Extract Firefox version
|
||||
if idx := strings.Index(userAgent, "firefox/"); idx != -1 {
|
||||
versionStart := idx + 8
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = len(userAgent) - versionStart
|
||||
}
|
||||
version := userAgent[versionStart : versionStart+versionEnd]
|
||||
clientInfo.Browser = "Firefox " + version
|
||||
} else {
|
||||
clientInfo.Browser = "Firefox"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "safari/") && !strings.Contains(userAgent, "chrome") && !strings.Contains(userAgent, "edg") {
|
||||
// Extract Safari version
|
||||
if idx := strings.Index(userAgent, "version/"); idx != -1 {
|
||||
versionStart := idx + 8
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = len(userAgent) - versionStart
|
||||
}
|
||||
version := userAgent[versionStart : versionStart+versionEnd]
|
||||
clientInfo.Browser = "Safari " + version
|
||||
} else {
|
||||
clientInfo.Browser = "Safari"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "opera/") || strings.Contains(userAgent, "opr/") {
|
||||
clientInfo.Browser = "Opera"
|
||||
}
|
||||
}
|
||||
179
server/router/api/v1/auth_service_client_info_test.go
Normal file
179
server/router/api/v1/auth_service_client_info_test.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
func TestParseUserAgent(t *testing.T) {
|
||||
service := &APIV1Service{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
expectedDevice string
|
||||
expectedOS string
|
||||
expectedBrowser string
|
||||
}{
|
||||
{
|
||||
name: "Chrome on Windows",
|
||||
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
|
||||
expectedDevice: "desktop",
|
||||
expectedOS: "Windows 10/11",
|
||||
expectedBrowser: "Chrome 119.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "Safari on macOS",
|
||||
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15",
|
||||
expectedDevice: "desktop",
|
||||
expectedOS: "macOS 10.15.7",
|
||||
expectedBrowser: "Safari 17.0",
|
||||
},
|
||||
{
|
||||
name: "Chrome on Android Mobile",
|
||||
userAgent: "Mozilla/5.0 (Linux; Android 13; SM-G998B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Mobile Safari/537.36",
|
||||
expectedDevice: "mobile",
|
||||
expectedOS: "Android 13",
|
||||
expectedBrowser: "Chrome 119.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "Safari on iPhone",
|
||||
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
|
||||
expectedDevice: "mobile",
|
||||
expectedOS: "iOS 17.0",
|
||||
expectedBrowser: "Safari 17.0",
|
||||
},
|
||||
{
|
||||
name: "Firefox on Windows",
|
||||
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/119.0",
|
||||
expectedDevice: "desktop",
|
||||
expectedOS: "Windows 10/11",
|
||||
expectedBrowser: "Firefox 119.0",
|
||||
},
|
||||
{
|
||||
name: "Edge on Windows",
|
||||
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0",
|
||||
expectedDevice: "desktop",
|
||||
expectedOS: "Windows 10/11",
|
||||
expectedBrowser: "Edge 119.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "iPad Safari",
|
||||
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
|
||||
expectedDevice: "tablet",
|
||||
expectedOS: "iOS 17.0",
|
||||
expectedBrowser: "Safari 17.0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientInfo := &storepb.RefreshTokensUserSetting_ClientInfo{}
|
||||
service.parseUserAgent(tt.userAgent, clientInfo)
|
||||
|
||||
if clientInfo.DeviceType != tt.expectedDevice {
|
||||
t.Errorf("Expected device type %s, got %s", tt.expectedDevice, clientInfo.DeviceType)
|
||||
}
|
||||
if clientInfo.Os != tt.expectedOS {
|
||||
t.Errorf("Expected OS %s, got %s", tt.expectedOS, clientInfo.Os)
|
||||
}
|
||||
if clientInfo.Browser != tt.expectedBrowser {
|
||||
t.Errorf("Expected browser %s, got %s", tt.expectedBrowser, clientInfo.Browser)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClientInfo(t *testing.T) {
|
||||
service := &APIV1Service{}
|
||||
|
||||
// Test with metadata containing user agent and IP
|
||||
md := metadata.New(map[string]string{
|
||||
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
|
||||
"x-forwarded-for": "203.0.113.1, 198.51.100.1",
|
||||
"x-real-ip": "203.0.113.1",
|
||||
})
|
||||
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
|
||||
clientInfo := service.extractClientInfo(ctx)
|
||||
|
||||
if clientInfo.UserAgent == "" {
|
||||
t.Error("Expected user agent to be set")
|
||||
}
|
||||
if clientInfo.IpAddress != "203.0.113.1" {
|
||||
t.Errorf("Expected IP address to be 203.0.113.1, got %s", clientInfo.IpAddress)
|
||||
}
|
||||
if clientInfo.DeviceType != "desktop" {
|
||||
t.Errorf("Expected device type to be desktop, got %s", clientInfo.DeviceType)
|
||||
}
|
||||
if clientInfo.Os != "Windows 10/11" {
|
||||
t.Errorf("Expected OS to be Windows 10/11, got %s", clientInfo.Os)
|
||||
}
|
||||
if clientInfo.Browser != "Chrome 119.0.0.0" {
|
||||
t.Errorf("Expected browser to be Chrome 119.0.0.0, got %s", clientInfo.Browser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientInfoExamples demonstrates the enhanced client info extraction with various user agents.
|
||||
func TestClientInfoExamples(t *testing.T) {
|
||||
service := &APIV1Service{}
|
||||
|
||||
examples := []struct {
|
||||
description string
|
||||
userAgent string
|
||||
}{
|
||||
{
|
||||
description: "Modern Chrome on Windows 11",
|
||||
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
|
||||
},
|
||||
{
|
||||
description: "Safari on iPhone 15 Pro",
|
||||
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
|
||||
},
|
||||
{
|
||||
description: "Chrome on Samsung Galaxy",
|
||||
userAgent: "Mozilla/5.0 (Linux; Android 14; SM-S918B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36",
|
||||
},
|
||||
{
|
||||
description: "Firefox on Ubuntu",
|
||||
userAgent: "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/120.0",
|
||||
},
|
||||
{
|
||||
description: "Edge on Windows 10",
|
||||
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
|
||||
},
|
||||
{
|
||||
description: "Safari on iPad Air",
|
||||
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, example := range examples {
|
||||
t.Run(example.description, func(t *testing.T) {
|
||||
clientInfo := &storepb.RefreshTokensUserSetting_ClientInfo{}
|
||||
service.parseUserAgent(example.userAgent, clientInfo)
|
||||
|
||||
t.Logf("User Agent: %s", example.userAgent)
|
||||
t.Logf("Device Type: %s", clientInfo.DeviceType)
|
||||
t.Logf("Operating System: %s", clientInfo.Os)
|
||||
t.Logf("Browser: %s", clientInfo.Browser)
|
||||
t.Log("---")
|
||||
|
||||
// Ensure all fields are populated
|
||||
if clientInfo.DeviceType == "" {
|
||||
t.Error("Device type should not be empty")
|
||||
}
|
||||
if clientInfo.Os == "" {
|
||||
t.Error("OS should not be empty")
|
||||
}
|
||||
if clientInfo.Browser == "" {
|
||||
t.Error("Browser should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
68
server/router/api/v1/common.go
Normal file
68
server/router/api/v1/common.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultPageSize is the default page size for requests.
|
||||
DefaultPageSize = 10
|
||||
// MaxPageSize is the maximum page size for requests.
|
||||
MaxPageSize = 1000
|
||||
)
|
||||
|
||||
func convertStateFromStore(rowStatus store.RowStatus) v1pb.State {
|
||||
switch rowStatus {
|
||||
case store.Normal:
|
||||
return v1pb.State_NORMAL
|
||||
case store.Archived:
|
||||
return v1pb.State_ARCHIVED
|
||||
default:
|
||||
return v1pb.State_STATE_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertStateToStore(state v1pb.State) store.RowStatus {
|
||||
switch state {
|
||||
case v1pb.State_ARCHIVED:
|
||||
return store.Archived
|
||||
default:
|
||||
return store.Normal
|
||||
}
|
||||
}
|
||||
|
||||
func getPageToken(limit int, offset int) (string, error) {
|
||||
return marshalPageToken(&v1pb.PageToken{
|
||||
Limit: int32(limit),
|
||||
Offset: int32(offset),
|
||||
})
|
||||
}
|
||||
|
||||
func marshalPageToken(pageToken *v1pb.PageToken) (string, error) {
|
||||
b, err := proto.Marshal(pageToken)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to marshal page token")
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error {
|
||||
b, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to decode page token")
|
||||
}
|
||||
if err := proto.Unmarshal(b, pageToken); err != nil {
|
||||
return errors.Wrapf(err, "failed to unmarshal page token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isSuperUser(user *store.User) bool {
|
||||
return user.Role == store.RoleAdmin
|
||||
}
|
||||
80
server/router/api/v1/connect_handler.go
Normal file
80
server/router/api/v1/connect_handler.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/usememos/memos/proto/gen/api/v1/apiv1connect"
|
||||
)
|
||||
|
||||
// ConnectServiceHandler wraps APIV1Service to implement Connect handler interfaces.
|
||||
// It adapts the existing gRPC service implementations to work with Connect's
|
||||
// request/response wrapper types.
|
||||
//
|
||||
// This wrapper pattern allows us to:
|
||||
// - Reuse existing gRPC service implementations
|
||||
// - Support both native gRPC and Connect protocols
|
||||
// - Maintain a single source of truth for business logic.
|
||||
type ConnectServiceHandler struct {
|
||||
*APIV1Service
|
||||
}
|
||||
|
||||
// NewConnectServiceHandler creates a new Connect service handler.
|
||||
func NewConnectServiceHandler(svc *APIV1Service) *ConnectServiceHandler {
|
||||
return &ConnectServiceHandler{APIV1Service: svc}
|
||||
}
|
||||
|
||||
// RegisterConnectHandlers registers all Connect service handlers on the given mux.
|
||||
func (s *ConnectServiceHandler) RegisterConnectHandlers(mux *http.ServeMux, opts ...connect.HandlerOption) {
|
||||
// Register all service handlers
|
||||
handlers := []struct {
|
||||
path string
|
||||
handler http.Handler
|
||||
}{
|
||||
wrap(apiv1connect.NewInstanceServiceHandler(s, opts...)),
|
||||
wrap(apiv1connect.NewAuthServiceHandler(s, opts...)),
|
||||
wrap(apiv1connect.NewUserServiceHandler(s, opts...)),
|
||||
wrap(apiv1connect.NewMemoServiceHandler(s, opts...)),
|
||||
wrap(apiv1connect.NewAttachmentServiceHandler(s, opts...)),
|
||||
wrap(apiv1connect.NewShortcutServiceHandler(s, opts...)),
|
||||
wrap(apiv1connect.NewActivityServiceHandler(s, opts...)),
|
||||
wrap(apiv1connect.NewIdentityProviderServiceHandler(s, opts...)),
|
||||
}
|
||||
|
||||
for _, h := range handlers {
|
||||
mux.Handle(h.path, h.handler)
|
||||
}
|
||||
}
|
||||
|
||||
// wrap converts (path, handler) return value to a struct for cleaner iteration.
|
||||
func wrap(path string, handler http.Handler) struct {
|
||||
path string
|
||||
handler http.Handler
|
||||
} {
|
||||
return struct {
|
||||
path string
|
||||
handler http.Handler
|
||||
}{path, handler}
|
||||
}
|
||||
|
||||
// convertGRPCError converts gRPC status errors to Connect errors.
|
||||
// This preserves the error code semantics between the two protocols.
|
||||
func convertGRPCError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if st, ok := status.FromError(err); ok {
|
||||
return connect.NewError(grpcCodeToConnectCode(st.Code()), err)
|
||||
}
|
||||
return connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
|
||||
// grpcCodeToConnectCode converts gRPC status codes to Connect error codes.
|
||||
// gRPC and Connect use the same error code semantics, so this is a direct cast.
|
||||
// See: https://connectrpc.com/docs/protocol/#error-codes
|
||||
func grpcCodeToConnectCode(code codes.Code) connect.Code {
|
||||
return connect.Code(code)
|
||||
}
|
||||
237
server/router/api/v1/connect_interceptors.go
Normal file
237
server/router/api/v1/connect_interceptors.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// MetadataInterceptor converts Connect HTTP headers to gRPC metadata.
|
||||
//
|
||||
// This ensures service methods can use metadata.FromIncomingContext() to access
|
||||
// headers like User-Agent, X-Forwarded-For, etc., regardless of whether the
|
||||
// request came via Connect RPC or gRPC-Gateway.
|
||||
type MetadataInterceptor struct{}
|
||||
|
||||
// NewMetadataInterceptor creates a new metadata interceptor.
|
||||
func NewMetadataInterceptor() *MetadataInterceptor {
|
||||
return &MetadataInterceptor{}
|
||||
}
|
||||
|
||||
func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
// Convert HTTP headers to gRPC metadata
|
||||
header := req.Header()
|
||||
md := metadata.MD{}
|
||||
|
||||
// Copy important headers for client info extraction
|
||||
if ua := header.Get("User-Agent"); ua != "" {
|
||||
md.Set("user-agent", ua)
|
||||
}
|
||||
if xff := header.Get("X-Forwarded-For"); xff != "" {
|
||||
md.Set("x-forwarded-for", xff)
|
||||
}
|
||||
if xri := header.Get("X-Real-Ip"); xri != "" {
|
||||
md.Set("x-real-ip", xri)
|
||||
}
|
||||
// Forward Cookie header for authentication methods that need it (e.g., RefreshToken)
|
||||
if cookie := header.Get("Cookie"); cookie != "" {
|
||||
md.Set("cookie", cookie)
|
||||
}
|
||||
|
||||
// Set metadata in context so services can use metadata.FromIncomingContext()
|
||||
ctx = metadata.NewIncomingContext(ctx, md)
|
||||
|
||||
// Execute the request
|
||||
resp, err := next(ctx, req)
|
||||
|
||||
// Prevent browser caching of API responses to avoid stale data issues
|
||||
// See: https://github.com/usememos/memos/issues/5470
|
||||
if !isNilAnyResponse(resp) && resp.Header() != nil {
|
||||
resp.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
resp.Header().Set("Pragma", "no-cache")
|
||||
resp.Header().Set("Expires", "0")
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
func isNilAnyResponse(resp connect.AnyResponse) bool {
|
||||
if resp == nil {
|
||||
return true
|
||||
}
|
||||
val := reflect.ValueOf(resp)
|
||||
return val.Kind() == reflect.Ptr && val.IsNil()
|
||||
}
|
||||
|
||||
func (*MetadataInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
|
||||
return next
|
||||
}
|
||||
|
||||
func (*MetadataInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
|
||||
return next
|
||||
}
|
||||
|
||||
// LoggingInterceptor logs Connect RPC requests with appropriate log levels.
|
||||
//
|
||||
// Log levels:
|
||||
// - INFO: Successful requests and expected client errors (not found, permission denied, etc.)
|
||||
// - ERROR: Server errors (internal, unavailable, etc.)
|
||||
type LoggingInterceptor struct {
|
||||
logStacktrace bool
|
||||
}
|
||||
|
||||
// NewLoggingInterceptor creates a new logging interceptor.
|
||||
func NewLoggingInterceptor(logStacktrace bool) *LoggingInterceptor {
|
||||
return &LoggingInterceptor{logStacktrace: logStacktrace}
|
||||
}
|
||||
|
||||
func (in *LoggingInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
resp, err := next(ctx, req)
|
||||
in.log(req.Spec().Procedure, err)
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
func (*LoggingInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
|
||||
return next // No-op for server-side interceptor
|
||||
}
|
||||
|
||||
func (*LoggingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
|
||||
return next // Streaming not used in this service
|
||||
}
|
||||
|
||||
func (in *LoggingInterceptor) log(procedure string, err error) {
|
||||
level, msg := in.classifyError(err)
|
||||
attrs := []slog.Attr{slog.String("method", procedure)}
|
||||
if err != nil {
|
||||
attrs = append(attrs, slog.String("error", err.Error()))
|
||||
if in.logStacktrace {
|
||||
attrs = append(attrs, slog.String("stacktrace", fmt.Sprintf("%+v", err)))
|
||||
}
|
||||
}
|
||||
slog.LogAttrs(context.Background(), level, msg, attrs...)
|
||||
}
|
||||
|
||||
func (*LoggingInterceptor) classifyError(err error) (slog.Level, string) {
|
||||
if err == nil {
|
||||
return slog.LevelInfo, "OK"
|
||||
}
|
||||
|
||||
var connectErr *connect.Error
|
||||
if !pkgerrors.As(err, &connectErr) {
|
||||
return slog.LevelError, "unknown error"
|
||||
}
|
||||
|
||||
// Client errors (expected, log at INFO)
|
||||
switch connectErr.Code() {
|
||||
case connect.CodeCanceled,
|
||||
connect.CodeInvalidArgument,
|
||||
connect.CodeNotFound,
|
||||
connect.CodeAlreadyExists,
|
||||
connect.CodePermissionDenied,
|
||||
connect.CodeUnauthenticated,
|
||||
connect.CodeResourceExhausted,
|
||||
connect.CodeFailedPrecondition,
|
||||
connect.CodeAborted,
|
||||
connect.CodeOutOfRange:
|
||||
return slog.LevelInfo, "client error"
|
||||
default:
|
||||
// Server errors
|
||||
return slog.LevelError, "server error"
|
||||
}
|
||||
}
|
||||
|
||||
// RecoveryInterceptor recovers from panics in Connect handlers and returns an internal error.
|
||||
type RecoveryInterceptor struct {
|
||||
logStacktrace bool
|
||||
}
|
||||
|
||||
// NewRecoveryInterceptor creates a new recovery interceptor.
|
||||
func NewRecoveryInterceptor(logStacktrace bool) *RecoveryInterceptor {
|
||||
return &RecoveryInterceptor{logStacktrace: logStacktrace}
|
||||
}
|
||||
|
||||
func (in *RecoveryInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (resp connect.AnyResponse, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
in.logPanic(req.Spec().Procedure, r)
|
||||
err = connect.NewError(connect.CodeInternal, pkgerrors.New("internal server error"))
|
||||
}
|
||||
}()
|
||||
return next(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (*RecoveryInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
|
||||
return next
|
||||
}
|
||||
|
||||
func (*RecoveryInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
|
||||
return next
|
||||
}
|
||||
|
||||
func (in *RecoveryInterceptor) logPanic(procedure string, panicValue any) {
|
||||
attrs := []slog.Attr{
|
||||
slog.String("method", procedure),
|
||||
slog.Any("panic", panicValue),
|
||||
}
|
||||
if in.logStacktrace {
|
||||
attrs = append(attrs, slog.String("stacktrace", string(debug.Stack())))
|
||||
}
|
||||
slog.LogAttrs(context.Background(), slog.LevelError, "panic recovered in Connect handler", attrs...)
|
||||
}
|
||||
|
||||
// AuthInterceptor handles authentication for Connect handlers.
|
||||
//
|
||||
// It enforces authentication for all endpoints except those listed in PublicMethods.
|
||||
// Role-based authorization (admin checks) remains in the service layer.
|
||||
type AuthInterceptor struct {
|
||||
authenticator *auth.Authenticator
|
||||
}
|
||||
|
||||
// NewAuthInterceptor creates a new auth interceptor.
|
||||
func NewAuthInterceptor(store *store.Store, secret string) *AuthInterceptor {
|
||||
return &AuthInterceptor{
|
||||
authenticator: auth.NewAuthenticator(store, secret),
|
||||
}
|
||||
}
|
||||
|
||||
func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
header := req.Header()
|
||||
authHeader := header.Get("Authorization")
|
||||
|
||||
result := in.authenticator.Authenticate(ctx, authHeader)
|
||||
|
||||
// Enforce authentication for non-public methods
|
||||
if result == nil && !IsPublicMethod(req.Spec().Procedure) {
|
||||
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required"))
|
||||
}
|
||||
|
||||
ctx = auth.ApplyToContext(ctx, result)
|
||||
|
||||
return next(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (*AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
|
||||
return next
|
||||
}
|
||||
|
||||
func (*AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
|
||||
return next
|
||||
}
|
||||
490
server/router/api/v1/connect_services.go
Normal file
490
server/router/api/v1/connect_services.go
Normal file
@@ -0,0 +1,490 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
// This file contains all Connect service handler method implementations.
|
||||
// Each method delegates to the underlying gRPC service implementation,
|
||||
// converting between Connect and gRPC request/response types.
|
||||
|
||||
// InstanceService
|
||||
|
||||
func (s *ConnectServiceHandler) GetInstanceProfile(ctx context.Context, req *connect.Request[v1pb.GetInstanceProfileRequest]) (*connect.Response[v1pb.InstanceProfile], error) {
|
||||
resp, err := s.APIV1Service.GetInstanceProfile(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) GetInstanceSetting(ctx context.Context, req *connect.Request[v1pb.GetInstanceSettingRequest]) (*connect.Response[v1pb.InstanceSetting], error) {
|
||||
resp, err := s.APIV1Service.GetInstanceSetting(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) UpdateInstanceSetting(ctx context.Context, req *connect.Request[v1pb.UpdateInstanceSettingRequest]) (*connect.Response[v1pb.InstanceSetting], error) {
|
||||
resp, err := s.APIV1Service.UpdateInstanceSetting(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
// AuthService
|
||||
//
|
||||
// Auth service methods need special handling for response headers (cookies).
|
||||
// We use connectWithHeaderCarrier helper to inject a header carrier into the context,
|
||||
// which allows the service to set headers in a protocol-agnostic way.
|
||||
|
||||
func (s *ConnectServiceHandler) GetCurrentUser(ctx context.Context, req *connect.Request[v1pb.GetCurrentUserRequest]) (*connect.Response[v1pb.GetCurrentUserResponse], error) {
|
||||
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.GetCurrentUserResponse, error) {
|
||||
return s.APIV1Service.GetCurrentUser(ctx, req.Msg)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) SignIn(ctx context.Context, req *connect.Request[v1pb.SignInRequest]) (*connect.Response[v1pb.SignInResponse], error) {
|
||||
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.SignInResponse, error) {
|
||||
return s.APIV1Service.SignIn(ctx, req.Msg)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) SignOut(ctx context.Context, req *connect.Request[v1pb.SignOutRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*emptypb.Empty, error) {
|
||||
return s.APIV1Service.SignOut(ctx, req.Msg)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) RefreshToken(ctx context.Context, req *connect.Request[v1pb.RefreshTokenRequest]) (*connect.Response[v1pb.RefreshTokenResponse], error) {
|
||||
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.RefreshTokenResponse, error) {
|
||||
return s.APIV1Service.RefreshToken(ctx, req.Msg)
|
||||
})
|
||||
}
|
||||
|
||||
// UserService
|
||||
|
||||
func (s *ConnectServiceHandler) ListUsers(ctx context.Context, req *connect.Request[v1pb.ListUsersRequest]) (*connect.Response[v1pb.ListUsersResponse], error) {
|
||||
resp, err := s.APIV1Service.ListUsers(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) GetUser(ctx context.Context, req *connect.Request[v1pb.GetUserRequest]) (*connect.Response[v1pb.User], error) {
|
||||
resp, err := s.APIV1Service.GetUser(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) CreateUser(ctx context.Context, req *connect.Request[v1pb.CreateUserRequest]) (*connect.Response[v1pb.User], error) {
|
||||
resp, err := s.APIV1Service.CreateUser(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) UpdateUser(ctx context.Context, req *connect.Request[v1pb.UpdateUserRequest]) (*connect.Response[v1pb.User], error) {
|
||||
resp, err := s.APIV1Service.UpdateUser(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) DeleteUser(ctx context.Context, req *connect.Request[v1pb.DeleteUserRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.DeleteUser(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListAllUserStats(ctx context.Context, req *connect.Request[v1pb.ListAllUserStatsRequest]) (*connect.Response[v1pb.ListAllUserStatsResponse], error) {
|
||||
resp, err := s.APIV1Service.ListAllUserStats(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) GetUserStats(ctx context.Context, req *connect.Request[v1pb.GetUserStatsRequest]) (*connect.Response[v1pb.UserStats], error) {
|
||||
resp, err := s.APIV1Service.GetUserStats(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) GetUserSetting(ctx context.Context, req *connect.Request[v1pb.GetUserSettingRequest]) (*connect.Response[v1pb.UserSetting], error) {
|
||||
resp, err := s.APIV1Service.GetUserSetting(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) UpdateUserSetting(ctx context.Context, req *connect.Request[v1pb.UpdateUserSettingRequest]) (*connect.Response[v1pb.UserSetting], error) {
|
||||
resp, err := s.APIV1Service.UpdateUserSetting(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListUserSettings(ctx context.Context, req *connect.Request[v1pb.ListUserSettingsRequest]) (*connect.Response[v1pb.ListUserSettingsResponse], error) {
|
||||
resp, err := s.APIV1Service.ListUserSettings(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListPersonalAccessTokens(ctx context.Context, req *connect.Request[v1pb.ListPersonalAccessTokensRequest]) (*connect.Response[v1pb.ListPersonalAccessTokensResponse], error) {
|
||||
resp, err := s.APIV1Service.ListPersonalAccessTokens(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) CreatePersonalAccessToken(ctx context.Context, req *connect.Request[v1pb.CreatePersonalAccessTokenRequest]) (*connect.Response[v1pb.CreatePersonalAccessTokenResponse], error) {
|
||||
resp, err := s.APIV1Service.CreatePersonalAccessToken(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) DeletePersonalAccessToken(ctx context.Context, req *connect.Request[v1pb.DeletePersonalAccessTokenRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.DeletePersonalAccessToken(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListUserWebhooks(ctx context.Context, req *connect.Request[v1pb.ListUserWebhooksRequest]) (*connect.Response[v1pb.ListUserWebhooksResponse], error) {
|
||||
resp, err := s.APIV1Service.ListUserWebhooks(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) CreateUserWebhook(ctx context.Context, req *connect.Request[v1pb.CreateUserWebhookRequest]) (*connect.Response[v1pb.UserWebhook], error) {
|
||||
resp, err := s.APIV1Service.CreateUserWebhook(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) UpdateUserWebhook(ctx context.Context, req *connect.Request[v1pb.UpdateUserWebhookRequest]) (*connect.Response[v1pb.UserWebhook], error) {
|
||||
resp, err := s.APIV1Service.UpdateUserWebhook(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) DeleteUserWebhook(ctx context.Context, req *connect.Request[v1pb.DeleteUserWebhookRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.DeleteUserWebhook(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListUserNotifications(ctx context.Context, req *connect.Request[v1pb.ListUserNotificationsRequest]) (*connect.Response[v1pb.ListUserNotificationsResponse], error) {
|
||||
resp, err := s.APIV1Service.ListUserNotifications(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) UpdateUserNotification(ctx context.Context, req *connect.Request[v1pb.UpdateUserNotificationRequest]) (*connect.Response[v1pb.UserNotification], error) {
|
||||
resp, err := s.APIV1Service.UpdateUserNotification(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) DeleteUserNotification(ctx context.Context, req *connect.Request[v1pb.DeleteUserNotificationRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.DeleteUserNotification(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
// MemoService
|
||||
|
||||
func (s *ConnectServiceHandler) CreateMemo(ctx context.Context, req *connect.Request[v1pb.CreateMemoRequest]) (*connect.Response[v1pb.Memo], error) {
|
||||
resp, err := s.APIV1Service.CreateMemo(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListMemos(ctx context.Context, req *connect.Request[v1pb.ListMemosRequest]) (*connect.Response[v1pb.ListMemosResponse], error) {
|
||||
resp, err := s.APIV1Service.ListMemos(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) GetMemo(ctx context.Context, req *connect.Request[v1pb.GetMemoRequest]) (*connect.Response[v1pb.Memo], error) {
|
||||
resp, err := s.APIV1Service.GetMemo(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) UpdateMemo(ctx context.Context, req *connect.Request[v1pb.UpdateMemoRequest]) (*connect.Response[v1pb.Memo], error) {
|
||||
resp, err := s.APIV1Service.UpdateMemo(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) DeleteMemo(ctx context.Context, req *connect.Request[v1pb.DeleteMemoRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.DeleteMemo(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) SetMemoAttachments(ctx context.Context, req *connect.Request[v1pb.SetMemoAttachmentsRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.SetMemoAttachments(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListMemoAttachments(ctx context.Context, req *connect.Request[v1pb.ListMemoAttachmentsRequest]) (*connect.Response[v1pb.ListMemoAttachmentsResponse], error) {
|
||||
resp, err := s.APIV1Service.ListMemoAttachments(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) SetMemoRelations(ctx context.Context, req *connect.Request[v1pb.SetMemoRelationsRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.SetMemoRelations(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListMemoRelations(ctx context.Context, req *connect.Request[v1pb.ListMemoRelationsRequest]) (*connect.Response[v1pb.ListMemoRelationsResponse], error) {
|
||||
resp, err := s.APIV1Service.ListMemoRelations(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) CreateMemoComment(ctx context.Context, req *connect.Request[v1pb.CreateMemoCommentRequest]) (*connect.Response[v1pb.Memo], error) {
|
||||
resp, err := s.APIV1Service.CreateMemoComment(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListMemoComments(ctx context.Context, req *connect.Request[v1pb.ListMemoCommentsRequest]) (*connect.Response[v1pb.ListMemoCommentsResponse], error) {
|
||||
resp, err := s.APIV1Service.ListMemoComments(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListMemoReactions(ctx context.Context, req *connect.Request[v1pb.ListMemoReactionsRequest]) (*connect.Response[v1pb.ListMemoReactionsResponse], error) {
|
||||
resp, err := s.APIV1Service.ListMemoReactions(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) UpsertMemoReaction(ctx context.Context, req *connect.Request[v1pb.UpsertMemoReactionRequest]) (*connect.Response[v1pb.Reaction], error) {
|
||||
resp, err := s.APIV1Service.UpsertMemoReaction(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) DeleteMemoReaction(ctx context.Context, req *connect.Request[v1pb.DeleteMemoReactionRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.DeleteMemoReaction(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
// AttachmentService
|
||||
|
||||
func (s *ConnectServiceHandler) CreateAttachment(ctx context.Context, req *connect.Request[v1pb.CreateAttachmentRequest]) (*connect.Response[v1pb.Attachment], error) {
|
||||
resp, err := s.APIV1Service.CreateAttachment(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) ListAttachments(ctx context.Context, req *connect.Request[v1pb.ListAttachmentsRequest]) (*connect.Response[v1pb.ListAttachmentsResponse], error) {
|
||||
resp, err := s.APIV1Service.ListAttachments(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) GetAttachment(ctx context.Context, req *connect.Request[v1pb.GetAttachmentRequest]) (*connect.Response[v1pb.Attachment], error) {
|
||||
resp, err := s.APIV1Service.GetAttachment(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) UpdateAttachment(ctx context.Context, req *connect.Request[v1pb.UpdateAttachmentRequest]) (*connect.Response[v1pb.Attachment], error) {
|
||||
resp, err := s.APIV1Service.UpdateAttachment(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) DeleteAttachment(ctx context.Context, req *connect.Request[v1pb.DeleteAttachmentRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.DeleteAttachment(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
// ShortcutService
|
||||
|
||||
func (s *ConnectServiceHandler) ListShortcuts(ctx context.Context, req *connect.Request[v1pb.ListShortcutsRequest]) (*connect.Response[v1pb.ListShortcutsResponse], error) {
|
||||
resp, err := s.APIV1Service.ListShortcuts(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) GetShortcut(ctx context.Context, req *connect.Request[v1pb.GetShortcutRequest]) (*connect.Response[v1pb.Shortcut], error) {
|
||||
resp, err := s.APIV1Service.GetShortcut(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) CreateShortcut(ctx context.Context, req *connect.Request[v1pb.CreateShortcutRequest]) (*connect.Response[v1pb.Shortcut], error) {
|
||||
resp, err := s.APIV1Service.CreateShortcut(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) UpdateShortcut(ctx context.Context, req *connect.Request[v1pb.UpdateShortcutRequest]) (*connect.Response[v1pb.Shortcut], error) {
|
||||
resp, err := s.APIV1Service.UpdateShortcut(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) DeleteShortcut(ctx context.Context, req *connect.Request[v1pb.DeleteShortcutRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.DeleteShortcut(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
// ActivityService
|
||||
|
||||
func (s *ConnectServiceHandler) ListActivities(ctx context.Context, req *connect.Request[v1pb.ListActivitiesRequest]) (*connect.Response[v1pb.ListActivitiesResponse], error) {
|
||||
resp, err := s.APIV1Service.ListActivities(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) GetActivity(ctx context.Context, req *connect.Request[v1pb.GetActivityRequest]) (*connect.Response[v1pb.Activity], error) {
|
||||
resp, err := s.APIV1Service.GetActivity(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
// IdentityProviderService
|
||||
|
||||
func (s *ConnectServiceHandler) ListIdentityProviders(ctx context.Context, req *connect.Request[v1pb.ListIdentityProvidersRequest]) (*connect.Response[v1pb.ListIdentityProvidersResponse], error) {
|
||||
resp, err := s.APIV1Service.ListIdentityProviders(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) GetIdentityProvider(ctx context.Context, req *connect.Request[v1pb.GetIdentityProviderRequest]) (*connect.Response[v1pb.IdentityProvider], error) {
|
||||
resp, err := s.APIV1Service.GetIdentityProvider(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) CreateIdentityProvider(ctx context.Context, req *connect.Request[v1pb.CreateIdentityProviderRequest]) (*connect.Response[v1pb.IdentityProvider], error) {
|
||||
resp, err := s.APIV1Service.CreateIdentityProvider(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) UpdateIdentityProvider(ctx context.Context, req *connect.Request[v1pb.UpdateIdentityProviderRequest]) (*connect.Response[v1pb.IdentityProvider], error) {
|
||||
resp, err := s.APIV1Service.UpdateIdentityProvider(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *ConnectServiceHandler) DeleteIdentityProvider(ctx context.Context, req *connect.Request[v1pb.DeleteIdentityProviderRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||
resp, err := s.APIV1Service.DeleteIdentityProvider(ctx, req.Msg)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
return connect.NewResponse(resp), nil
|
||||
}
|
||||
124
server/router/api/v1/header_carrier.go
Normal file
124
server/router/api/v1/header_carrier.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// headerCarrierKey is the context key for storing headers to be set in the response.
|
||||
type headerCarrierKey struct{}
|
||||
|
||||
// HeaderCarrier stores headers that need to be set in the response.
|
||||
//
|
||||
// Problem: The codebase supports two protocols simultaneously:
|
||||
// - Native gRPC: Uses grpc.SetHeader() to set response headers
|
||||
// - Connect-RPC: Uses connect.Response.Header().Set() to set response headers
|
||||
//
|
||||
// Solution: HeaderCarrier provides a protocol-agnostic way to set headers.
|
||||
// - Service methods call SetResponseHeader() regardless of protocol
|
||||
// - For gRPC requests: SetResponseHeader uses grpc.SetHeader directly
|
||||
// - For Connect requests: SetResponseHeader stores headers in HeaderCarrier
|
||||
// - Connect wrappers extract headers from HeaderCarrier and apply to response
|
||||
//
|
||||
// This allows service methods to work with both protocols without knowing which one is being used.
|
||||
type HeaderCarrier struct {
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
// newHeaderCarrier creates a new header carrier.
|
||||
func newHeaderCarrier() *HeaderCarrier {
|
||||
return &HeaderCarrier{
|
||||
headers: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds a header to the carrier.
|
||||
func (h *HeaderCarrier) Set(key, value string) {
|
||||
h.headers[key] = value
|
||||
}
|
||||
|
||||
// Get retrieves a header from the carrier.
|
||||
func (h *HeaderCarrier) Get(key string) string {
|
||||
return h.headers[key]
|
||||
}
|
||||
|
||||
// All returns all headers.
|
||||
func (h *HeaderCarrier) All() map[string]string {
|
||||
return h.headers
|
||||
}
|
||||
|
||||
// WithHeaderCarrier adds a header carrier to the context.
|
||||
func WithHeaderCarrier(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, headerCarrierKey{}, newHeaderCarrier())
|
||||
}
|
||||
|
||||
// GetHeaderCarrier retrieves the header carrier from the context.
|
||||
// Returns nil if no carrier is present.
|
||||
func GetHeaderCarrier(ctx context.Context) *HeaderCarrier {
|
||||
if carrier, ok := ctx.Value(headerCarrierKey{}).(*HeaderCarrier); ok {
|
||||
return carrier
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetResponseHeader sets a header in the response.
|
||||
//
|
||||
// This function works for both gRPC and Connect protocols:
|
||||
// - For gRPC: Uses grpc.SetHeader to set headers in gRPC metadata
|
||||
// - For Connect: Stores in HeaderCarrier for Connect wrapper to apply later
|
||||
//
|
||||
// The protocol is automatically detected based on whether a HeaderCarrier
|
||||
// exists in the context (injected by Connect wrappers).
|
||||
func SetResponseHeader(ctx context.Context, key, value string) error {
|
||||
// Try Connect first (check if we have a header carrier)
|
||||
if carrier := GetHeaderCarrier(ctx); carrier != nil {
|
||||
carrier.Set(key, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fall back to gRPC
|
||||
return grpc.SetHeader(ctx, metadata.New(map[string]string{
|
||||
key: value,
|
||||
}))
|
||||
}
|
||||
|
||||
// connectWithHeaderCarrier is a helper for Connect service wrappers that need to set response headers.
|
||||
//
|
||||
// It injects a HeaderCarrier into the context, calls the service method,
|
||||
// and applies any headers from the carrier to the Connect response.
|
||||
//
|
||||
// The generic parameter T is the non-pointer protobuf message type (e.g., v1pb.CreateSessionResponse),
|
||||
// while fn returns *T (the pointer type) as is standard for protobuf messages.
|
||||
//
|
||||
// Usage in Connect wrappers:
|
||||
//
|
||||
// func (s *ConnectServiceHandler) CreateSession(ctx context.Context, req *connect.Request[v1pb.CreateSessionRequest]) (*connect.Response[v1pb.CreateSessionResponse], error) {
|
||||
// return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.CreateSessionResponse, error) {
|
||||
// return s.APIV1Service.CreateSession(ctx, req.Msg)
|
||||
// })
|
||||
// }
|
||||
func connectWithHeaderCarrier[T any](ctx context.Context, fn func(context.Context) (*T, error)) (*connect.Response[T], error) {
|
||||
// Inject header carrier for Connect protocol
|
||||
ctx = WithHeaderCarrier(ctx)
|
||||
|
||||
// Call the service method
|
||||
resp, err := fn(ctx)
|
||||
if err != nil {
|
||||
return nil, convertGRPCError(err)
|
||||
}
|
||||
|
||||
// Create Connect response
|
||||
connectResp := connect.NewResponse(resp)
|
||||
|
||||
// Apply any headers set via the header carrier
|
||||
if carrier := GetHeaderCarrier(ctx); carrier != nil {
|
||||
for key, value := range carrier.All() {
|
||||
connectResp.Header().Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
return connectResp, nil
|
||||
}
|
||||
25
server/router/api/v1/health_service.go
Normal file
25
server/router/api/v1/health_service.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) Check(ctx context.Context,
|
||||
_ *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) {
|
||||
// Check if database is initialized by verifying instance basic setting exists
|
||||
instanceBasicSetting, err := s.Store.GetInstanceBasicSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "database not initialized: %v", err)
|
||||
}
|
||||
|
||||
// Verify schema version is set (empty means database not properly initialized)
|
||||
if instanceBasicSetting.SchemaVersion == "" {
|
||||
return nil, status.Errorf(codes.Unavailable, "schema version not set")
|
||||
}
|
||||
|
||||
return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil
|
||||
}
|
||||
238
server/router/api/v1/idp_service.go
Normal file
238
server/router/api/v1/idp_service.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb.CreateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
|
||||
}
|
||||
return convertIdentityProviderFromStore(identityProvider), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListIdentityProvidersRequest) (*v1pb.ListIdentityProvidersResponse, error) {
|
||||
identityProviders, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list identity providers, error: %+v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListIdentityProvidersResponse{
|
||||
IdentityProviders: []*v1pb.IdentityProvider{},
|
||||
}
|
||||
|
||||
// Default to lowest-privilege role, update later based on real role
|
||||
currentUserRole := store.RoleUser
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err == nil && currentUser != nil {
|
||||
currentUserRole = currentUser.Role
|
||||
}
|
||||
|
||||
for _, identityProvider := range identityProviders {
|
||||
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
|
||||
response.IdentityProviders = append(response.IdentityProviders, redactIdentityProviderResponse(identityProviderConverted, currentUserRole))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
||||
}
|
||||
|
||||
// Default to lowest-privilege role, update later based on real role
|
||||
currentUserRole := store.RoleUser
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err == nil && currentUser != nil {
|
||||
currentUserRole = currentUser.Role
|
||||
}
|
||||
|
||||
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
|
||||
return redactIdentityProviderResponse(identityProviderConverted, currentUserRole), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
|
||||
}
|
||||
|
||||
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
update := &store.UpdateIdentityProviderV1{
|
||||
ID: id,
|
||||
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
switch field {
|
||||
case "title":
|
||||
update.Name = &request.IdentityProvider.Title
|
||||
case "identifier_filter":
|
||||
update.IdentifierFilter = &request.IdentityProvider.IdentifierFilter
|
||||
case "config":
|
||||
update.Config = convertIdentityProviderConfigToStore(request.IdentityProvider.Type, request.IdentityProvider.Config)
|
||||
default:
|
||||
// Ignore unsupported fields
|
||||
}
|
||||
}
|
||||
|
||||
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, update)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
|
||||
}
|
||||
return convertIdentityProviderFromStore(identityProvider), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
|
||||
// Check if the identity provider exists before trying to delete it
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &id})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err)
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
|
||||
temp := &v1pb.IdentityProvider{
|
||||
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
|
||||
Title: identityProvider.Name,
|
||||
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||
Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
|
||||
}
|
||||
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
|
||||
oauth2Config := identityProvider.Config.GetOauth2Config()
|
||||
temp.Config = &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: oauth2Config.ClientId,
|
||||
ClientSecret: oauth2Config.ClientSecret,
|
||||
AuthUrl: oauth2Config.AuthUrl,
|
||||
TokenUrl: oauth2Config.TokenUrl,
|
||||
UserInfoUrl: oauth2Config.UserInfoUrl,
|
||||
Scopes: oauth2Config.Scopes,
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: oauth2Config.FieldMapping.Identifier,
|
||||
DisplayName: oauth2Config.FieldMapping.DisplayName,
|
||||
Email: oauth2Config.FieldMapping.Email,
|
||||
AvatarUrl: oauth2Config.FieldMapping.AvatarUrl,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return temp
|
||||
}
|
||||
|
||||
func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
|
||||
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
|
||||
|
||||
temp := &storepb.IdentityProvider{
|
||||
Id: id,
|
||||
Name: identityProvider.Title,
|
||||
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
|
||||
Config: convertIdentityProviderConfigToStore(identityProvider.Type, identityProvider.Config),
|
||||
}
|
||||
return temp
|
||||
}
|
||||
|
||||
func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProvider_Type, config *v1pb.IdentityProviderConfig) *storepb.IdentityProviderConfig {
|
||||
if identityProviderType == v1pb.IdentityProvider_OAUTH2 {
|
||||
oauth2Config := config.GetOauth2Config()
|
||||
return &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: oauth2Config.ClientId,
|
||||
ClientSecret: oauth2Config.ClientSecret,
|
||||
AuthUrl: oauth2Config.AuthUrl,
|
||||
TokenUrl: oauth2Config.TokenUrl,
|
||||
UserInfoUrl: oauth2Config.UserInfoUrl,
|
||||
Scopes: oauth2Config.Scopes,
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: oauth2Config.FieldMapping.Identifier,
|
||||
DisplayName: oauth2Config.FieldMapping.DisplayName,
|
||||
Email: oauth2Config.FieldMapping.Email,
|
||||
AvatarUrl: oauth2Config.FieldMapping.AvatarUrl,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider {
|
||||
if userRole != store.RoleAdmin {
|
||||
if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 {
|
||||
identityProvider.Config.GetOauth2Config().ClientSecret = ""
|
||||
}
|
||||
}
|
||||
|
||||
return identityProvider
|
||||
}
|
||||
285
server/router/api/v1/instance_service.go
Normal file
285
server/router/api/v1/instance_service.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// GetInstanceProfile returns the instance profile.
|
||||
func (s *APIV1Service) GetInstanceProfile(ctx context.Context, _ *v1pb.GetInstanceProfileRequest) (*v1pb.InstanceProfile, error) {
|
||||
admin, err := s.GetInstanceAdmin(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance admin: %v", err)
|
||||
}
|
||||
|
||||
instanceProfile := &v1pb.InstanceProfile{
|
||||
Version: s.Profile.Version,
|
||||
Demo: s.Profile.Demo,
|
||||
InstanceUrl: s.Profile.InstanceURL,
|
||||
Admin: admin, // nil when not initialized
|
||||
}
|
||||
return instanceProfile, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.GetInstanceSettingRequest) (*v1pb.InstanceSetting, error) {
|
||||
instanceSettingKeyString, err := ExtractInstanceSettingKeyFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid instance setting name: %v", err)
|
||||
}
|
||||
|
||||
instanceSettingKey := storepb.InstanceSettingKey(storepb.InstanceSettingKey_value[instanceSettingKeyString])
|
||||
// Get instance setting from store with default value.
|
||||
switch instanceSettingKey {
|
||||
case storepb.InstanceSettingKey_BASIC:
|
||||
_, err = s.Store.GetInstanceBasicSetting(ctx)
|
||||
case storepb.InstanceSettingKey_GENERAL:
|
||||
_, err = s.Store.GetInstanceGeneralSetting(ctx)
|
||||
case storepb.InstanceSettingKey_MEMO_RELATED:
|
||||
_, err = s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
case storepb.InstanceSettingKey_STORAGE:
|
||||
_, err = s.Store.GetInstanceStorageSetting(ctx)
|
||||
default:
|
||||
return nil, status.Errorf(codes.InvalidArgument, "unsupported instance setting key: %v", instanceSettingKey)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance setting: %v", err)
|
||||
}
|
||||
|
||||
instanceSetting, err := s.Store.GetInstanceSetting(ctx, &store.FindInstanceSetting{
|
||||
Name: instanceSettingKey.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance setting: %v", err)
|
||||
}
|
||||
if instanceSetting == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "instance setting not found")
|
||||
}
|
||||
|
||||
// For storage setting, only admin can get it.
|
||||
if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if user.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
return convertInstanceSettingFromStore(instanceSetting), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateInstanceSetting(ctx context.Context, request *v1pb.UpdateInstanceSettingRequest) (*v1pb.InstanceSetting, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if user.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// TODO: Apply update_mask if specified
|
||||
_ = request.UpdateMask
|
||||
|
||||
updateSetting := convertInstanceSettingToStore(request.Setting)
|
||||
instanceSetting, err := s.Store.UpsertInstanceSetting(ctx, updateSetting)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert instance setting: %v", err)
|
||||
}
|
||||
|
||||
return convertInstanceSettingFromStore(instanceSetting), nil
|
||||
}
|
||||
|
||||
func convertInstanceSettingFromStore(setting *storepb.InstanceSetting) *v1pb.InstanceSetting {
|
||||
instanceSetting := &v1pb.InstanceSetting{
|
||||
Name: fmt.Sprintf("instance/settings/%s", setting.Key.String()),
|
||||
}
|
||||
switch setting.Value.(type) {
|
||||
case *storepb.InstanceSetting_GeneralSetting:
|
||||
instanceSetting.Value = &v1pb.InstanceSetting_GeneralSetting_{
|
||||
GeneralSetting: convertInstanceGeneralSettingFromStore(setting.GetGeneralSetting()),
|
||||
}
|
||||
case *storepb.InstanceSetting_StorageSetting:
|
||||
instanceSetting.Value = &v1pb.InstanceSetting_StorageSetting_{
|
||||
StorageSetting: convertInstanceStorageSettingFromStore(setting.GetStorageSetting()),
|
||||
}
|
||||
case *storepb.InstanceSetting_MemoRelatedSetting:
|
||||
instanceSetting.Value = &v1pb.InstanceSetting_MemoRelatedSetting_{
|
||||
MemoRelatedSetting: convertInstanceMemoRelatedSettingFromStore(setting.GetMemoRelatedSetting()),
|
||||
}
|
||||
}
|
||||
return instanceSetting
|
||||
}
|
||||
|
||||
func convertInstanceSettingToStore(setting *v1pb.InstanceSetting) *storepb.InstanceSetting {
|
||||
settingKeyString, _ := ExtractInstanceSettingKeyFromName(setting.Name)
|
||||
instanceSetting := &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey(storepb.InstanceSettingKey_value[settingKeyString]),
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: convertInstanceGeneralSettingToStore(setting.GetGeneralSetting()),
|
||||
},
|
||||
}
|
||||
switch instanceSetting.Key {
|
||||
case storepb.InstanceSettingKey_GENERAL:
|
||||
instanceSetting.Value = &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: convertInstanceGeneralSettingToStore(setting.GetGeneralSetting()),
|
||||
}
|
||||
case storepb.InstanceSettingKey_STORAGE:
|
||||
instanceSetting.Value = &storepb.InstanceSetting_StorageSetting{
|
||||
StorageSetting: convertInstanceStorageSettingToStore(setting.GetStorageSetting()),
|
||||
}
|
||||
case storepb.InstanceSettingKey_MEMO_RELATED:
|
||||
instanceSetting.Value = &storepb.InstanceSetting_MemoRelatedSetting{
|
||||
MemoRelatedSetting: convertInstanceMemoRelatedSettingToStore(setting.GetMemoRelatedSetting()),
|
||||
}
|
||||
default:
|
||||
// Keep the default GeneralSetting value
|
||||
}
|
||||
return instanceSetting
|
||||
}
|
||||
|
||||
func convertInstanceGeneralSettingFromStore(setting *storepb.InstanceGeneralSetting) *v1pb.InstanceSetting_GeneralSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
generalSetting := &v1pb.InstanceSetting_GeneralSetting{
|
||||
DisallowUserRegistration: setting.DisallowUserRegistration,
|
||||
DisallowPasswordAuth: setting.DisallowPasswordAuth,
|
||||
AdditionalScript: setting.AdditionalScript,
|
||||
AdditionalStyle: setting.AdditionalStyle,
|
||||
WeekStartDayOffset: setting.WeekStartDayOffset,
|
||||
DisallowChangeUsername: setting.DisallowChangeUsername,
|
||||
DisallowChangeNickname: setting.DisallowChangeNickname,
|
||||
}
|
||||
if setting.CustomProfile != nil {
|
||||
generalSetting.CustomProfile = &v1pb.InstanceSetting_GeneralSetting_CustomProfile{
|
||||
Title: setting.CustomProfile.Title,
|
||||
Description: setting.CustomProfile.Description,
|
||||
LogoUrl: setting.CustomProfile.LogoUrl,
|
||||
}
|
||||
}
|
||||
return generalSetting
|
||||
}
|
||||
|
||||
func convertInstanceGeneralSettingToStore(setting *v1pb.InstanceSetting_GeneralSetting) *storepb.InstanceGeneralSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
generalSetting := &storepb.InstanceGeneralSetting{
|
||||
DisallowUserRegistration: setting.DisallowUserRegistration,
|
||||
DisallowPasswordAuth: setting.DisallowPasswordAuth,
|
||||
AdditionalScript: setting.AdditionalScript,
|
||||
AdditionalStyle: setting.AdditionalStyle,
|
||||
WeekStartDayOffset: setting.WeekStartDayOffset,
|
||||
DisallowChangeUsername: setting.DisallowChangeUsername,
|
||||
DisallowChangeNickname: setting.DisallowChangeNickname,
|
||||
}
|
||||
if setting.CustomProfile != nil {
|
||||
generalSetting.CustomProfile = &storepb.InstanceCustomProfile{
|
||||
Title: setting.CustomProfile.Title,
|
||||
Description: setting.CustomProfile.Description,
|
||||
LogoUrl: setting.CustomProfile.LogoUrl,
|
||||
}
|
||||
}
|
||||
return generalSetting
|
||||
}
|
||||
|
||||
func convertInstanceStorageSettingFromStore(settingpb *storepb.InstanceStorageSetting) *v1pb.InstanceSetting_StorageSetting {
|
||||
if settingpb == nil {
|
||||
return nil
|
||||
}
|
||||
setting := &v1pb.InstanceSetting_StorageSetting{
|
||||
StorageType: v1pb.InstanceSetting_StorageSetting_StorageType(settingpb.StorageType),
|
||||
FilepathTemplate: settingpb.FilepathTemplate,
|
||||
UploadSizeLimitMb: settingpb.UploadSizeLimitMb,
|
||||
}
|
||||
if settingpb.S3Config != nil {
|
||||
setting.S3Config = &v1pb.InstanceSetting_StorageSetting_S3Config{
|
||||
AccessKeyId: settingpb.S3Config.AccessKeyId,
|
||||
AccessKeySecret: settingpb.S3Config.AccessKeySecret,
|
||||
Endpoint: settingpb.S3Config.Endpoint,
|
||||
Region: settingpb.S3Config.Region,
|
||||
Bucket: settingpb.S3Config.Bucket,
|
||||
UsePathStyle: settingpb.S3Config.UsePathStyle,
|
||||
}
|
||||
}
|
||||
return setting
|
||||
}
|
||||
|
||||
func convertInstanceStorageSettingToStore(setting *v1pb.InstanceSetting_StorageSetting) *storepb.InstanceStorageSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
settingpb := &storepb.InstanceStorageSetting{
|
||||
StorageType: storepb.InstanceStorageSetting_StorageType(setting.StorageType),
|
||||
FilepathTemplate: setting.FilepathTemplate,
|
||||
UploadSizeLimitMb: setting.UploadSizeLimitMb,
|
||||
}
|
||||
if setting.S3Config != nil {
|
||||
settingpb.S3Config = &storepb.StorageS3Config{
|
||||
AccessKeyId: setting.S3Config.AccessKeyId,
|
||||
AccessKeySecret: setting.S3Config.AccessKeySecret,
|
||||
Endpoint: setting.S3Config.Endpoint,
|
||||
Region: setting.S3Config.Region,
|
||||
Bucket: setting.S3Config.Bucket,
|
||||
UsePathStyle: setting.S3Config.UsePathStyle,
|
||||
}
|
||||
}
|
||||
return settingpb
|
||||
}
|
||||
|
||||
func convertInstanceMemoRelatedSettingFromStore(setting *storepb.InstanceMemoRelatedSetting) *v1pb.InstanceSetting_MemoRelatedSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
return &v1pb.InstanceSetting_MemoRelatedSetting{
|
||||
DisallowPublicVisibility: setting.DisallowPublicVisibility,
|
||||
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
|
||||
ContentLengthLimit: setting.ContentLengthLimit,
|
||||
EnableDoubleClickEdit: setting.EnableDoubleClickEdit,
|
||||
Reactions: setting.Reactions,
|
||||
}
|
||||
}
|
||||
|
||||
func convertInstanceMemoRelatedSettingToStore(setting *v1pb.InstanceSetting_MemoRelatedSetting) *storepb.InstanceMemoRelatedSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
return &storepb.InstanceMemoRelatedSetting{
|
||||
DisallowPublicVisibility: setting.DisallowPublicVisibility,
|
||||
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
|
||||
ContentLengthLimit: setting.ContentLengthLimit,
|
||||
EnableDoubleClickEdit: setting.EnableDoubleClickEdit,
|
||||
Reactions: setting.Reactions,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetInstanceAdmin(ctx context.Context) (*v1pb.User, error) {
|
||||
adminUserType := store.RoleAdmin
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Role: &adminUserType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to find admin")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return convertUserFromStore(user), nil
|
||||
}
|
||||
136
server/router/api/v1/memo_attachment_service.go
Normal file
136
server/router/api/v1/memo_attachment_service.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
// Delete attachments that are not in the request.
|
||||
for _, attachment := range attachments {
|
||||
found := false
|
||||
for _, requestAttachment := range request.Attachments {
|
||||
requestAttachmentUID, err := ExtractAttachmentUIDFromName(requestAttachment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
}
|
||||
if attachment.UID == requestAttachmentUID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if err = s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{
|
||||
ID: int32(attachment.ID),
|
||||
MemoID: &memo.ID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slices.Reverse(request.Attachments)
|
||||
// Update attachments' memo_id in the request.
|
||||
for index, attachment := range request.Attachments {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(attachment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
}
|
||||
tempAttachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
}
|
||||
if tempAttachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found: %s", attachmentUID)
|
||||
}
|
||||
updatedTs := time.Now().Unix() + int64(index)
|
||||
if err := s.Store.UpdateAttachment(ctx, &store.UpdateAttachment{
|
||||
ID: tempAttachment.ID,
|
||||
MemoID: &memo.ID,
|
||||
UpdatedTs: &updatedTs,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.ListMemoAttachmentsRequest) (*v1pb.ListMemoAttachmentsResponse, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
// Check memo visibility.
|
||||
if memo.Visibility != store.Public {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoAttachmentsResponse{
|
||||
Attachments: []*v1pb.Attachment{},
|
||||
}
|
||||
for _, attachment := range attachments {
|
||||
response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
181
server/router/api/v1/memo_relation_service.go
Normal file
181
server/router/api/v1/memo_relation_service.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
referenceType := store.MemoRelationReference
|
||||
// Delete all reference relations first.
|
||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
MemoID: &memo.ID,
|
||||
Type: &referenceType,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo relation")
|
||||
}
|
||||
|
||||
for _, relation := range request.Relations {
|
||||
// Ignore reflexive relations.
|
||||
if request.Name == relation.RelatedMemo.Name {
|
||||
continue
|
||||
}
|
||||
// Ignore comment relations as there's no need to update a comment's relation.
|
||||
// Inserting/Deleting a comment is handled elsewhere.
|
||||
if relation.Type == v1pb.MemoRelation_COMMENT {
|
||||
continue
|
||||
}
|
||||
relatedMemoUID, err := ExtractMemoUIDFromName(relation.RelatedMemo.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
|
||||
}
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &relatedMemoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get related memo")
|
||||
}
|
||||
if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: convertMemoRelationTypeToStore(relation.Type),
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert memo relation")
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.ListMemoRelationsRequest) (*v1pb.ListMemoRelationsResponse, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
var memoFilter string
|
||||
if currentUser == nil {
|
||||
memoFilter = `visibility == "PUBLIC"`
|
||||
} else {
|
||||
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
|
||||
}
|
||||
relationList := []*v1pb.MemoRelation{}
|
||||
tempList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memo.ID,
|
||||
MemoFilter: &memoFilter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memo relations: %v", err)
|
||||
}
|
||||
for _, raw := range tempList {
|
||||
relation, err := s.convertMemoRelationFromStore(ctx, raw)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert memo relation")
|
||||
}
|
||||
relationList = append(relationList, relation)
|
||||
}
|
||||
tempList, err = s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
RelatedMemoID: &memo.ID,
|
||||
MemoFilter: &memoFilter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list related memo relations: %v", err)
|
||||
}
|
||||
for _, raw := range tempList {
|
||||
relation, err := s.convertMemoRelationFromStore(ctx, raw)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert memo relation")
|
||||
}
|
||||
relationList = append(relationList, relation)
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoRelationsResponse{
|
||||
Relations: relationList,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertMemoRelationFromStore(ctx context.Context, memoRelation *store.MemoRelation) (*v1pb.MemoRelation, error) {
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoRelation.MemoID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
memoSnippet, err := s.getMemoContentSnippet(memo.Content)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo content snippet")
|
||||
}
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoRelation.RelatedMemoID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get related memo: %v", err)
|
||||
}
|
||||
relatedMemoSnippet, err := s.getMemoContentSnippet(relatedMemo.Content)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get related memo content snippet")
|
||||
}
|
||||
return &v1pb.MemoRelation{
|
||||
Memo: &v1pb.MemoRelation_Memo{
|
||||
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
|
||||
Snippet: memoSnippet,
|
||||
},
|
||||
RelatedMemo: &v1pb.MemoRelation_Memo{
|
||||
Name: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
|
||||
Snippet: relatedMemoSnippet,
|
||||
},
|
||||
Type: convertMemoRelationTypeFromStore(memoRelation.Type),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertMemoRelationTypeFromStore(relationType store.MemoRelationType) v1pb.MemoRelation_Type {
|
||||
switch relationType {
|
||||
case store.MemoRelationReference:
|
||||
return v1pb.MemoRelation_REFERENCE
|
||||
case store.MemoRelationComment:
|
||||
return v1pb.MemoRelation_COMMENT
|
||||
default:
|
||||
return v1pb.MemoRelation_TYPE_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertMemoRelationTypeToStore(relationType v1pb.MemoRelation_Type) store.MemoRelationType {
|
||||
switch relationType {
|
||||
case v1pb.MemoRelation_COMMENT:
|
||||
return store.MemoRelationComment
|
||||
default:
|
||||
return store.MemoRelationReference
|
||||
}
|
||||
}
|
||||
866
server/router/api/v1/memo_service.go
Normal file
866
server/router/api/v1/memo_service.go
Normal file
@@ -0,0 +1,866 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/usememos/memos/internal/base"
|
||||
"github.com/usememos/memos/plugin/webhook"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/runner/memopayload"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Use custom memo_id if provided, otherwise generate a new UUID
|
||||
memoUID := strings.TrimSpace(request.MemoId)
|
||||
if memoUID == "" {
|
||||
memoUID = shortuuid.New()
|
||||
} else if !base.UIDMatcher.MatchString(memoUID) {
|
||||
// Validate custom memo ID format
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo_id format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen")
|
||||
}
|
||||
|
||||
create := &store.Memo{
|
||||
UID: memoUID,
|
||||
CreatorID: user.ID,
|
||||
Content: request.Memo.Content,
|
||||
Visibility: convertVisibilityToStore(request.Memo.Visibility),
|
||||
}
|
||||
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
|
||||
}
|
||||
|
||||
// Handle display_time first: if provided, use it to set the appropriate timestamp
|
||||
// based on the instance setting (similar to UpdateMemo logic)
|
||||
// Note: explicit create_time/update_time below will override this if provided
|
||||
if request.Memo.DisplayTime != nil && request.Memo.DisplayTime.IsValid() {
|
||||
displayTs := request.Memo.DisplayTime.AsTime().Unix()
|
||||
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
create.UpdatedTs = displayTs
|
||||
} else {
|
||||
create.CreatedTs = displayTs
|
||||
}
|
||||
}
|
||||
|
||||
// Set custom timestamps if provided in the request
|
||||
// These take precedence over display_time
|
||||
if request.Memo.CreateTime != nil && request.Memo.CreateTime.IsValid() {
|
||||
createdTs := request.Memo.CreateTime.AsTime().Unix()
|
||||
create.CreatedTs = createdTs
|
||||
}
|
||||
if request.Memo.UpdateTime != nil && request.Memo.UpdateTime.IsValid() {
|
||||
updatedTs := request.Memo.UpdateTime.AsTime().Unix()
|
||||
create.UpdatedTs = updatedTs
|
||||
}
|
||||
|
||||
if instanceMemoRelatedSetting.DisallowPublicVisibility && create.Visibility == store.Public {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
|
||||
}
|
||||
contentLengthLimit, err := s.getContentLengthLimit(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
|
||||
}
|
||||
if len(create.Content) > contentLengthLimit {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
|
||||
}
|
||||
if err := memopayload.RebuildMemoPayload(create, s.MarkdownService); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
|
||||
}
|
||||
if request.Memo.Location != nil {
|
||||
create.Payload.Location = convertLocationToStore(request.Memo.Location)
|
||||
}
|
||||
|
||||
memo, err := s.Store.CreateMemo(ctx, create)
|
||||
if err != nil {
|
||||
// Check for unique constraint violation (AIP-133 compliance)
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "UNIQUE constraint failed") ||
|
||||
strings.Contains(errMsg, "duplicate key") ||
|
||||
strings.Contains(errMsg, "Duplicate entry") {
|
||||
return nil, status.Errorf(codes.AlreadyExists, "memo with ID %q already exists", memoUID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attachments := []*store.Attachment{}
|
||||
|
||||
if len(request.Memo.Attachments) > 0 {
|
||||
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
|
||||
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
|
||||
Attachments: request.Memo.Attachments,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo attachments")
|
||||
}
|
||||
|
||||
a, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo attachments")
|
||||
}
|
||||
attachments = a
|
||||
}
|
||||
if len(request.Memo.Relations) > 0 {
|
||||
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
|
||||
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
|
||||
Relations: request.Memo.Relations,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo relations")
|
||||
}
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil, attachments)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
// Try to dispatch webhook when memo is created.
|
||||
if err := s.DispatchMemoCreatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosRequest) (*v1pb.ListMemosResponse, error) {
|
||||
memoFind := &store.FindMemo{
|
||||
// Exclude comments by default.
|
||||
ExcludeComments: true,
|
||||
}
|
||||
if request.State == v1pb.State_ARCHIVED {
|
||||
state := store.Archived
|
||||
memoFind.RowStatus = &state
|
||||
} else {
|
||||
state := store.Normal
|
||||
memoFind.RowStatus = &state
|
||||
}
|
||||
|
||||
// Parse order_by field (replaces the old sort and direction fields)
|
||||
if request.OrderBy != "" {
|
||||
if err := s.parseMemoOrderBy(request.OrderBy, memoFind); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid order_by: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Default ordering by display_time desc
|
||||
memoFind.OrderByTimeAsc = false
|
||||
}
|
||||
|
||||
if request.Filter != "" {
|
||||
if err := s.validateFilter(ctx, request.Filter); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
memoFind.Filters = append(memoFind.Filters, request.Filter)
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if currentUser == nil {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public}
|
||||
} else {
|
||||
if memoFind.CreatorID == nil {
|
||||
filter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
|
||||
memoFind.Filters = append(memoFind.Filters, filter)
|
||||
} else if *memoFind.CreatorID != currentUser.ID {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
|
||||
}
|
||||
}
|
||||
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
|
||||
}
|
||||
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
memoFind.OrderByUpdatedTs = true
|
||||
}
|
||||
|
||||
var limit, offset int
|
||||
if request.PageToken != "" {
|
||||
var pageToken v1pb.PageToken
|
||||
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
|
||||
}
|
||||
limit = int(pageToken.Limit)
|
||||
offset = int(pageToken.Offset)
|
||||
} else {
|
||||
limit = int(request.PageSize)
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = DefaultPageSize
|
||||
}
|
||||
limitPlusOne := limit + 1
|
||||
memoFind.Limit = &limitPlusOne
|
||||
memoFind.Offset = &offset
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
|
||||
memoMessages := []*v1pb.Memo{}
|
||||
nextPageToken := ""
|
||||
if len(memos) == limitPlusOne {
|
||||
memos = memos[:limit]
|
||||
nextPageToken, err = getPageToken(limit, offset+limit)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(memos) == 0 {
|
||||
response := &v1pb.ListMemosResponse{
|
||||
Memos: memoMessages,
|
||||
NextPageToken: nextPageToken,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
reactionMap := make(map[string][]*store.Reaction)
|
||||
contentIDs := make([]string, 0, len(memos))
|
||||
|
||||
attachmentMap := make(map[int32][]*store.Attachment)
|
||||
memoIDs := make([]int32, 0, len(memos))
|
||||
|
||||
for _, m := range memos {
|
||||
contentIDs = append(contentIDs, fmt.Sprintf("%s%s", MemoNamePrefix, m.UID))
|
||||
memoIDs = append(memoIDs, m.ID)
|
||||
}
|
||||
|
||||
// REACTIONS
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ContentIDList: contentIDs})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
}
|
||||
for _, reaction := range reactions {
|
||||
reactionMap[reaction.ContentID] = append(reactionMap[reaction.ContentID], reaction)
|
||||
}
|
||||
|
||||
// ATTACHMENTS
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoIDList: memoIDs})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
for _, attachment := range attachments {
|
||||
attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment)
|
||||
}
|
||||
|
||||
for _, memo := range memos {
|
||||
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
|
||||
reactions := reactionMap[memoName]
|
||||
attachments := attachmentMap[memo.ID]
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
||||
memoMessages = append(memoMessages, memoMessage)
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemosResponse{
|
||||
Memos: memoMessages,
|
||||
NextPageToken: nextPageToken,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest) (*v1pb.Memo, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
UID: &memoUID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
if memo.Visibility != store.Public {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &request.Name,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
}
|
||||
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoRequest) (*v1pb.Memo, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Memo.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
// Only the creator or admin can update the memo.
|
||||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
update := &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
}
|
||||
for _, path := range request.UpdateMask.Paths {
|
||||
if path == "content" {
|
||||
contentLengthLimit, err := s.getContentLengthLimit(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
|
||||
}
|
||||
if len(request.Memo.Content) > contentLengthLimit {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
|
||||
}
|
||||
memo.Content = request.Memo.Content
|
||||
if err := memopayload.RebuildMemoPayload(memo, s.MarkdownService); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
|
||||
}
|
||||
update.Content = &memo.Content
|
||||
update.Payload = memo.Payload
|
||||
} else if path == "visibility" {
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
|
||||
}
|
||||
visibility := convertVisibilityToStore(request.Memo.Visibility)
|
||||
if instanceMemoRelatedSetting.DisallowPublicVisibility && visibility == store.Public {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
|
||||
}
|
||||
update.Visibility = &visibility
|
||||
} else if path == "pinned" {
|
||||
update.Pinned = &request.Memo.Pinned
|
||||
} else if path == "state" {
|
||||
rowStatus := convertStateToStore(request.Memo.State)
|
||||
update.RowStatus = &rowStatus
|
||||
} else if path == "create_time" {
|
||||
createdTs := request.Memo.CreateTime.AsTime().Unix()
|
||||
update.CreatedTs = &createdTs
|
||||
} else if path == "update_time" {
|
||||
updatedTs := time.Now().Unix()
|
||||
if request.Memo.UpdateTime != nil {
|
||||
updatedTs = request.Memo.UpdateTime.AsTime().Unix()
|
||||
}
|
||||
update.UpdatedTs = &updatedTs
|
||||
} else if path == "display_time" {
|
||||
displayTs := request.Memo.DisplayTime.AsTime().Unix()
|
||||
memoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
|
||||
}
|
||||
if memoRelatedSetting.DisplayWithUpdateTime {
|
||||
update.UpdatedTs = &displayTs
|
||||
} else {
|
||||
update.CreatedTs = &displayTs
|
||||
}
|
||||
} else if path == "location" {
|
||||
payload := memo.Payload
|
||||
payload.Location = convertLocationToStore(request.Memo.Location)
|
||||
update.Payload = payload
|
||||
} else if path == "attachments" {
|
||||
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
|
||||
Name: request.Memo.Name,
|
||||
Attachments: request.Memo.Attachments,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo attachments")
|
||||
}
|
||||
} else if path == "relations" {
|
||||
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
|
||||
Name: request.Memo.Name,
|
||||
Relations: request.Memo.Relations,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo relations")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = s.Store.UpdateMemo(ctx, update); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update memo")
|
||||
}
|
||||
|
||||
memo, err = s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo")
|
||||
}
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &request.Memo.Name,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
}
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
// Try to dispatch webhook when memo is updated.
|
||||
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoRequest) (*emptypb.Empty, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
UID: &memoUID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
// Only the creator or admin can update the memo.
|
||||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &request.Name,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
}
|
||||
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments); err == nil {
|
||||
// Try to dispatch webhook when memo is deleted.
|
||||
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Delete memo comments first (store.DeleteMemo handles their relations and attachments)
|
||||
commentType := store.MemoRelationComment
|
||||
relations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{RelatedMemoID: &memo.ID, Type: &commentType})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memo comments")
|
||||
}
|
||||
for _, relation := range relations {
|
||||
if err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: relation.MemoID}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo comment")
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the memo (store.DeleteMemo handles relation and attachment cleanup)
|
||||
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo")
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.CreateMemoCommentRequest) (*v1pb.Memo, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
if relatedMemo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
// Check memo visibility before allowing comment.
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if relatedMemo.Visibility == store.Private && relatedMemo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// Create the memo comment first.
|
||||
memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{
|
||||
Memo: request.Comment,
|
||||
MemoId: request.CommentId,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create memo")
|
||||
}
|
||||
memoUID, err = ExtractMemoUIDFromName(memoComment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
|
||||
// Build the relation between the comment memo and the original memo.
|
||||
_, err = s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create memo relation")
|
||||
}
|
||||
creatorID, err := ExtractUserIDFromName(memoComment.Creator)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo creator")
|
||||
}
|
||||
if memoComment.Visibility != v1pb.Visibility_PRIVATE && creatorID != relatedMemo.CreatorID {
|
||||
activity, err := s.Store.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: creatorID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{
|
||||
MemoComment: &storepb.ActivityMemoCommentPayload{
|
||||
MemoId: memo.ID,
|
||||
RelatedMemoId: relatedMemo.ID,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create activity")
|
||||
}
|
||||
if _, err := s.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: creatorID,
|
||||
ReceiverID: relatedMemo.CreatorID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
ActivityId: &activity.ID,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create inbox")
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.DispatchMemoCommentCreatedWebhook(ctx, memoComment, relatedMemo.CreatorID); err != nil {
|
||||
slog.Warn("Failed to dispatch memo comment created webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
return memoComment, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListMemoCommentsRequest) (*v1pb.ListMemoCommentsResponse, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
var memoFilter string
|
||||
if currentUser == nil {
|
||||
memoFilter = `visibility == "PUBLIC"`
|
||||
} else {
|
||||
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
|
||||
}
|
||||
memoRelationComment := store.MemoRelationComment
|
||||
memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
RelatedMemoID: &memo.ID,
|
||||
Type: &memoRelationComment,
|
||||
MemoFilter: &memoFilter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memo relations")
|
||||
}
|
||||
|
||||
if len(memoRelations) == 0 {
|
||||
response := &v1pb.ListMemoCommentsResponse{
|
||||
Memos: []*v1pb.Memo{},
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
memoRelationIDs := make([]int32, 0, len(memoRelations))
|
||||
for _, m := range memoRelations {
|
||||
memoRelationIDs = append(memoRelationIDs, m.MemoID)
|
||||
}
|
||||
memos, err := s.Store.ListMemos(ctx, &store.FindMemo{IDList: memoRelationIDs})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos")
|
||||
}
|
||||
|
||||
memoIDToNameMap := make(map[int32]string)
|
||||
contentIDs := make([]string, 0, len(memos))
|
||||
memoIDsForAttachments := make([]int32, 0, len(memos))
|
||||
|
||||
for _, memo := range memos {
|
||||
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
|
||||
memoIDToNameMap[memo.ID] = memoName
|
||||
contentIDs = append(contentIDs, memoName)
|
||||
memoIDsForAttachments = append(memoIDsForAttachments, memo.ID)
|
||||
}
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ContentIDList: contentIDs})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
}
|
||||
|
||||
memoReactionsMap := make(map[string][]*store.Reaction)
|
||||
for _, reaction := range reactions {
|
||||
memoReactionsMap[reaction.ContentID] = append(memoReactionsMap[reaction.ContentID], reaction)
|
||||
}
|
||||
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoIDList: memoIDsForAttachments})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
attachmentMap := make(map[int32][]*store.Attachment)
|
||||
for _, attachment := range attachments {
|
||||
attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment)
|
||||
}
|
||||
|
||||
var memosResponse []*v1pb.Memo
|
||||
for _, m := range memos {
|
||||
memoName := memoIDToNameMap[m.ID]
|
||||
reactions := memoReactionsMap[memoName]
|
||||
attachments := attachmentMap[m.ID]
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, m, reactions, attachments)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
memosResponse = append(memosResponse, memoMessage)
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoCommentsResponse{
|
||||
Memos: memosResponse,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) getContentLengthLimit(ctx context.Context) (int, error) {
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return 0, status.Errorf(codes.Internal, "failed to get instance memo related setting")
|
||||
}
|
||||
return int(instanceMemoRelatedSetting.ContentLengthLimit), nil
|
||||
}
|
||||
|
||||
// DispatchMemoCreatedWebhook dispatches webhook when memo is created.
|
||||
func (s *APIV1Service) DispatchMemoCreatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
|
||||
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.created")
|
||||
}
|
||||
|
||||
// DispatchMemoUpdatedWebhook dispatches webhook when memo is updated.
|
||||
func (s *APIV1Service) DispatchMemoUpdatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
|
||||
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.updated")
|
||||
}
|
||||
|
||||
// DispatchMemoDeletedWebhook dispatches webhook when memo is deleted.
|
||||
func (s *APIV1Service) DispatchMemoDeletedWebhook(ctx context.Context, memo *v1pb.Memo) error {
|
||||
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.deleted")
|
||||
}
|
||||
|
||||
// DispatchMemoCommentCreatedWebhook dispatches webhook to the related memo owner when a comment is created.
|
||||
func (s *APIV1Service) DispatchMemoCommentCreatedWebhook(ctx context.Context, commentMemo *v1pb.Memo, relatedMemoCreatorID int32) error {
|
||||
webhooks, err := s.Store.GetUserWebhooks(ctx, relatedMemoCreatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, hook := range webhooks {
|
||||
payload, err := convertMemoToWebhookPayload(commentMemo)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to convert memo to webhook payload")
|
||||
}
|
||||
payload.ActivityType = "memos.memo.comment.created"
|
||||
payload.URL = hook.Url
|
||||
webhook.PostAsync(payload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *v1pb.Memo, activityType string) error {
|
||||
creatorID, err := ExtractUserIDFromName(memo.Creator)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid memo creator")
|
||||
}
|
||||
webhooks, err := s.Store.GetUserWebhooks(ctx, creatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, hook := range webhooks {
|
||||
payload, err := convertMemoToWebhookPayload(memo)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to convert memo to webhook payload")
|
||||
}
|
||||
payload.ActivityType = activityType
|
||||
payload.URL = hook.Url
|
||||
|
||||
// Use asynchronous webhook dispatch
|
||||
webhook.PostAsync(payload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertMemoToWebhookPayload(memo *v1pb.Memo) (*webhook.WebhookRequestPayload, error) {
|
||||
creatorID, err := ExtractUserIDFromName(memo.Creator)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid memo creator")
|
||||
}
|
||||
return &webhook.WebhookRequestPayload{
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creatorID),
|
||||
Memo: memo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) getMemoContentSnippet(content string) (string, error) {
|
||||
// Use goldmark service for snippet generation
|
||||
snippet, err := s.MarkdownService.GenerateSnippet([]byte(content), 64)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to generate snippet")
|
||||
}
|
||||
return snippet, nil
|
||||
}
|
||||
|
||||
// parseMemoOrderBy parses the order_by field and sets the appropriate ordering in memoFind.
|
||||
// Follows AIP-132: supports comma-separated list of fields with optional "desc" suffix.
|
||||
// Example: "pinned desc, display_time desc" or "create_time asc".
|
||||
func (*APIV1Service) parseMemoOrderBy(orderBy string, memoFind *store.FindMemo) error {
|
||||
if strings.TrimSpace(orderBy) == "" {
|
||||
return errors.New("empty order_by")
|
||||
}
|
||||
|
||||
// Split by comma to support multiple sort fields per AIP-132.
|
||||
fields := strings.Split(orderBy, ",")
|
||||
|
||||
// Track if we've seen pinned field.
|
||||
hasPinned := false
|
||||
|
||||
for _, field := range fields {
|
||||
parts := strings.Fields(strings.TrimSpace(field))
|
||||
if len(parts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldName := parts[0]
|
||||
fieldDirection := "desc" // default per AIP-132 (we use desc as default for time fields)
|
||||
if len(parts) > 1 {
|
||||
fieldDirection = strings.ToLower(parts[1])
|
||||
if fieldDirection != "asc" && fieldDirection != "desc" {
|
||||
return errors.Errorf("invalid order direction: %s, must be 'asc' or 'desc'", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
switch fieldName {
|
||||
case "pinned":
|
||||
hasPinned = true
|
||||
memoFind.OrderByPinned = true
|
||||
// Note: pinned is always DESC (true first) regardless of direction specified.
|
||||
case "display_time", "create_time", "name":
|
||||
// Only set if this is the first time field we encounter.
|
||||
if !memoFind.OrderByUpdatedTs {
|
||||
memoFind.OrderByTimeAsc = fieldDirection == "asc"
|
||||
}
|
||||
case "update_time":
|
||||
memoFind.OrderByUpdatedTs = true
|
||||
memoFind.OrderByTimeAsc = fieldDirection == "asc"
|
||||
default:
|
||||
return errors.Errorf("unsupported order field: %s, supported fields are: pinned, display_time, create_time, update_time, name", fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// If only pinned was specified, still need to set a default time ordering.
|
||||
if hasPinned && !memoFind.OrderByUpdatedTs && len(fields) == 1 {
|
||||
memoFind.OrderByTimeAsc = false // default to desc
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
134
server/router/api/v1/memo_service_converter.go
Normal file
134
server/router/api/v1/memo_service_converter.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment) (*v1pb.Memo, error) {
|
||||
displayTs := memo.CreatedTs
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get instance memo related setting")
|
||||
}
|
||||
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
displayTs = memo.UpdatedTs
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
|
||||
memoMessage := &v1pb.Memo{
|
||||
Name: name,
|
||||
State: convertStateFromStore(memo.RowStatus),
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, memo.CreatorID),
|
||||
CreateTime: timestamppb.New(time.Unix(memo.CreatedTs, 0)),
|
||||
UpdateTime: timestamppb.New(time.Unix(memo.UpdatedTs, 0)),
|
||||
DisplayTime: timestamppb.New(time.Unix(displayTs, 0)),
|
||||
Content: memo.Content,
|
||||
Visibility: convertVisibilityFromStore(memo.Visibility),
|
||||
Pinned: memo.Pinned,
|
||||
}
|
||||
if memo.Payload != nil {
|
||||
memoMessage.Tags = memo.Payload.Tags
|
||||
memoMessage.Property = convertMemoPropertyFromStore(memo.Payload.Property)
|
||||
memoMessage.Location = convertLocationFromStore(memo.Payload.Location)
|
||||
}
|
||||
|
||||
if memo.ParentUID != nil {
|
||||
parentName := fmt.Sprintf("%s%s", MemoNamePrefix, *memo.ParentUID)
|
||||
memoMessage.Parent = &parentName
|
||||
}
|
||||
|
||||
memoMessage.Reactions = []*v1pb.Reaction{}
|
||||
|
||||
for _, reaction := range reactions {
|
||||
reactionResponse := convertReactionFromStore(reaction)
|
||||
memoMessage.Reactions = append(memoMessage.Reactions, reactionResponse)
|
||||
}
|
||||
|
||||
listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &v1pb.ListMemoRelationsRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo relations")
|
||||
}
|
||||
memoMessage.Relations = listMemoRelationsResponse.Relations
|
||||
|
||||
memoMessage.Attachments = []*v1pb.Attachment{}
|
||||
|
||||
for _, attachment := range attachments {
|
||||
attachmentResponse := convertAttachmentFromStore(attachment)
|
||||
memoMessage.Attachments = append(memoMessage.Attachments, attachmentResponse)
|
||||
}
|
||||
|
||||
snippet, err := s.getMemoContentSnippet(memo.Content)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo content snippet")
|
||||
}
|
||||
memoMessage.Snippet = snippet
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func convertMemoPropertyFromStore(property *storepb.MemoPayload_Property) *v1pb.Memo_Property {
|
||||
if property == nil {
|
||||
return nil
|
||||
}
|
||||
return &v1pb.Memo_Property{
|
||||
HasLink: property.HasLink,
|
||||
HasTaskList: property.HasTaskList,
|
||||
HasCode: property.HasCode,
|
||||
HasIncompleteTasks: property.HasIncompleteTasks,
|
||||
}
|
||||
}
|
||||
|
||||
func convertLocationFromStore(location *storepb.MemoPayload_Location) *v1pb.Location {
|
||||
if location == nil {
|
||||
return nil
|
||||
}
|
||||
return &v1pb.Location{
|
||||
Placeholder: location.Placeholder,
|
||||
Latitude: location.Latitude,
|
||||
Longitude: location.Longitude,
|
||||
}
|
||||
}
|
||||
|
||||
func convertLocationToStore(location *v1pb.Location) *storepb.MemoPayload_Location {
|
||||
if location == nil {
|
||||
return nil
|
||||
}
|
||||
return &storepb.MemoPayload_Location{
|
||||
Placeholder: location.Placeholder,
|
||||
Latitude: location.Latitude,
|
||||
Longitude: location.Longitude,
|
||||
}
|
||||
}
|
||||
|
||||
func convertVisibilityFromStore(visibility store.Visibility) v1pb.Visibility {
|
||||
switch visibility {
|
||||
case store.Private:
|
||||
return v1pb.Visibility_PRIVATE
|
||||
case store.Protected:
|
||||
return v1pb.Visibility_PROTECTED
|
||||
case store.Public:
|
||||
return v1pb.Visibility_PUBLIC
|
||||
default:
|
||||
return v1pb.Visibility_VISIBILITY_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertVisibilityToStore(visibility v1pb.Visibility) store.Visibility {
|
||||
switch visibility {
|
||||
case v1pb.Visibility_PROTECTED:
|
||||
return store.Protected
|
||||
case v1pb.Visibility_PUBLIC:
|
||||
return store.Public
|
||||
default:
|
||||
return store.Private
|
||||
}
|
||||
}
|
||||
1
server/router/api/v1/memo_service_filter.go
Normal file
1
server/router/api/v1/memo_service_filter.go
Normal file
@@ -0,0 +1 @@
|
||||
package v1
|
||||
153
server/router/api/v1/reaction_service.go
Normal file
153
server/router/api/v1/reaction_service.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.ListMemoReactionsRequest) (*v1pb.ListMemoReactionsResponse, error) {
|
||||
// Extract memo UID and check visibility.
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
// Check memo visibility.
|
||||
if memo.Visibility != store.Public {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &request.Name,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoReactionsResponse{
|
||||
Reactions: []*v1pb.Reaction{},
|
||||
}
|
||||
for _, reaction := range reactions {
|
||||
reactionMessage := convertReactionFromStore(reaction)
|
||||
response.Reactions = append(response.Reactions, reactionMessage)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.UpsertMemoReactionRequest) (*v1pb.Reaction, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Extract memo UID and check visibility before allowing reaction.
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Reaction.ContentId)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
// Check memo visibility.
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user.ID,
|
||||
ContentID: request.Reaction.ContentId,
|
||||
ReactionType: request.Reaction.ReactionType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert reaction")
|
||||
}
|
||||
|
||||
reactionMessage := convertReactionFromStore(reaction)
|
||||
|
||||
return reactionMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
_, reactionID, err := ExtractMemoReactionIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
|
||||
}
|
||||
|
||||
// Get reaction and check ownership.
|
||||
reaction, err := s.Store.GetReaction(ctx, &store.FindReaction{
|
||||
ID: &reactionID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get reaction")
|
||||
}
|
||||
if reaction == nil {
|
||||
// Return permission denied to avoid revealing if reaction exists.
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
if reaction.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
|
||||
ID: reactionID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func convertReactionFromStore(reaction *store.Reaction) *v1pb.Reaction {
|
||||
reactionUID := fmt.Sprintf("%d", reaction.ID)
|
||||
// Generate nested resource name: memos/{memo}/reactions/{reaction}
|
||||
// reaction.ContentID already contains "memos/{memo}"
|
||||
return &v1pb.Reaction{
|
||||
Name: fmt.Sprintf("%s/%s%s", reaction.ContentID, ReactionNamePrefix, reactionUID),
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, reaction.CreatorID),
|
||||
ContentId: reaction.ContentID,
|
||||
ReactionType: reaction.ReactionType,
|
||||
CreateTime: timestamppb.New(time.Unix(reaction.CreatedTs, 0)),
|
||||
}
|
||||
}
|
||||
158
server/router/api/v1/resource_name.go
Normal file
158
server/router/api/v1/resource_name.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
InstanceSettingNamePrefix = "instance/settings/"
|
||||
UserNamePrefix = "users/"
|
||||
MemoNamePrefix = "memos/"
|
||||
AttachmentNamePrefix = "attachments/"
|
||||
ReactionNamePrefix = "reactions/"
|
||||
InboxNamePrefix = "inboxes/"
|
||||
IdentityProviderNamePrefix = "identity-providers/"
|
||||
ActivityNamePrefix = "activities/"
|
||||
WebhookNamePrefix = "webhooks/"
|
||||
)
|
||||
|
||||
// GetNameParentTokens returns the tokens from a resource name.
|
||||
func GetNameParentTokens(name string, tokenPrefixes ...string) ([]string, error) {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 2*len(tokenPrefixes) {
|
||||
return nil, errors.Errorf("invalid request %q", name)
|
||||
}
|
||||
|
||||
var tokens []string
|
||||
for i, tokenPrefix := range tokenPrefixes {
|
||||
if fmt.Sprintf("%s/", parts[2*i]) != tokenPrefix {
|
||||
return nil, errors.Errorf("invalid prefix %q in request %q", tokenPrefix, name)
|
||||
}
|
||||
if parts[2*i+1] == "" {
|
||||
return nil, errors.Errorf("invalid request %q with empty prefix %q", name, tokenPrefix)
|
||||
}
|
||||
tokens = append(tokens, parts[2*i+1])
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func ExtractInstanceSettingKeyFromName(name string) (string, error) {
|
||||
const prefix = "instance/settings/"
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return "", errors.Errorf("invalid instance setting name: expected prefix %q, got %q", prefix, name)
|
||||
}
|
||||
|
||||
settingKey := strings.TrimPrefix(name, prefix)
|
||||
if settingKey == "" {
|
||||
return "", errors.Errorf("invalid instance setting name: empty setting key in %q", name)
|
||||
}
|
||||
|
||||
// Ensure there are no additional path segments
|
||||
if strings.Contains(settingKey, "/") {
|
||||
return "", errors.Errorf("invalid instance setting name: setting key cannot contain '/' in %q", name)
|
||||
}
|
||||
|
||||
return settingKey, nil
|
||||
}
|
||||
|
||||
// ExtractUserIDFromName returns the uid from a resource name.
|
||||
func ExtractUserIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, UserNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid user ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// extractUserIdentifierFromName extracts the identifier (ID or username) from a user resource name.
|
||||
// Supports: "users/101" or "users/steven"
|
||||
// Returns the identifier string (e.g., "101" or "steven").
|
||||
func extractUserIdentifierFromName(name string) string {
|
||||
tokens, err := GetNameParentTokens(name, UserNamePrefix)
|
||||
if err != nil || len(tokens) == 0 {
|
||||
return ""
|
||||
}
|
||||
return tokens[0]
|
||||
}
|
||||
|
||||
// ExtractMemoUIDFromName returns the memo UID from a resource name.
|
||||
// e.g., "memos/uuid" -> "uuid".
|
||||
func ExtractMemoUIDFromName(name string) (string, error) {
|
||||
tokens, err := GetNameParentTokens(name, MemoNamePrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
id := tokens[0]
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractAttachmentUIDFromName returns the attachment UID from a resource name.
|
||||
func ExtractAttachmentUIDFromName(name string) (string, error) {
|
||||
tokens, err := GetNameParentTokens(name, AttachmentNamePrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
id := tokens[0]
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractMemoReactionIDFromName returns the memo UID and reaction ID from a resource name.
|
||||
// e.g., "memos/abc/reactions/123" -> ("abc", 123).
|
||||
func ExtractMemoReactionIDFromName(name string) (string, int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, MemoNamePrefix, ReactionNamePrefix)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
memoUID := tokens[0]
|
||||
reactionID, err := util.ConvertStringToInt32(tokens[1])
|
||||
if err != nil {
|
||||
return "", 0, errors.Errorf("invalid reaction ID %q", tokens[1])
|
||||
}
|
||||
return memoUID, reactionID, nil
|
||||
}
|
||||
|
||||
// ExtractInboxIDFromName returns the inbox ID from a resource name.
|
||||
func ExtractInboxIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, InboxNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid inbox ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func ExtractIdentityProviderIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func ExtractActivityIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, ActivityNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid activity ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
346
server/router/api/v1/shortcut_service.go
Normal file
346
server/router/api/v1/shortcut_service.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// Helper function to extract user ID and shortcut ID from shortcut resource name.
|
||||
// Format: users/{user}/shortcuts/{shortcut}.
|
||||
func extractUserAndShortcutIDFromName(name string) (int32, string, error) {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "shortcuts" {
|
||||
return 0, "", errors.Errorf("invalid shortcut name format: %s", name)
|
||||
}
|
||||
|
||||
userID, err := util.ConvertStringToInt32(parts[1])
|
||||
if err != nil {
|
||||
return 0, "", errors.Errorf("invalid user ID %q", parts[1])
|
||||
}
|
||||
|
||||
shortcutID := parts[3]
|
||||
if shortcutID == "" {
|
||||
return 0, "", errors.Errorf("empty shortcut ID in name: %s", name)
|
||||
}
|
||||
|
||||
return userID, shortcutID, nil
|
||||
}
|
||||
|
||||
// Helper function to construct shortcut resource name.
|
||||
func constructShortcutName(userID int32, shortcutID string) string {
|
||||
return fmt.Sprintf("users/%d/shortcuts/%s", userID, shortcutID)
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListShortcuts(ctx context.Context, request *v1pb.ListShortcutsRequest) (*v1pb.ListShortcutsResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user setting: %v", err)
|
||||
}
|
||||
if userSetting == nil {
|
||||
return &v1pb.ListShortcutsResponse{
|
||||
Shortcuts: []*v1pb.Shortcut{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
shortcutsUserSetting := userSetting.GetShortcuts()
|
||||
shortcuts := []*v1pb.Shortcut{}
|
||||
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
|
||||
shortcuts = append(shortcuts, &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, shortcut.GetId()),
|
||||
Title: shortcut.GetTitle(),
|
||||
Filter: shortcut.GetFilter(),
|
||||
})
|
||||
}
|
||||
|
||||
return &v1pb.ListShortcutsResponse{
|
||||
Shortcuts: shortcuts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetShortcut(ctx context.Context, request *v1pb.GetShortcutRequest) (*v1pb.Shortcut, error) {
|
||||
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userSetting == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
|
||||
shortcutsUserSetting := userSetting.GetShortcuts()
|
||||
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
|
||||
if shortcut.GetId() == shortcutID {
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, shortcut.GetId()),
|
||||
Title: shortcut.GetTitle(),
|
||||
Filter: shortcut.GetFilter(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateShortcutRequest) (*v1pb.Shortcut, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
newShortcut := &storepb.ShortcutsUserSetting_Shortcut{
|
||||
Id: util.GenUUID(),
|
||||
Title: request.Shortcut.GetTitle(),
|
||||
Filter: request.Shortcut.GetFilter(),
|
||||
}
|
||||
if newShortcut.Title == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "title is required")
|
||||
}
|
||||
if err := s.validateFilter(ctx, newShortcut.Filter); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
if request.ValidateOnly {
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, newShortcut.GetId()),
|
||||
Title: newShortcut.GetTitle(),
|
||||
Filter: newShortcut.GetFilter(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userSetting == nil {
|
||||
userSetting = &storepb.UserSetting{
|
||||
UserId: userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{
|
||||
Shortcuts: &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
shortcutsUserSetting := userSetting.GetShortcuts()
|
||||
shortcuts := shortcutsUserSetting.GetShortcuts()
|
||||
shortcuts = append(shortcuts, newShortcut)
|
||||
shortcutsUserSetting.Shortcuts = shortcuts
|
||||
|
||||
userSetting.Value = &storepb.UserSetting_Shortcuts{
|
||||
Shortcuts: shortcutsUserSetting,
|
||||
}
|
||||
|
||||
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
|
||||
}
|
||||
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, newShortcut.GetId()),
|
||||
Title: newShortcut.GetTitle(),
|
||||
Filter: newShortcut.GetFilter(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateShortcut(ctx context.Context, request *v1pb.UpdateShortcutRequest) (*v1pb.Shortcut, error) {
|
||||
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Shortcut.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
|
||||
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userSetting == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
|
||||
shortcutsUserSetting := userSetting.GetShortcuts()
|
||||
shortcuts := shortcutsUserSetting.GetShortcuts()
|
||||
var foundShortcut *storepb.ShortcutsUserSetting_Shortcut
|
||||
newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts))
|
||||
for _, shortcut := range shortcuts {
|
||||
if shortcut.GetId() == shortcutID {
|
||||
foundShortcut = shortcut
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
if field == "title" {
|
||||
if request.Shortcut.GetTitle() == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "title is required")
|
||||
}
|
||||
shortcut.Title = request.Shortcut.GetTitle()
|
||||
} else if field == "filter" {
|
||||
if err := s.validateFilter(ctx, request.Shortcut.GetFilter()); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
shortcut.Filter = request.Shortcut.GetFilter()
|
||||
}
|
||||
}
|
||||
}
|
||||
newShortcuts = append(newShortcuts, shortcut)
|
||||
}
|
||||
|
||||
if foundShortcut == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
|
||||
shortcutsUserSetting.Shortcuts = newShortcuts
|
||||
userSetting.Value = &storepb.UserSetting_Shortcuts{
|
||||
Shortcuts: shortcutsUserSetting,
|
||||
}
|
||||
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, foundShortcut.GetId()),
|
||||
Title: foundShortcut.GetTitle(),
|
||||
Filter: foundShortcut.GetFilter(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteShortcutRequest) (*emptypb.Empty, error) {
|
||||
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userSetting == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
|
||||
shortcutsUserSetting := userSetting.GetShortcuts()
|
||||
shortcuts := shortcutsUserSetting.GetShortcuts()
|
||||
newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts))
|
||||
found := false
|
||||
for _, shortcut := range shortcuts {
|
||||
if shortcut.GetId() != shortcutID {
|
||||
newShortcuts = append(newShortcuts, shortcut)
|
||||
} else {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
shortcutsUserSetting.Shortcuts = newShortcuts
|
||||
userSetting.Value = &storepb.UserSetting_Shortcuts{
|
||||
Shortcuts: shortcutsUserSetting,
|
||||
}
|
||||
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) validateFilter(ctx context.Context, filterStr string) error {
|
||||
if filterStr == "" {
|
||||
return errors.New("filter cannot be empty")
|
||||
}
|
||||
|
||||
engine, err := filter.DefaultEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var dialect filter.DialectName
|
||||
switch s.Profile.Driver {
|
||||
case "mysql":
|
||||
dialect = filter.DialectMySQL
|
||||
case "postgres":
|
||||
dialect = filter.DialectPostgres
|
||||
default:
|
||||
dialect = filter.DialectSQLite
|
||||
}
|
||||
|
||||
if _, err := engine.CompileToStatement(ctx, filterStr, filter.RenderOptions{Dialect: dialect}); err != nil {
|
||||
return errors.Wrap(err, "failed to compile filter")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
263
server/router/api/v1/test/activity_deleted_memo_test.go
Normal file
263
server/router/api/v1/test/activity_deleted_memo_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
v1 "github.com/usememos/memos/server/router/api/v1" //nolint:revive
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// TestListActivitiesWithDeletedMemos verifies that ListActivities gracefully handles
|
||||
// activities that reference deleted memos instead of crashing the entire request.
|
||||
func TestListActivitiesWithDeletedMemos(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users - one to create memo, one to comment
|
||||
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
|
||||
require.NoError(t, err)
|
||||
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
|
||||
|
||||
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
|
||||
require.NoError(t, err)
|
||||
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
|
||||
|
||||
// Create a memo by userOne
|
||||
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Original memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo1)
|
||||
|
||||
// Create a comment on the memo by userTwo (this will create an activity for userOne)
|
||||
comment, err := ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: memo1.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "This is a comment",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, comment)
|
||||
|
||||
// Verify activity was created for the comment (check from userOne's perspective - they receive the notification)
|
||||
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
initialActivityCount := len(activities.Activities)
|
||||
require.Greater(t, initialActivityCount, 0, "Should have at least one activity")
|
||||
|
||||
// Delete the original memo (this deletes the comment too)
|
||||
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
|
||||
Name: memo1.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List activities again - should succeed even though the memo is deleted
|
||||
activities, err = ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
// Activities list should be empty or not contain the deleted memo activity
|
||||
for _, activity := range activities.Activities {
|
||||
if activity.Payload != nil && activity.Payload.GetMemoComment() != nil {
|
||||
require.NotEqual(t, memo1.Name, activity.Payload.GetMemoComment().Memo,
|
||||
"Activity should not reference deleted memo")
|
||||
}
|
||||
}
|
||||
// After deletion, there should be fewer activities
|
||||
require.LessOrEqual(t, len(activities.Activities), initialActivityCount-1,
|
||||
"Should have filtered out the activity for the deleted memo")
|
||||
}
|
||||
|
||||
// TestGetActivityWithDeletedMemo verifies that GetActivity returns a proper error
|
||||
// when trying to fetch an activity that references a deleted memo.
|
||||
func TestGetActivityWithDeletedMemo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
|
||||
require.NoError(t, err)
|
||||
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
|
||||
|
||||
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
|
||||
require.NoError(t, err)
|
||||
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
|
||||
|
||||
// Create a memo by userOne
|
||||
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Original memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo1)
|
||||
|
||||
// Create a comment to trigger activity creation by userTwo
|
||||
comment, err := ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: memo1.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "Comment",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, comment)
|
||||
|
||||
// Get the activity ID by listing activities from userOne's perspective
|
||||
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(activities.Activities), 0)
|
||||
|
||||
activityName := activities.Activities[0].Name
|
||||
|
||||
// Delete the memo
|
||||
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
|
||||
Name: memo1.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get the specific activity - should return NotFound error
|
||||
_, err = ts.Service.GetActivity(userOneCtx, &apiv1.GetActivityRequest{
|
||||
Name: activityName,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "activity references deleted content")
|
||||
}
|
||||
|
||||
// TestActivitiesWithPartiallyDeletedMemos verifies that when some memos are deleted,
|
||||
// other valid activities are still returned.
|
||||
func TestActivitiesWithPartiallyDeletedMemos(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
|
||||
require.NoError(t, err)
|
||||
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
|
||||
|
||||
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
|
||||
require.NoError(t, err)
|
||||
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
|
||||
|
||||
// Create two memos by userOne
|
||||
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "First memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memo2, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Second memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create comments on both by userTwo (creates activities for userOne)
|
||||
_, err = ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: memo1.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "Comment on first",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: memo2.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "Comment on second",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have 2 activities from userOne's perspective
|
||||
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(activities.Activities))
|
||||
|
||||
// Delete first memo
|
||||
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
|
||||
Name: memo1.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List activities - should still work and return only the second memo's activity
|
||||
activities, err = ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(activities.Activities), "Should have 1 activity remaining")
|
||||
|
||||
// Verify the remaining activity relates to a valid memo
|
||||
require.NotNil(t, activities.Activities[0].Payload.GetMemoComment())
|
||||
require.Contains(t, activities.Activities[0].Payload.GetMemoComment().RelatedMemo, "memos/")
|
||||
}
|
||||
|
||||
// TestActivityStoreDirectDeletion tests the scenario where a memo is deleted directly
|
||||
// from the store (simulating database-level deletion or migration).
|
||||
func TestActivityStoreDirectDeletion(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "test-user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a memo
|
||||
memo1, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a comment
|
||||
comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: memo1.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "Test comment",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract memo UID from the comment name
|
||||
commentMemoUID, err := v1.ExtractMemoUIDFromName(comment.Name)
|
||||
require.NoError(t, err)
|
||||
|
||||
commentMemo, err := ts.Store.GetMemo(ctx, &store.FindMemo{
|
||||
UID: &commentMemoUID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, commentMemo)
|
||||
|
||||
// Delete the comment memo directly from store (simulating orphaned activity)
|
||||
err = ts.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: commentMemo.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List activities should still succeed even with orphaned activity
|
||||
activities, err := ts.Service.ListActivities(userCtx, &apiv1.ListActivitiesRequest{})
|
||||
require.NoError(t, err)
|
||||
// Activities should be empty or not include the orphaned one
|
||||
for _, activity := range activities.Activities {
|
||||
if activity.Payload != nil && activity.Payload.GetMemoComment() != nil {
|
||||
require.NotEqual(t, comment.Name, activity.Payload.GetMemoComment().Memo,
|
||||
"Should not return activity with deleted memo")
|
||||
}
|
||||
}
|
||||
}
|
||||
1
server/router/api/v1/test/assets/1772542534_test.png
Normal file
1
server/router/api/v1/test/assets/1772542534_test.png
Normal file
@@ -0,0 +1 @@
|
||||
fake png content
|
||||
1
server/router/api/v1/test/assets/1772542535_test.png
Normal file
1
server/router/api/v1/test/assets/1772542535_test.png
Normal file
@@ -0,0 +1 @@
|
||||
fake png content
|
||||
2
server/router/api/v1/test/assets/1772542535_test.unknown
Normal file
2
server/router/api/v1/test/assets/1772542535_test.unknown
Normal file
@@ -0,0 +1,2 @@
|
||||
‰PNG
|
||||
|
||||
|
After Width: | Height: | Size: 8 B |
BIN
server/router/api/v1/test/assets/1772542536_test.data
Normal file
BIN
server/router/api/v1/test/assets/1772542536_test.data
Normal file
Binary file not shown.
59
server/router/api/v1/test/attachment_service_test.go
Normal file
59
server/router/api/v1/test/attachment_service_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestCreateAttachment(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "test_user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Test case 1: Create attachment with empty type but known extension
|
||||
t.Run("EmptyType_KnownExtension", func(t *testing.T) {
|
||||
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
|
||||
Attachment: &v1pb.Attachment{
|
||||
Filename: "test.png",
|
||||
Content: []byte("fake png content"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "image/png", attachment.Type)
|
||||
})
|
||||
|
||||
// Test case 2: Create attachment with empty type and unknown extension, but detectable content
|
||||
t.Run("EmptyType_UnknownExtension_ContentSniffing", func(t *testing.T) {
|
||||
// PNG magic header: 89 50 4E 47 0D 0A 1A 0A
|
||||
pngContent := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
|
||||
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
|
||||
Attachment: &v1pb.Attachment{
|
||||
Filename: "test.unknown",
|
||||
Content: pngContent,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "image/png", attachment.Type)
|
||||
})
|
||||
|
||||
// Test case 3: Empty type, unknown extension, random content -> fallback to application/octet-stream
|
||||
t.Run("EmptyType_Fallback", func(t *testing.T) {
|
||||
randomContent := []byte{0x00, 0x01, 0x02, 0x03}
|
||||
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
|
||||
Attachment: &v1pb.Attachment{
|
||||
Filename: "test.data",
|
||||
Content: randomContent,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "application/octet-stream", attachment.Type)
|
||||
})
|
||||
}
|
||||
655
server/router/api/v1/test/auth_test.go
Normal file
655
server/router/api/v1/test/auth_test.go
Normal file
@@ -0,0 +1,655 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestAuthenticatorAccessTokenV2(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("authenticates valid access token v2", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate access token v2
|
||||
token, _, err := auth.GenerateAccessTokenV2(
|
||||
user.ID,
|
||||
user.Username,
|
||||
string(user.Role),
|
||||
string(user.RowStatus),
|
||||
[]byte(ts.Secret),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
claims, err := authenticator.AuthenticateByAccessTokenV2(token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, claims)
|
||||
assert.Equal(t, user.ID, claims.UserID)
|
||||
assert.Equal(t, user.Username, claims.Username)
|
||||
assert.Equal(t, string(user.Role), claims.Role)
|
||||
assert.Equal(t, string(user.RowStatus), claims.Status)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, err := authenticator.AuthenticateByAccessTokenV2("invalid-token")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with wrong secret", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate token with one secret
|
||||
token, _, err := auth.GenerateAccessTokenV2(
|
||||
user.ID,
|
||||
user.Username,
|
||||
string(user.Role),
|
||||
string(user.RowStatus),
|
||||
[]byte("secret-1"),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate with different secret
|
||||
authenticator := auth.NewAuthenticator(ts.Store, "secret-2")
|
||||
_, err = authenticator.AuthenticateByAccessTokenV2(token)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthenticatorRefreshToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("authenticates valid refresh token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create refresh token record in store
|
||||
tokenID := util.GenUUID()
|
||||
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(auth.RefreshTokenDuration)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate refresh token JWT
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
authenticatedUser, returnedTokenID, err := authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.Equal(t, user.ID, authenticatedUser.ID)
|
||||
assert.Equal(t, tokenID, returnedTokenID)
|
||||
})
|
||||
|
||||
t.Run("fails with revoked token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
// Generate refresh token JWT but don't store it in database (simulates revocation)
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "revoked")
|
||||
})
|
||||
|
||||
t.Run("fails with expired token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create expired refresh token record in store
|
||||
tokenID := util.GenUUID()
|
||||
expiredToken := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), // Expired
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, expiredToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate refresh token JWT (JWT itself isn't expired yet)
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expired")
|
||||
})
|
||||
|
||||
t.Run("fails with archived user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create valid refresh token
|
||||
tokenID := util.GenUUID()
|
||||
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(auth.RefreshTokenDuration)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the user
|
||||
archivedStatus := store.Archived
|
||||
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
RowStatus: &archivedStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "archived")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthenticatorPAT(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("authenticates valid PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate PAT
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
// Store PAT in database
|
||||
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
authenticatedUser, pat, err := authenticator.AuthenticateByPAT(ctx, token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.NotNil(t, pat)
|
||||
assert.Equal(t, user.ID, authenticatedUser.ID)
|
||||
assert.Equal(t, tokenID, pat.TokenId)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid PAT format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err := authenticator.AuthenticateByPAT(ctx, "invalid-token-without-prefix")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid PAT format")
|
||||
})
|
||||
|
||||
t.Run("fails with non-existent PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Generate a PAT but don't store it
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err := authenticator.AuthenticateByPAT(ctx, token)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with expired PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate and store expired PAT
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
expiredPAT := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Expired PAT",
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), // Expired
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, expiredPAT)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByPAT(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expired")
|
||||
})
|
||||
|
||||
t.Run("succeeds with non-expiring PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate and store PAT without expiration
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Never-expiring PAT",
|
||||
ExpiresAt: nil, // No expiration
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
authenticatedUser, pat, err := authenticator.AuthenticateByPAT(ctx, token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.NotNil(t, pat)
|
||||
assert.Nil(t, pat.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("fails with archived user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate and store PAT
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the user
|
||||
archivedStatus := store.Archived
|
||||
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
RowStatus: &archivedStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByPAT(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "archived")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreRefreshTokenMethods(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("adds and retrieves refresh token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
token := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve tokens
|
||||
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 1)
|
||||
assert.Equal(t, tokenID, tokens[0].TokenId)
|
||||
})
|
||||
|
||||
t.Run("retrieves specific refresh token by ID", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
token := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve specific token
|
||||
retrievedToken, err := ts.Store.GetUserRefreshTokenByID(ctx, user.ID, tokenID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, retrievedToken)
|
||||
assert.Equal(t, tokenID, retrievedToken.TokenId)
|
||||
})
|
||||
|
||||
t.Run("removes refresh token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
token := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove token
|
||||
err = ts.Store.RemoveUserRefreshToken(ctx, user.ID, tokenID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify removal
|
||||
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 0)
|
||||
})
|
||||
|
||||
t.Run("handles multiple refresh tokens", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add multiple tokens
|
||||
tokenID1 := util.GenUUID()
|
||||
tokenID2 := util.GenUUID()
|
||||
|
||||
token1 := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID1,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
token2 := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID2,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token1)
|
||||
require.NoError(t, err)
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve all tokens
|
||||
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 2)
|
||||
|
||||
// Remove one token
|
||||
err = ts.Store.RemoveUserRefreshToken(ctx, user.ID, tokenID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only one token remains
|
||||
tokens, err = ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 1)
|
||||
assert.Equal(t, tokenID2, tokens[0].TokenId)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorePersonalAccessTokenMethods(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("adds and retrieves PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve PATs
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 1)
|
||||
assert.Equal(t, tokenID, pats[0].TokenId)
|
||||
assert.Equal(t, tokenHash, pats[0].TokenHash)
|
||||
})
|
||||
|
||||
t.Run("removes PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove PAT
|
||||
err = ts.Store.RemoveUserPersonalAccessToken(ctx, user.ID, tokenID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify removal
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
})
|
||||
|
||||
t.Run("updates PAT last used time", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update last used time
|
||||
lastUsed := timestamppb.Now()
|
||||
err = ts.Store.UpdatePATLastUsed(ctx, user.ID, tokenID, lastUsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify update
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 1)
|
||||
assert.NotNil(t, pats[0].LastUsedAt)
|
||||
})
|
||||
|
||||
t.Run("handles multiple PATs", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add multiple PATs
|
||||
token1 := auth.GeneratePersonalAccessToken()
|
||||
tokenHash1 := auth.HashPersonalAccessToken(token1)
|
||||
tokenID1 := util.GenUUID()
|
||||
|
||||
token2 := auth.GeneratePersonalAccessToken()
|
||||
tokenHash2 := auth.HashPersonalAccessToken(token2)
|
||||
tokenID2 := util.GenUUID()
|
||||
|
||||
pat1 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID1,
|
||||
TokenHash: tokenHash1,
|
||||
Description: "PAT 1",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
pat2 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID2,
|
||||
TokenHash: tokenHash2,
|
||||
Description: "PAT 2",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat1)
|
||||
require.NoError(t, err)
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve all PATs
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 2)
|
||||
|
||||
// Remove one PAT
|
||||
err = ts.Store.RemoveUserPersonalAccessToken(ctx, user.ID, tokenID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only one PAT remains
|
||||
pats, err = ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 1)
|
||||
assert.Equal(t, tokenID2, pats[0].TokenId)
|
||||
})
|
||||
|
||||
t.Run("finds user by PAT hash", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find user by PAT hash
|
||||
result, err := ts.Store.GetUserByPATHash(ctx, tokenHash)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, user.ID, result.UserID)
|
||||
assert.NotNil(t, result.User)
|
||||
assert.Equal(t, user.Username, result.User.Username)
|
||||
assert.NotNil(t, result.PAT)
|
||||
assert.Equal(t, tokenID, result.PAT.TokenId)
|
||||
})
|
||||
}
|
||||
552
server/router/api/v1/test/idp_service_test.go
Normal file
552
server/router/api/v1/test/idp_service_test.go
Normal file
@@ -0,0 +1,552 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestCreateIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
ctx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create OAuth2 identity provider
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test OAuth2 Provider",
|
||||
IdentifierFilter: "",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthUrl: "https://example.com/oauth/authorize",
|
||||
TokenUrl: "https://example.com/oauth/token",
|
||||
UserInfoUrl: "https://example.com/oauth/userinfo",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
AvatarUrl: "avatar_url",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "Test OAuth2 Provider", resp.Title)
|
||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
|
||||
require.Contains(t, resp.Name, "identity-providers/")
|
||||
require.NotNil(t, resp.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client-id", resp.Config.GetOauth2Config().ClientId)
|
||||
})
|
||||
|
||||
t.Run("CreateIdentityProvider permission denied for non-host user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
ctx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("CreateIdentityProvider unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "user not authenticated")
|
||||
})
|
||||
}
|
||||
|
||||
func TestListIdentityProviders(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ListIdentityProviders empty", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.ListIdentityProvidersRequest{}
|
||||
resp, err := ts.Service.ListIdentityProviders(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Empty(t, resp.IdentityProviders)
|
||||
})
|
||||
|
||||
t.Run("ListIdentityProviders with providers", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a couple of identity providers
|
||||
createReq1 := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Provider 1",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "client1",
|
||||
AuthUrl: "https://example1.com/auth",
|
||||
TokenUrl: "https://example1.com/token",
|
||||
UserInfoUrl: "https://example1.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
createReq2 := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Provider 2",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "client2",
|
||||
AuthUrl: "https://example2.com/auth",
|
||||
TokenUrl: "https://example2.com/token",
|
||||
UserInfoUrl: "https://example2.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq1)
|
||||
require.NoError(t, err)
|
||||
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List providers
|
||||
listReq := &v1pb.ListIdentityProvidersRequest{}
|
||||
resp, err := ts.Service.ListIdentityProviders(ctx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.IdentityProviders, 2)
|
||||
|
||||
// Verify response contains expected providers
|
||||
titles := []string{resp.IdentityProviders[0].Title, resp.IdentityProviders[1].Title}
|
||||
require.Contains(t, titles, "Provider 1")
|
||||
require.Contains(t, titles, "Provider 2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
AuthUrl: "https://example.com/auth",
|
||||
TokenUrl: "https://example.com/token",
|
||||
UserInfoUrl: "https://example.com/user",
|
||||
Scopes: []string{"openid", "profile"},
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get identity provider
|
||||
getReq := &v1pb.GetIdentityProviderRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
// Test unauthenticated, should not contain client secret
|
||||
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, created.Name, resp.Name)
|
||||
require.Equal(t, "Test Provider", resp.Title)
|
||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
|
||||
require.NotNil(t, resp.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "", resp.Config.GetOauth2Config().ClientSecret)
|
||||
|
||||
// Test as host user, should contain client secret
|
||||
respHostUser, err := ts.Service.GetIdentityProvider(userCtx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respHostUser)
|
||||
require.Equal(t, created.Name, respHostUser.Name)
|
||||
require.Equal(t, "Test Provider", respHostUser.Title)
|
||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, respHostUser.Type)
|
||||
require.NotNil(t, respHostUser.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client", respHostUser.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "test-secret", respHostUser.Config.GetOauth2Config().ClientSecret)
|
||||
})
|
||||
|
||||
t.Run("GetIdentityProvider not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.GetIdentityProviderRequest{
|
||||
Name: "identity-providers/999",
|
||||
}
|
||||
|
||||
_, err := ts.Service.GetIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("GetIdentityProvider invalid name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.GetIdentityProviderRequest{
|
||||
Name: "invalid-name",
|
||||
}
|
||||
|
||||
_, err := ts.Service.GetIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Original Provider",
|
||||
IdentifierFilter: "",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "original-client",
|
||||
AuthUrl: "https://original.com/auth",
|
||||
TokenUrl: "https://original.com/token",
|
||||
UserInfoUrl: "https://original.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update identity provider
|
||||
updateReq := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: created.Name,
|
||||
Title: "Updated Provider",
|
||||
IdentifierFilter: "test@example.com",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "updated-client",
|
||||
ClientSecret: "updated-secret",
|
||||
AuthUrl: "https://updated.com/auth",
|
||||
TokenUrl: "https://updated.com/token",
|
||||
UserInfoUrl: "https://updated.com/user",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "given_name",
|
||||
Email: "email",
|
||||
AvatarUrl: "picture",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "identifier_filter", "config"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := ts.Service.UpdateIdentityProvider(userCtx, updateReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, "Updated Provider", updated.Title)
|
||||
require.Equal(t, "test@example.com", updated.IdentifierFilter)
|
||||
require.Equal(t, "updated-client", updated.Config.GetOauth2Config().ClientId)
|
||||
})
|
||||
|
||||
t.Run("UpdateIdentityProvider missing update mask", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: "identity-providers/1",
|
||||
Title: "Updated Provider",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update_mask is required")
|
||||
})
|
||||
|
||||
t.Run("UpdateIdentityProvider invalid name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: "invalid-name",
|
||||
Title: "Updated Provider",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DeleteIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Provider to Delete",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "client-to-delete",
|
||||
AuthUrl: "https://example.com/auth",
|
||||
TokenUrl: "https://example.com/token",
|
||||
UserInfoUrl: "https://example.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete identity provider
|
||||
deleteReq := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteIdentityProvider(userCtx, deleteReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
getReq := &v1pb.GetIdentityProviderRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetIdentityProvider(ctx, getReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("DeleteIdentityProvider invalid name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: "invalid-name",
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
|
||||
t.Run("DeleteIdentityProvider not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: "identity-providers/999",
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
// Note: Delete might succeed even if item doesn't exist, depending on store implementation
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviderPermissions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Only host users can create identity providers", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("Authentication required", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "user not authenticated")
|
||||
})
|
||||
}
|
||||
54
server/router/api/v1/test/instance_admin_cache_test.go
Normal file
54
server/router/api/v1/test/instance_admin_cache_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestInstanceAdminRetrieval(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Instance becomes initialized after first admin user is created", func(t *testing.T) {
|
||||
// Create test service
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Verify instance is not initialized initially
|
||||
profile1, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, profile1.Admin, "Instance should not be initialized before first admin user")
|
||||
|
||||
// Create the first admin user
|
||||
user, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
|
||||
// Verify instance is now initialized
|
||||
profile2, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, profile2.Admin, "Instance should be initialized after first admin user is created")
|
||||
require.Equal(t, user.Username, profile2.Admin.Username)
|
||||
})
|
||||
|
||||
t.Run("Admin retrieval is cached by Store layer", func(t *testing.T) {
|
||||
// Create test service
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create admin user
|
||||
user, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Multiple calls should return consistent admin user (from cache)
|
||||
for i := 0; i < 5; i++ {
|
||||
profile, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, profile.Admin)
|
||||
require.Equal(t, user.Username, profile.Admin.Username)
|
||||
}
|
||||
})
|
||||
}
|
||||
204
server/router/api/v1/test/instance_service_test.go
Normal file
204
server/router/api/v1/test/instance_service_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestGetInstanceProfile(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetInstanceProfile returns instance profile", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetInstanceProfile directly
|
||||
req := &v1pb.GetInstanceProfileRequest{}
|
||||
resp, err := ts.Service.GetInstanceProfile(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify the response contains expected data
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.True(t, resp.Demo)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
|
||||
// Instance should not be initialized since no admin users are created
|
||||
require.Nil(t, resp.Admin)
|
||||
})
|
||||
|
||||
t.Run("GetInstanceProfile with initialized instance", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user in the store
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, hostUser)
|
||||
|
||||
// Call GetInstanceProfile directly
|
||||
req := &v1pb.GetInstanceProfileRequest{}
|
||||
resp, err := ts.Service.GetInstanceProfile(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify the response contains expected data with initialized flag
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.True(t, resp.Demo)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
|
||||
// Instance should be initialized since an admin user exists
|
||||
require.NotNil(t, resp.Admin)
|
||||
require.Equal(t, hostUser.Username, resp.Admin.Username)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetInstanceProfile_Concurrency(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Concurrent access to service", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user
|
||||
_, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make concurrent requests
|
||||
numGoroutines := 10
|
||||
results := make(chan *v1pb.InstanceProfile, numGoroutines)
|
||||
errors := make(chan error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
req := &v1pb.GetInstanceProfileRequest{}
|
||||
resp, err := ts.Service.GetInstanceProfile(ctx, req)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
results <- resp
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect all results
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
select {
|
||||
case err := <-errors:
|
||||
t.Fatalf("Goroutine returned error: %v", err)
|
||||
case resp := <-results:
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.True(t, resp.Demo)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
require.NotNil(t, resp.Admin)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetInstanceSetting(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetInstanceSetting - general setting", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetInstanceSetting for general setting
|
||||
req := &v1pb.GetInstanceSettingRequest{
|
||||
Name: "instance/settings/GENERAL",
|
||||
}
|
||||
resp, err := ts.Service.GetInstanceSetting(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "instance/settings/GENERAL", resp.Name)
|
||||
|
||||
// The general setting should have a general_setting field
|
||||
generalSetting := resp.GetGeneralSetting()
|
||||
require.NotNil(t, generalSetting)
|
||||
|
||||
// General setting should have default values
|
||||
require.False(t, generalSetting.DisallowUserRegistration)
|
||||
require.False(t, generalSetting.DisallowPasswordAuth)
|
||||
require.Empty(t, generalSetting.AdditionalScript)
|
||||
})
|
||||
|
||||
t.Run("GetInstanceSetting - storage setting", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user for storage setting access
|
||||
hostUser, err := ts.CreateHostUser(ctx, "testhost")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add user to context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Call GetInstanceSetting for storage setting
|
||||
req := &v1pb.GetInstanceSettingRequest{
|
||||
Name: "instance/settings/STORAGE",
|
||||
}
|
||||
resp, err := ts.Service.GetInstanceSetting(userCtx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "instance/settings/STORAGE", resp.Name)
|
||||
|
||||
// The storage setting should have a storage_setting field
|
||||
storageSetting := resp.GetStorageSetting()
|
||||
require.NotNil(t, storageSetting)
|
||||
})
|
||||
|
||||
t.Run("GetInstanceSetting - memo related setting", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetInstanceSetting for memo related setting
|
||||
req := &v1pb.GetInstanceSettingRequest{
|
||||
Name: "instance/settings/MEMO_RELATED",
|
||||
}
|
||||
resp, err := ts.Service.GetInstanceSetting(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "instance/settings/MEMO_RELATED", resp.Name)
|
||||
|
||||
// The memo related setting should have a memo_related_setting field
|
||||
memoRelatedSetting := resp.GetMemoRelatedSetting()
|
||||
require.NotNil(t, memoRelatedSetting)
|
||||
})
|
||||
|
||||
t.Run("GetInstanceSetting - invalid setting name", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetInstanceSetting with invalid name
|
||||
req := &v1pb.GetInstanceSettingRequest{
|
||||
Name: "invalid/setting/name",
|
||||
}
|
||||
_, err := ts.Service.GetInstanceSetting(ctx, req)
|
||||
|
||||
// Should return an error
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid instance setting name")
|
||||
})
|
||||
}
|
||||
166
server/router/api/v1/test/memo_attachment_service_test.go
Normal file
166
server/router/api/v1/test/memo_attachment_service_test.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestSetMemoAttachments(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMemoAttachments success by memo owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Create attachment
|
||||
attachment, err := ts.Service.CreateAttachment(userCtx, &apiv1.CreateAttachmentRequest{
|
||||
Attachment: &apiv1.Attachment{
|
||||
Filename: "test.txt",
|
||||
Size: 5,
|
||||
Type: "text/plain",
|
||||
Content: []byte("hello"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, attachment)
|
||||
|
||||
// Set memo attachments - should succeed
|
||||
_, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{
|
||||
Name: memo.Name,
|
||||
Attachments: []*apiv1.Attachment{
|
||||
{Name: attachment.Name},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetMemoAttachments success by host user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create memo by regular user
|
||||
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Host user can modify attachments - should succeed
|
||||
_, err = ts.Service.SetMemoAttachments(hostCtx, &apiv1.SetMemoAttachmentsRequest{
|
||||
Name: memo.Name,
|
||||
Attachments: []*apiv1.Attachment{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetMemoAttachments permission denied for non-owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user1
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
// Create user2
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
|
||||
// Create memo by user1
|
||||
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// User2 tries to modify attachments - should fail
|
||||
_, err = ts.Service.SetMemoAttachments(user2Ctx, &apiv1.SetMemoAttachmentsRequest{
|
||||
Name: memo.Name,
|
||||
Attachments: []*apiv1.Attachment{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("SetMemoAttachments unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Unauthenticated user tries to modify attachments - should fail
|
||||
_, err = ts.Service.SetMemoAttachments(ctx, &apiv1.SetMemoAttachmentsRequest{
|
||||
Name: memo.Name,
|
||||
Attachments: []*apiv1.Attachment{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not authenticated")
|
||||
})
|
||||
|
||||
t.Run("SetMemoAttachments memo not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Try to set attachments on non-existent memo - should fail
|
||||
_, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{
|
||||
Name: "memos/nonexistent-uid-12345",
|
||||
Attachments: []*apiv1.Attachment{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
169
server/router/api/v1/test/memo_relation_service_test.go
Normal file
169
server/router/api/v1/test/memo_relation_service_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestSetMemoRelations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMemoRelations success by memo owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo1
|
||||
memo1, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo 1",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo1)
|
||||
|
||||
// Create memo2
|
||||
memo2, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo 2",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo2)
|
||||
|
||||
// Set memo relations - should succeed
|
||||
_, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{
|
||||
Name: memo1.Name,
|
||||
Relations: []*apiv1.MemoRelation{
|
||||
{
|
||||
RelatedMemo: &apiv1.MemoRelation_Memo{
|
||||
Name: memo2.Name,
|
||||
},
|
||||
Type: apiv1.MemoRelation_REFERENCE,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetMemoRelations success by host user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create memo by regular user
|
||||
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Host user can modify relations - should succeed
|
||||
_, err = ts.Service.SetMemoRelations(hostCtx, &apiv1.SetMemoRelationsRequest{
|
||||
Name: memo.Name,
|
||||
Relations: []*apiv1.MemoRelation{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetMemoRelations permission denied for non-owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user1
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
// Create user2
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
|
||||
// Create memo by user1
|
||||
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// User2 tries to modify relations - should fail
|
||||
_, err = ts.Service.SetMemoRelations(user2Ctx, &apiv1.SetMemoRelationsRequest{
|
||||
Name: memo.Name,
|
||||
Relations: []*apiv1.MemoRelation{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("SetMemoRelations unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Unauthenticated user tries to modify relations - should fail
|
||||
_, err = ts.Service.SetMemoRelations(ctx, &apiv1.SetMemoRelationsRequest{
|
||||
Name: memo.Name,
|
||||
Relations: []*apiv1.MemoRelation{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not authenticated")
|
||||
})
|
||||
|
||||
t.Run("SetMemoRelations memo not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Try to set relations on non-existent memo - should fail
|
||||
_, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{
|
||||
Name: "memos/nonexistent-uid-12345",
|
||||
Relations: []*apiv1.MemoRelation{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
369
server/router/api/v1/test/memo_service_test.go
Normal file
369
server/router/api/v1/test/memo_service_test.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestListMemos(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create userOne
|
||||
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, userOne)
|
||||
|
||||
// Create userOne context
|
||||
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
|
||||
|
||||
// Create userTwo
|
||||
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, userTwo)
|
||||
|
||||
// Create userTwo context
|
||||
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
|
||||
|
||||
// Create attachmentOne by userOne
|
||||
attachmentOne, err := ts.Service.CreateAttachment(userOneCtx, &apiv1.CreateAttachmentRequest{
|
||||
Attachment: &apiv1.Attachment{
|
||||
Name: "",
|
||||
Filename: "hello.txt",
|
||||
Size: 5,
|
||||
Type: "text/plain",
|
||||
Content: []byte{
|
||||
104, 101, 108, 108, 111,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, attachmentOne)
|
||||
|
||||
// Create attachmentTwo by userOne
|
||||
attachmentTwo, err := ts.Service.CreateAttachment(userOneCtx, &apiv1.CreateAttachmentRequest{
|
||||
Attachment: &apiv1.Attachment{
|
||||
Name: "",
|
||||
Filename: "world.txt",
|
||||
Size: 5,
|
||||
Type: "text/plain",
|
||||
Content: []byte{
|
||||
119, 111, 114, 108, 100,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, attachmentTwo)
|
||||
|
||||
// Create memoOne with two attachments by userOne
|
||||
memoOne, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Hellooo, any words after this sentence won't be in the snippet. This is the next sentence. And I also have two attachments.",
|
||||
Visibility: apiv1.Visibility_PROTECTED,
|
||||
Attachments: []*apiv1.Attachment{
|
||||
&apiv1.Attachment{
|
||||
Name: attachmentOne.Name,
|
||||
},
|
||||
&apiv1.Attachment{
|
||||
Name: attachmentTwo.Name,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoOne)
|
||||
|
||||
// Create memoTwo by userTwo referencing memoOne
|
||||
memoTwo, err := ts.Service.CreateMemo(userTwoCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This is a memo reminding you to check the attachment attached to memoOne. I have referenced the memo below.⬇️",
|
||||
Visibility: apiv1.Visibility_PROTECTED,
|
||||
Relations: []*apiv1.MemoRelation{
|
||||
&apiv1.MemoRelation{
|
||||
RelatedMemo: &apiv1.MemoRelation_Memo{
|
||||
Name: memoOne.Name,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoTwo)
|
||||
|
||||
// Create memoThree by userOne
|
||||
memoThree, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This is a very popular memo. I have 2 reactions!",
|
||||
Visibility: apiv1.Visibility_PROTECTED,
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoThree)
|
||||
|
||||
// Create reaction from userOne on memoThree
|
||||
reactionOne, err := ts.Service.UpsertMemoReaction(userOneCtx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memoThree.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memoThree.Name,
|
||||
ReactionType: "❤️",
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reactionOne)
|
||||
|
||||
// Create reaction from userTwo on memoThree
|
||||
reactionTwo, err := ts.Service.UpsertMemoReaction(userTwoCtx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memoThree.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memoThree.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reactionTwo)
|
||||
|
||||
memos, err := ts.Service.ListMemos(userOneCtx, &apiv1.ListMemosRequest{PageSize: 10})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memos)
|
||||
require.Equal(t, 3, len(memos.Memos))
|
||||
|
||||
// ///////////////
|
||||
// VERIFY MEMO ONE
|
||||
// ///////////////
|
||||
memoOneResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoOne.GetName() })
|
||||
require.NotEqual(t, memoOneResIdx, -1)
|
||||
|
||||
memoOneRes := memos.Memos[memoOneResIdx]
|
||||
require.NotNil(t, memoOneRes)
|
||||
|
||||
require.Equal(t, fmt.Sprintf("users/%d", userOne.ID), memoOneRes.GetCreator())
|
||||
require.Equal(t, apiv1.Visibility_PROTECTED, memoOneRes.GetVisibility())
|
||||
require.Equal(t, memoOne.Content, memoOneRes.GetContent())
|
||||
require.Equal(t, memoOne.Content[:64]+"...", memoOneRes.GetSnippet(), "memoOne's content is snipped past the 64 char limit")
|
||||
require.Len(t, memoOneRes.Attachments, 2)
|
||||
require.Len(t, memoOneRes.Relations, 1)
|
||||
require.Empty(t, memoOneRes.Reactions)
|
||||
|
||||
// verify memoOne's attachments
|
||||
// attachment one
|
||||
attachmentOneResIdx := slices.IndexFunc(memoOneRes.Attachments, func(a *apiv1.Attachment) bool { return a.GetName() == attachmentOne.GetName() })
|
||||
require.NotEqual(t, attachmentOneResIdx, -1)
|
||||
|
||||
attachmentOneRes := memoOneRes.Attachments[attachmentOneResIdx]
|
||||
require.NotNil(t, attachmentOneRes)
|
||||
|
||||
require.Equal(t, attachmentOne.GetName(), attachmentOneRes.GetName())
|
||||
require.Equal(t, attachmentOne.GetContent(), attachmentOneRes.GetContent())
|
||||
|
||||
// attachment two
|
||||
attachmentTwoResIdx := slices.IndexFunc(memoOneRes.Attachments, func(a *apiv1.Attachment) bool { return a.GetName() == attachmentTwo.GetName() })
|
||||
require.NotEqual(t, attachmentTwoResIdx, -1)
|
||||
|
||||
attachmentTwoRes := memoOneRes.Attachments[attachmentTwoResIdx]
|
||||
require.NotNil(t, attachmentTwoRes)
|
||||
require.Equal(t, attachmentTwo.GetName(), attachmentTwoRes.GetName())
|
||||
|
||||
require.Equal(t, attachmentTwo.GetName(), attachmentTwoRes.GetName())
|
||||
require.Equal(t, attachmentTwo.GetContent(), attachmentTwoRes.GetContent())
|
||||
|
||||
// verify memoOne's relations
|
||||
require.Len(t, memoOneRes.Relations, 1)
|
||||
memoOneExpectedRelation := &apiv1.MemoRelation{
|
||||
Memo: &apiv1.MemoRelation_Memo{Name: memoTwo.GetName()},
|
||||
RelatedMemo: &apiv1.MemoRelation_Memo{Name: memoOne.GetName()},
|
||||
}
|
||||
require.Equal(t, memoOneExpectedRelation.Memo.GetName(), memoOneRes.Relations[0].Memo.GetName())
|
||||
require.Equal(t, memoOneExpectedRelation.RelatedMemo.GetName(), memoOneRes.Relations[0].RelatedMemo.GetName())
|
||||
|
||||
// ///////////////
|
||||
// VERIFY MEMO TWO
|
||||
// ///////////////
|
||||
memoTwoResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoTwo.GetName() })
|
||||
require.NotEqual(t, memoTwoResIdx, -1)
|
||||
|
||||
memoTwoRes := memos.Memos[memoTwoResIdx]
|
||||
require.NotNil(t, memoTwoRes)
|
||||
|
||||
require.Equal(t, fmt.Sprintf("users/%d", userTwo.ID), memoTwoRes.GetCreator())
|
||||
require.Equal(t, apiv1.Visibility_PROTECTED, memoTwoRes.GetVisibility())
|
||||
require.Equal(t, memoTwo.Content, memoTwoRes.GetContent())
|
||||
require.Empty(t, memoTwoRes.Attachments)
|
||||
require.Len(t, memoTwoRes.Relations, 1)
|
||||
require.Empty(t, memoTwoRes.Reactions)
|
||||
|
||||
// verify memoTwo's relations
|
||||
require.Len(t, memoTwoRes.Relations, 1)
|
||||
memoTwoExpectedRelation := &apiv1.MemoRelation{
|
||||
Memo: &apiv1.MemoRelation_Memo{Name: memoTwo.GetName()},
|
||||
RelatedMemo: &apiv1.MemoRelation_Memo{Name: memoOne.GetName()},
|
||||
}
|
||||
require.Equal(t, memoTwoExpectedRelation.Memo.GetName(), memoTwoRes.Relations[0].Memo.GetName())
|
||||
require.Equal(t, memoTwoExpectedRelation.RelatedMemo.GetName(), memoTwoRes.Relations[0].RelatedMemo.GetName())
|
||||
|
||||
// ///////////////
|
||||
// VERIFY MEMO THREE
|
||||
// ///////////////
|
||||
memoThreeResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoThree.GetName() })
|
||||
require.NotEqual(t, memoThreeResIdx, -1)
|
||||
|
||||
memoThreeRes := memos.Memos[memoThreeResIdx]
|
||||
require.NotNil(t, memoThreeRes)
|
||||
|
||||
require.Equal(t, fmt.Sprintf("users/%d", userOne.ID), memoThreeRes.GetCreator())
|
||||
require.Equal(t, apiv1.Visibility_PROTECTED, memoThreeRes.GetVisibility())
|
||||
require.Equal(t, memoThree.Content, memoThreeRes.GetContent())
|
||||
require.Empty(t, memoThreeRes.Attachments)
|
||||
require.Empty(t, memoThreeRes.Relations)
|
||||
require.Len(t, memoThreeRes.Reactions, 2)
|
||||
|
||||
// verify memoThree's reactions
|
||||
require.Len(t, memoThreeRes.Reactions, 2)
|
||||
// userOne's reaction
|
||||
userOneReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%d", userOne.ID) })
|
||||
require.NotEqual(t, userOneReactionIdx, -1)
|
||||
|
||||
userOneReaction := memoThreeRes.Reactions[userOneReactionIdx]
|
||||
require.NotNil(t, userOneReaction)
|
||||
require.Equal(t, "❤️", userOneReaction.ReactionType)
|
||||
|
||||
// userTwo's reaction
|
||||
userTwoReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%d", userTwo.ID) })
|
||||
require.NotEqual(t, userTwoReactionIdx, -1)
|
||||
|
||||
userTwoReaction := memoThreeRes.Reactions[userTwoReactionIdx]
|
||||
require.NotNil(t, userTwoReaction)
|
||||
require.Equal(t, "👍", userTwoReaction.ReactionType)
|
||||
}
|
||||
|
||||
// TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments.
|
||||
// This addresses issue #5483: https://github.com/usememos/memos/issues/5483
|
||||
func TestCreateMemoWithCustomTimestamps(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "test-user-timestamps")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Define custom timestamps (January 1, 2020)
|
||||
customCreateTime := time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
customUpdateTime := time.Date(2020, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
customDisplayTime := time.Date(2020, 1, 3, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Test 1: Create a memo with custom create_time
|
||||
memoWithCreateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has a custom creation time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
CreateTime: timestamppb.New(customCreateTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithCreateTime)
|
||||
require.Equal(t, customCreateTime.Unix(), memoWithCreateTime.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
|
||||
|
||||
// Test 2: Create a memo with custom update_time
|
||||
memoWithUpdateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has a custom update time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
UpdateTime: timestamppb.New(customUpdateTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithUpdateTime)
|
||||
require.Equal(t, customUpdateTime.Unix(), memoWithUpdateTime.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
|
||||
|
||||
// Test 3: Create a memo with custom display_time
|
||||
// Note: display_time is computed from either created_ts or updated_ts based on instance setting
|
||||
// Since DisplayWithUpdateTime defaults to false, display_time maps to created_ts
|
||||
memoWithDisplayTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has a custom display time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
DisplayTime: timestamppb.New(customDisplayTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithDisplayTime)
|
||||
// Since DisplayWithUpdateTime is false by default, display_time sets created_ts
|
||||
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.DisplayTime.AsTime().Unix(), "display_time should match the custom timestamp")
|
||||
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.CreateTime.AsTime().Unix(), "create_time should also match since display_time maps to created_ts")
|
||||
|
||||
// Test 4: Create a memo with all custom timestamps
|
||||
// When both display_time and create_time are provided, create_time takes precedence
|
||||
memoWithAllTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has all custom timestamps",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
CreateTime: timestamppb.New(customCreateTime),
|
||||
UpdateTime: timestamppb.New(customUpdateTime),
|
||||
DisplayTime: timestamppb.New(customDisplayTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithAllTimestamps)
|
||||
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
|
||||
require.Equal(t, customUpdateTime.Unix(), memoWithAllTimestamps.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
|
||||
// display_time is computed from created_ts when DisplayWithUpdateTime is false
|
||||
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.DisplayTime.AsTime().Unix(), "display_time should be derived from create_time")
|
||||
|
||||
// Test 5: Create a comment (memo relation) with custom timestamps
|
||||
parentMemo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This is the parent memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parentMemo)
|
||||
|
||||
customCommentCreateTime := time.Date(2021, 6, 15, 10, 30, 0, 0, time.UTC)
|
||||
comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: parentMemo.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "This is a comment with custom create time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
CreateTime: timestamppb.New(customCommentCreateTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, comment)
|
||||
require.Equal(t, customCommentCreateTime.Unix(), comment.CreateTime.AsTime().Unix(), "comment create_time should match the custom timestamp")
|
||||
|
||||
// Test 6: Verify that memos without custom timestamps still get auto-generated ones
|
||||
memoWithoutTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has auto-generated timestamps",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithoutTimestamps)
|
||||
require.NotNil(t, memoWithoutTimestamps.CreateTime, "create_time should be auto-generated")
|
||||
require.NotNil(t, memoWithoutTimestamps.UpdateTime, "update_time should be auto-generated")
|
||||
require.True(t, time.Now().Unix()-memoWithoutTimestamps.CreateTime.AsTime().Unix() < 5, "create_time should be recent (within 5 seconds)")
|
||||
}
|
||||
194
server/router/api/v1/test/reaction_service_test.go
Normal file
194
server/router/api/v1/test/reaction_service_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestDeleteMemoReaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DeleteMemoReaction success by reaction owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Create reaction
|
||||
reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reaction)
|
||||
|
||||
// Delete reaction - should succeed
|
||||
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
|
||||
Name: reaction.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("DeleteMemoReaction success by host user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create memo by regular user
|
||||
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Create reaction by regular user
|
||||
reaction, err := ts.Service.UpsertMemoReaction(regularUserCtx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reaction)
|
||||
|
||||
// Host user can delete reaction - should succeed
|
||||
_, err = ts.Service.DeleteMemoReaction(hostCtx, &apiv1.DeleteMemoReactionRequest{
|
||||
Name: reaction.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("DeleteMemoReaction permission denied for non-owner", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user1
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
// Create user2
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
|
||||
// Create memo by user1
|
||||
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Create reaction by user1
|
||||
reaction, err := ts.Service.UpsertMemoReaction(user1Ctx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reaction)
|
||||
|
||||
// User2 tries to delete reaction - should fail with permission denied
|
||||
_, err = ts.Service.DeleteMemoReaction(user2Ctx, &apiv1.DeleteMemoReactionRequest{
|
||||
Name: reaction.Name,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("DeleteMemoReaction unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create memo
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "Test memo",
|
||||
Visibility: apiv1.Visibility_PUBLIC,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Create reaction
|
||||
reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &apiv1.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reaction)
|
||||
|
||||
// Unauthenticated user tries to delete reaction - should fail
|
||||
_, err = ts.Service.DeleteMemoReaction(ctx, &apiv1.DeleteMemoReactionRequest{
|
||||
Name: reaction.Name,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not authenticated")
|
||||
})
|
||||
|
||||
t.Run("DeleteMemoReaction not found returns permission denied", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Try to delete non-existent reaction - should fail with permission denied
|
||||
// (not "not found" to avoid information disclosure)
|
||||
// Use new nested resource format: memos/{memo}/reactions/{reaction}
|
||||
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
|
||||
Name: "memos/nonexistent/reactions/99999",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
require.NotContains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
819
server/router/api/v1/test/shortcut_service_test.go
Normal file
819
server/router/api/v1/test/shortcut_service_test.go
Normal file
@@ -0,0 +1,819 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestListShortcuts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ListShortcuts success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// List shortcuts (should be empty initially)
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
resp, err := ts.Service.ListShortcuts(userCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Empty(t, resp.Shortcuts)
|
||||
})
|
||||
|
||||
t.Run("ListShortcuts permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to list user2's shortcuts
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
}
|
||||
|
||||
_, err = ts.Service.ListShortcuts(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("ListShortcuts invalid parent format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: "invalid-parent-format",
|
||||
}
|
||||
|
||||
_, err = ts.Service.ListShortcuts(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid user name")
|
||||
})
|
||||
|
||||
t.Run("ListShortcuts unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: "users/1",
|
||||
}
|
||||
|
||||
_, err := ts.Service.ListShortcuts(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// First create a shortcut
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now get the shortcut
|
||||
getReq := &v1pb.GetShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
resp, err := ts.Service.GetShortcut(userCtx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, created.Name, resp.Name)
|
||||
require.Equal(t, "Test Shortcut", resp.Title)
|
||||
require.Equal(t, "tag in [\"test\"]", resp.Filter)
|
||||
})
|
||||
|
||||
t.Run("GetShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
getReq := &v1pb.GetShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(user2Ctx, getReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("GetShortcut invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.GetShortcutRequest{
|
||||
Name: "invalid-shortcut-name",
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid shortcut name")
|
||||
})
|
||||
|
||||
t.Run("GetShortcut not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.GetShortcutRequest{
|
||||
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "My Shortcut",
|
||||
Filter: "tag in [\"important\"]",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ts.Service.CreateShortcut(userCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "My Shortcut", resp.Title)
|
||||
require.Equal(t, "tag in [\"important\"]", resp.Filter)
|
||||
require.Contains(t, resp.Name, fmt.Sprintf("users/%d/shortcuts/", user.ID))
|
||||
|
||||
// Verify the shortcut was created by listing
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listResp.Shortcuts, 1)
|
||||
require.Equal(t, "My Shortcut", listResp.Shortcuts[0].Title)
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to create shortcut for user2
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Forbidden Shortcut",
|
||||
Filter: "tag in [\"forbidden\"]",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut invalid parent format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: "invalid-parent",
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid user name")
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut invalid filter", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Invalid Filter Shortcut",
|
||||
Filter: "invalid||filter))syntax",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid filter")
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut missing title", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "title is required")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Original Title",
|
||||
Filter: "tag in [\"original\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the shortcut
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created.Name,
|
||||
Title: "Updated Title",
|
||||
Filter: "tag in [\"updated\"]",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "filter"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, "Updated Title", updated.Title)
|
||||
require.Equal(t, "tag in [\"updated\"]", updated.Filter)
|
||||
require.Equal(t, created.Name, updated.Name)
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to update shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created.Name,
|
||||
Title: "Hacked Title",
|
||||
Filter: "tag in [\"hacked\"]",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "filter"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateShortcut(user2Ctx, updateReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut missing update mask", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user and context for authentication
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: fmt.Sprintf("users/%d/shortcuts/test", user.ID),
|
||||
Title: "Updated Title",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update mask is required")
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: "invalid-shortcut-name",
|
||||
Title: "Updated Title",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.UpdateShortcut(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid shortcut name")
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut invalid filter", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to update with invalid filter
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created.Name,
|
||||
Filter: "invalid||filter))syntax",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"filter"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateShortcut(userCtx, updateReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid filter")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DeleteShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Shortcut to Delete",
|
||||
Filter: "tag in [\"delete\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the shortcut
|
||||
deleteReq := &v1pb.DeleteShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion by listing shortcuts
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, listResp.Shortcuts)
|
||||
|
||||
// Also verify by trying to get the deleted shortcut
|
||||
getReq := &v1pb.GetShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(userCtx, getReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("DeleteShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to delete shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
deleteReq := &v1pb.DeleteShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(user2Ctx, deleteReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("DeleteShortcut invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.DeleteShortcutRequest{
|
||||
Name: "invalid-shortcut-name",
|
||||
}
|
||||
|
||||
_, err := ts.Service.DeleteShortcut(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid shortcut name")
|
||||
})
|
||||
|
||||
t.Run("DeleteShortcut not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.DeleteShortcutRequest{
|
||||
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestShortcutFiltering(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateShortcut with valid filters", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Test various valid filter formats
|
||||
validFilters := []string{
|
||||
"tag in [\"work\"]",
|
||||
"content.contains(\"meeting\")",
|
||||
"tag in [\"work\"] && content.contains(\"meeting\")",
|
||||
"tag in [\"work\"] || tag in [\"personal\"]",
|
||||
"creator_id == 1",
|
||||
"visibility == \"PUBLIC\"",
|
||||
"has_task_list == true",
|
||||
"has_task_list == false",
|
||||
}
|
||||
|
||||
for i, filter := range validFilters {
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Valid Filter " + string(rune(i)),
|
||||
Filter: filter,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.NoError(t, err, "Filter should be valid: %s", filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut with invalid filters", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Test various invalid filter formats
|
||||
invalidFilters := []string{
|
||||
"tag in ", // incomplete expression
|
||||
"invalid_field @in [\"value\"]", // unknown field
|
||||
"tag in [\"work\"] &&", // incomplete expression
|
||||
"tag in [\"work\"] || || tag in [\"test\"]", // double operator
|
||||
"((tag in [\"work\"]", // unmatched parentheses
|
||||
"tag in [\"work\"] && )", // mismatched parentheses
|
||||
"tag == \"work\"", // wrong operator (== not supported for tags)
|
||||
"tag in work", // missing brackets
|
||||
}
|
||||
|
||||
for _, filter := range invalidFilters {
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Invalid Filter Test",
|
||||
Filter: filter,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err, "Filter should be invalid: %s", filter)
|
||||
require.Contains(t, err.Error(), "invalid filter", "Error should mention invalid filter for: %s", filter)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShortcutCRUDComplete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Complete CRUD lifecycle", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// 1. Create multiple shortcuts
|
||||
shortcut1Req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Work Notes",
|
||||
Filter: "tag in [\"work\"]",
|
||||
},
|
||||
}
|
||||
|
||||
shortcut2Req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Personal Notes",
|
||||
Filter: "tag in [\"personal\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created1, err := ts.Service.CreateShortcut(userCtx, shortcut1Req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Work Notes", created1.Title)
|
||||
|
||||
created2, err := ts.Service.CreateShortcut(userCtx, shortcut2Req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Personal Notes", created2.Title)
|
||||
|
||||
// 2. List shortcuts and verify both exist
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listResp.Shortcuts, 2)
|
||||
|
||||
// 3. Get individual shortcuts
|
||||
getReq1 := &v1pb.GetShortcutRequest{Name: created1.Name}
|
||||
getResp1, err := ts.Service.GetShortcut(userCtx, getReq1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, created1.Name, getResp1.Name)
|
||||
require.Equal(t, "Work Notes", getResp1.Title)
|
||||
|
||||
getReq2 := &v1pb.GetShortcutRequest{Name: created2.Name}
|
||||
getResp2, err := ts.Service.GetShortcut(userCtx, getReq2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, created2.Name, getResp2.Name)
|
||||
require.Equal(t, "Personal Notes", getResp2.Title)
|
||||
|
||||
// 4. Update one shortcut
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created1.Name,
|
||||
Title: "Work & Meeting Notes",
|
||||
Filter: "tag in [\"work\"] || tag in [\"meeting\"]",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "filter"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Work & Meeting Notes", updated.Title)
|
||||
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", updated.Filter)
|
||||
|
||||
// 5. Verify update by getting it again
|
||||
getUpdatedReq := &v1pb.GetShortcutRequest{Name: created1.Name}
|
||||
getUpdatedResp, err := ts.Service.GetShortcut(userCtx, getUpdatedReq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Work & Meeting Notes", getUpdatedResp.Title)
|
||||
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", getUpdatedResp.Filter)
|
||||
|
||||
// 6. Delete one shortcut
|
||||
deleteReq := &v1pb.DeleteShortcutRequest{
|
||||
Name: created2.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 7. Verify deletion by listing (should only have 1 left)
|
||||
finalListResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, finalListResp.Shortcuts, 1)
|
||||
require.Equal(t, "Work & Meeting Notes", finalListResp.Shortcuts[0].Title)
|
||||
|
||||
// 8. Verify deleted shortcut can't be accessed
|
||||
getDeletedReq := &v1pb.GetShortcutRequest{Name: created2.Name}
|
||||
_, err = ts.Service.GetShortcut(userCtx, getDeletedReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
86
server/router/api/v1/test/test_helper.go
Normal file
86
server/router/api/v1/test/test_helper.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/plugin/markdown"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
apiv1 "github.com/usememos/memos/server/router/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
teststore "github.com/usememos/memos/store/test"
|
||||
)
|
||||
|
||||
// TestService holds the test service setup for API v1 services.
|
||||
type TestService struct {
|
||||
Service *apiv1.APIV1Service
|
||||
Store *store.Store
|
||||
Profile *profile.Profile
|
||||
Secret string
|
||||
}
|
||||
|
||||
// NewTestService creates a new test service with SQLite database.
|
||||
func NewTestService(t *testing.T) *TestService {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a test store with SQLite
|
||||
testStore := teststore.NewTestingStore(ctx, t)
|
||||
|
||||
// Create a test profile
|
||||
testProfile := &profile.Profile{
|
||||
Demo: true,
|
||||
Version: "test-1.0.0",
|
||||
InstanceURL: "http://localhost:8080",
|
||||
Driver: "sqlite",
|
||||
DSN: ":memory:",
|
||||
}
|
||||
|
||||
// Create APIV1Service with nil grpcServer since we're testing direct calls
|
||||
secret := "test-secret"
|
||||
markdownService := markdown.NewService(
|
||||
markdown.WithTagExtension(),
|
||||
)
|
||||
service := &apiv1.APIV1Service{
|
||||
Secret: secret,
|
||||
Profile: testProfile,
|
||||
Store: testStore,
|
||||
MarkdownService: markdownService,
|
||||
}
|
||||
|
||||
return &TestService{
|
||||
Service: service,
|
||||
Store: testStore,
|
||||
Profile: testProfile,
|
||||
Secret: secret,
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup closes resources after test.
|
||||
func (ts *TestService) Cleanup() {
|
||||
ts.Store.Close()
|
||||
}
|
||||
|
||||
// CreateHostUser creates an admin user for testing.
|
||||
func (ts *TestService) CreateHostUser(ctx context.Context, username string) (*store.User, error) {
|
||||
return ts.Store.CreateUser(ctx, &store.User{
|
||||
Username: username,
|
||||
Role: store.RoleAdmin,
|
||||
Email: username + "@example.com",
|
||||
})
|
||||
}
|
||||
|
||||
// CreateRegularUser creates a regular user for testing.
|
||||
func (ts *TestService) CreateRegularUser(ctx context.Context, username string) (*store.User, error) {
|
||||
return ts.Store.CreateUser(ctx, &store.User{
|
||||
Username: username,
|
||||
Role: store.RoleUser,
|
||||
Email: username + "@example.com",
|
||||
})
|
||||
}
|
||||
|
||||
// CreateUserContext creates a context with the given user's ID for authentication.
|
||||
func (*TestService) CreateUserContext(ctx context.Context, userID int32) context.Context {
|
||||
// Use the context key from the auth package
|
||||
return context.WithValue(ctx, auth.UserIDContextKey, userID)
|
||||
}
|
||||
173
server/router/api/v1/test/user_service_registration_test.go
Normal file
173
server/router/api/v1/test/user_service_registration_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
func TestCreateUserRegistration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateUser success when registration enabled", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// User registration is enabled by default, no need to set it explicitly
|
||||
|
||||
// Create user without authentication - should succeed
|
||||
_, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newuser",
|
||||
Email: "newuser@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("CreateUser blocked when registration disabled", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user first so we're not in first-user setup mode
|
||||
_, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Disable user registration
|
||||
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
DisallowUserRegistration: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to create user without authentication - should fail
|
||||
_, err = ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newuser",
|
||||
Email: "newuser@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not allowed")
|
||||
})
|
||||
|
||||
t.Run("CreateUser succeeds for superuser even when registration disabled", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Disable user registration
|
||||
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
DisallowUserRegistration: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Host user can create users even when registration is disabled - should succeed
|
||||
_, err = ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newuser",
|
||||
Email: "newuser@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("CreateUser regular user cannot create users when registration disabled", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
|
||||
require.NoError(t, err)
|
||||
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
// Disable user registration
|
||||
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
DisallowUserRegistration: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Regular user tries to create user when registration is disabled - should fail
|
||||
_, err = ts.Service.CreateUser(regularUserCtx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newuser",
|
||||
Email: "newuser@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not allowed")
|
||||
})
|
||||
|
||||
t.Run("CreateUser host can assign roles", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Host user can create user with specific role - should succeed
|
||||
createdUser, err := ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newadmin",
|
||||
Email: "newadmin@example.com",
|
||||
Password: "password123",
|
||||
Role: apiv1.User_ADMIN,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, createdUser)
|
||||
require.Equal(t, apiv1.User_ADMIN, createdUser.Role)
|
||||
})
|
||||
|
||||
t.Run("CreateUser unauthenticated user can only create regular user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user first so we're not in first-user setup mode
|
||||
_, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// User registration is enabled by default
|
||||
|
||||
// Unauthenticated user tries to create admin user - role should be ignored
|
||||
createdUser, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "wannabeadmin",
|
||||
Email: "wannabeadmin@example.com",
|
||||
Password: "password123",
|
||||
Role: apiv1.User_ADMIN, // This should be ignored
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, createdUser)
|
||||
require.Equal(t, apiv1.User_USER, createdUser.Role, "Unauthenticated users can only create USER role")
|
||||
})
|
||||
}
|
||||
105
server/router/api/v1/test/user_service_stats_test.go
Normal file
105
server/router/api/v1/test/user_service_stats_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestGetUserStats_TagCount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test service
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test host user
|
||||
user, err := ts.CreateHostUser(ctx, "test_user")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user context for authentication
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a memo with a single tag
|
||||
memo, err := ts.Store.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "This is a test memo with #test tag",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: []string{"test"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Test GetUserStats
|
||||
userName := fmt.Sprintf("users/%d", user.ID)
|
||||
response, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: userName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
// Check that the tag count is exactly 1, not 2
|
||||
require.Contains(t, response.TagCount, "test")
|
||||
require.Equal(t, int32(1), response.TagCount["test"], "Tag count should be 1 for a single occurrence")
|
||||
|
||||
// Create another memo with the same tag
|
||||
memo2, err := ts.Store.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "Another memo with #test tag",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: []string{"test"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo2)
|
||||
|
||||
// Test GetUserStats again
|
||||
response2, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: userName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response2)
|
||||
|
||||
// Check that the tag count is exactly 2, not 3
|
||||
require.Contains(t, response2.TagCount, "test")
|
||||
require.Equal(t, int32(2), response2.TagCount["test"], "Tag count should be 2 for two occurrences")
|
||||
|
||||
// Test with a new unique tag
|
||||
memo3, err := ts.Store.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-3",
|
||||
CreatorID: user.ID,
|
||||
Content: "Memo with #unique tag",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: []string{"unique"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo3)
|
||||
|
||||
// Test GetUserStats for the new tag
|
||||
response3, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: userName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response3)
|
||||
|
||||
// Check that the unique tag count is exactly 1
|
||||
require.Contains(t, response3.TagCount, "unique")
|
||||
require.Equal(t, int32(1), response3.TagCount["unique"], "New tag count should be 1 for first occurrence")
|
||||
|
||||
// The original test tag should still be 2
|
||||
require.Contains(t, response3.TagCount, "test")
|
||||
require.Equal(t, int32(2), response3.TagCount["test"], "Original tag count should remain 2")
|
||||
}
|
||||
1451
server/router/api/v1/user_service.go
Normal file
1451
server/router/api/v1/user_service.go
Normal file
File diff suppressed because it is too large
Load Diff
236
server/router/api/v1/user_service_stats.go
Normal file
236
server/router/api/v1/user_service_stats.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUserStatsRequest) (*v1pb.ListAllUserStatsResponse, error) {
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get instance memo related setting")
|
||||
}
|
||||
|
||||
normalStatus := store.Normal
|
||||
memoFind := &store.FindMemo{
|
||||
// Exclude comments by default.
|
||||
ExcludeComments: true,
|
||||
ExcludeContent: true,
|
||||
RowStatus: &normalStatus,
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public}
|
||||
} else {
|
||||
if memoFind.CreatorID == nil {
|
||||
filter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
|
||||
memoFind.Filters = append(memoFind.Filters, filter)
|
||||
} else if *memoFind.CreatorID != currentUser.ID {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
|
||||
}
|
||||
}
|
||||
|
||||
userMemoStatMap := make(map[int32]*v1pb.UserStats)
|
||||
limit := 1000
|
||||
offset := 0
|
||||
memoFind.Limit = &limit
|
||||
memoFind.Offset = &offset
|
||||
|
||||
for {
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
if len(memos) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, memo := range memos {
|
||||
// Initialize user stats if not exists
|
||||
if _, exists := userMemoStatMap[memo.CreatorID]; !exists {
|
||||
userMemoStatMap[memo.CreatorID] = &v1pb.UserStats{
|
||||
Name: fmt.Sprintf("users/%d/stats", memo.CreatorID),
|
||||
TagCount: make(map[string]int32),
|
||||
MemoDisplayTimestamps: []*timestamppb.Timestamp{},
|
||||
PinnedMemos: []string{},
|
||||
MemoTypeStats: &v1pb.UserStats_MemoTypeStats{
|
||||
LinkCount: 0,
|
||||
CodeCount: 0,
|
||||
TodoCount: 0,
|
||||
UndoCount: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
stats := userMemoStatMap[memo.CreatorID]
|
||||
|
||||
// Add display timestamp
|
||||
displayTs := memo.CreatedTs
|
||||
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
displayTs = memo.UpdatedTs
|
||||
}
|
||||
stats.MemoDisplayTimestamps = append(stats.MemoDisplayTimestamps, timestamppb.New(time.Unix(displayTs, 0)))
|
||||
|
||||
// Count memo stats
|
||||
stats.TotalMemoCount++
|
||||
|
||||
// Count tags and other properties
|
||||
if memo.Payload != nil {
|
||||
for _, tag := range memo.Payload.Tags {
|
||||
stats.TagCount[tag]++
|
||||
}
|
||||
if memo.Payload.Property != nil {
|
||||
if memo.Payload.Property.HasLink {
|
||||
stats.MemoTypeStats.LinkCount++
|
||||
}
|
||||
if memo.Payload.Property.HasCode {
|
||||
stats.MemoTypeStats.CodeCount++
|
||||
}
|
||||
if memo.Payload.Property.HasTaskList {
|
||||
stats.MemoTypeStats.TodoCount++
|
||||
}
|
||||
if memo.Payload.Property.HasIncompleteTasks {
|
||||
stats.MemoTypeStats.UndoCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track pinned memos
|
||||
if memo.Pinned {
|
||||
stats.PinnedMemos = append(stats.PinnedMemos, fmt.Sprintf("users/%d/memos/%d", memo.CreatorID, memo.ID))
|
||||
}
|
||||
}
|
||||
|
||||
offset += limit
|
||||
}
|
||||
|
||||
userMemoStats := []*v1pb.UserStats{}
|
||||
for _, userMemoStat := range userMemoStatMap {
|
||||
userMemoStats = append(userMemoStats, userMemoStat)
|
||||
}
|
||||
|
||||
response := &v1pb.ListAllUserStatsResponse{
|
||||
Stats: userMemoStats,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserStatsRequest) (*v1pb.UserStats, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
|
||||
normalStatus := store.Normal
|
||||
memoFind := &store.FindMemo{
|
||||
CreatorID: &userID,
|
||||
// Exclude comments by default.
|
||||
ExcludeComments: true,
|
||||
ExcludeContent: true,
|
||||
RowStatus: &normalStatus,
|
||||
}
|
||||
|
||||
if currentUser == nil {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public}
|
||||
} else if currentUser.ID != userID {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
|
||||
}
|
||||
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get instance memo related setting")
|
||||
}
|
||||
|
||||
displayTimestamps := []*timestamppb.Timestamp{}
|
||||
tagCount := make(map[string]int32)
|
||||
linkCount := int32(0)
|
||||
codeCount := int32(0)
|
||||
todoCount := int32(0)
|
||||
undoCount := int32(0)
|
||||
pinnedMemos := []string{}
|
||||
totalMemoCount := int32(0)
|
||||
|
||||
limit := 1000
|
||||
offset := 0
|
||||
memoFind.Limit = &limit
|
||||
memoFind.Offset = &offset
|
||||
|
||||
for {
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
if len(memos) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
totalMemoCount += int32(len(memos))
|
||||
|
||||
for _, memo := range memos {
|
||||
displayTs := memo.CreatedTs
|
||||
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
displayTs = memo.UpdatedTs
|
||||
}
|
||||
displayTimestamps = append(displayTimestamps, timestamppb.New(time.Unix(displayTs, 0)))
|
||||
// Count different memo types based on content.
|
||||
if memo.Payload != nil {
|
||||
for _, tag := range memo.Payload.Tags {
|
||||
tagCount[tag]++
|
||||
}
|
||||
if memo.Payload.Property != nil {
|
||||
if memo.Payload.Property.HasLink {
|
||||
linkCount++
|
||||
}
|
||||
if memo.Payload.Property.HasCode {
|
||||
codeCount++
|
||||
}
|
||||
if memo.Payload.Property.HasTaskList {
|
||||
todoCount++
|
||||
}
|
||||
if memo.Payload.Property.HasIncompleteTasks {
|
||||
undoCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
if memo.Pinned {
|
||||
pinnedMemos = append(pinnedMemos, fmt.Sprintf("users/%d/memos/%d", userID, memo.ID))
|
||||
}
|
||||
}
|
||||
|
||||
offset += limit
|
||||
}
|
||||
|
||||
userStats := &v1pb.UserStats{
|
||||
Name: fmt.Sprintf("users/%d/stats", userID),
|
||||
MemoDisplayTimestamps: displayTimestamps,
|
||||
TagCount: tagCount,
|
||||
PinnedMemos: pinnedMemos,
|
||||
TotalMemoCount: totalMemoCount,
|
||||
MemoTypeStats: &v1pb.UserStats_MemoTypeStats{
|
||||
LinkCount: linkCount,
|
||||
CodeCount: codeCount,
|
||||
TodoCount: todoCount,
|
||||
UndoCount: undoCount,
|
||||
},
|
||||
}
|
||||
|
||||
return userStats, nil
|
||||
}
|
||||
175
server/router/api/v1/v1.go
Normal file
175
server/router/api/v1/v1.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/labstack/echo/v5/middleware"
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/plugin/markdown"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type APIV1Service struct {
|
||||
v1pb.UnimplementedInstanceServiceServer
|
||||
v1pb.UnimplementedAuthServiceServer
|
||||
v1pb.UnimplementedUserServiceServer
|
||||
v1pb.UnimplementedMemoServiceServer
|
||||
v1pb.UnimplementedAttachmentServiceServer
|
||||
v1pb.UnimplementedShortcutServiceServer
|
||||
v1pb.UnimplementedActivityServiceServer
|
||||
v1pb.UnimplementedIdentityProviderServiceServer
|
||||
|
||||
Secret string
|
||||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
MarkdownService markdown.Service
|
||||
|
||||
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
|
||||
thumbnailSemaphore *semaphore.Weighted
|
||||
}
|
||||
|
||||
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store) *APIV1Service {
|
||||
markdownService := markdown.NewService(
|
||||
markdown.WithTagExtension(),
|
||||
)
|
||||
return &APIV1Service{
|
||||
Secret: secret,
|
||||
Profile: profile,
|
||||
Store: store,
|
||||
MarkdownService: markdownService,
|
||||
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterGateway registers the gRPC-Gateway and Connect handlers with the given Echo instance.
|
||||
func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error {
|
||||
// Auth middleware for gRPC-Gateway - runs after routing, has access to method name.
|
||||
// Uses the same PublicMethods config as the Connect AuthInterceptor.
|
||||
authenticator := auth.NewAuthenticator(s.Store, s.Secret)
|
||||
gatewayAuthMiddleware := func(next runtime.HandlerFunc) runtime.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get the RPC method name from context (set by grpc-gateway after routing)
|
||||
rpcMethod, ok := runtime.RPCMethod(ctx)
|
||||
|
||||
// Extract credentials from HTTP headers
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
|
||||
result := authenticator.Authenticate(ctx, authHeader)
|
||||
|
||||
// Enforce authentication for non-public methods
|
||||
// If rpcMethod cannot be determined, allow through, service layer will handle visibility checks
|
||||
if result == nil && ok && !IsPublicMethod(rpcMethod) {
|
||||
http.Error(w, `{"code": 16, "message": "authentication required"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply auth result to context (no-op when result is nil for public endpoints)
|
||||
if result != nil {
|
||||
ctx = auth.ApplyToContext(ctx, result)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
next(w, r, pathParams)
|
||||
}
|
||||
}
|
||||
|
||||
// Create gRPC-Gateway mux with auth middleware.
|
||||
gwMux := runtime.NewServeMux(
|
||||
runtime.WithMiddlewares(gatewayAuthMiddleware),
|
||||
)
|
||||
if err := v1pb.RegisterInstanceServiceHandlerServer(ctx, gwMux, s); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterAuthServiceHandlerServer(ctx, gwMux, s); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterUserServiceHandlerServer(ctx, gwMux, s); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterMemoServiceHandlerServer(ctx, gwMux, s); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterAttachmentServiceHandlerServer(ctx, gwMux, s); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterShortcutServiceHandlerServer(ctx, gwMux, s); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterActivityServiceHandlerServer(ctx, gwMux, s); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterIdentityProviderServiceHandlerServer(ctx, gwMux, s); err != nil {
|
||||
return err
|
||||
}
|
||||
gwGroup := echoServer.Group("")
|
||||
gwGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
}))
|
||||
handler := echo.WrapHandler(gwMux)
|
||||
|
||||
gwGroup.Any("/api/v1/*", handler)
|
||||
gwGroup.Any("/file/*", handler)
|
||||
|
||||
// Connect handlers for browser clients (replaces grpc-web).
|
||||
logStacktraces := s.Profile.Demo
|
||||
connectInterceptors := connect.WithInterceptors(
|
||||
NewMetadataInterceptor(), // Convert HTTP headers to gRPC metadata first
|
||||
NewLoggingInterceptor(logStacktraces),
|
||||
NewRecoveryInterceptor(logStacktraces),
|
||||
NewAuthInterceptor(s.Store, s.Secret),
|
||||
)
|
||||
connectMux := http.NewServeMux()
|
||||
connectHandler := NewConnectServiceHandler(s)
|
||||
connectHandler.RegisterConnectHandlers(connectMux, connectInterceptors)
|
||||
|
||||
// Wrap with CORS for browser access
|
||||
corsHandler := middleware.CORSWithConfig(middleware.CORSConfig{
|
||||
UnsafeAllowOriginFunc: func(_ *echo.Context, origin string) (string, bool, error) {
|
||||
return origin, true, nil
|
||||
},
|
||||
AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions},
|
||||
AllowHeaders: []string{"*"},
|
||||
AllowCredentials: true,
|
||||
})
|
||||
connectGroup := echoServer.Group("", corsHandler)
|
||||
connectGroup.Any("/memos.api.v1.*", echo.WrapHandler(connectMux))
|
||||
|
||||
// Register AI REST endpoints (direct HTTP, no Connect/gRPC required)
|
||||
// Apply auth middleware so user context is populated (tries Bearer token then cookie)
|
||||
aiAuthenticator := auth.NewAuthenticator(s.Store, s.Secret)
|
||||
aiGroup := echoServer.Group("", func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
cookieHeader := c.Request().Header.Get("Cookie")
|
||||
result := aiAuthenticator.Authenticate(ctx, authHeader)
|
||||
if result == nil && cookieHeader != "" {
|
||||
// Try cookie-based auth (refresh token)
|
||||
user, err := aiAuthenticator.AuthenticateToUser(ctx, authHeader, cookieHeader)
|
||||
if err == nil && user != nil {
|
||||
ctx = auth.SetUserInContext(ctx, user, "")
|
||||
c.SetRequest(c.Request().WithContext(ctx))
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
if result != nil {
|
||||
ctx = auth.ApplyToContext(ctx, result)
|
||||
c.SetRequest(c.Request().WithContext(ctx))
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
s.RegisterAIHTTPHandlers(aiGroup)
|
||||
|
||||
return nil
|
||||
}
|
||||
307
server/router/fileserver/README.md
Normal file
307
server/router/fileserver/README.md
Normal file
@@ -0,0 +1,307 @@
|
||||
# Fileserver Package
|
||||
|
||||
## Overview
|
||||
|
||||
The `fileserver` package handles all binary file serving for Memos using native HTTP handlers. It was created to replace gRPC-based binary serving, which had limitations with HTTP range requests (required for Safari video/audio playback).
|
||||
|
||||
## Responsibilities
|
||||
|
||||
- Serve attachment binary files (images, videos, audio, documents)
|
||||
- Serve user avatar images
|
||||
- Handle HTTP range requests for video/audio streaming
|
||||
- Authenticate requests using JWT tokens or Personal Access Tokens
|
||||
- Check permissions for private content
|
||||
- Generate and serve image thumbnails
|
||||
- Prevent XSS attacks on uploaded content
|
||||
- Support S3 external storage
|
||||
|
||||
## Architecture
|
||||
|
||||
### Design Principles
|
||||
|
||||
1. **Separation of Concerns**: Binary files via HTTP, metadata via gRPC
|
||||
2. **DRY**: Imports auth constants from `api/v1` package (single source of truth)
|
||||
3. **Security First**: Authentication, authorization, and XSS prevention
|
||||
4. **Performance**: Native HTTP streaming with proper caching headers
|
||||
|
||||
### Package Structure
|
||||
|
||||
```
|
||||
fileserver/
|
||||
├── fileserver.go # Main service and HTTP handlers
|
||||
├── README.md # This file
|
||||
└── fileserver_test.go # Tests (to be added)
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### 1. Attachment Binary
|
||||
```
|
||||
GET /file/attachments/:uid/:filename[?thumbnail=true]
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `uid` - Attachment unique identifier
|
||||
- `filename` - Original filename
|
||||
- `thumbnail` (optional) - Return thumbnail for images
|
||||
|
||||
**Authentication:** Required for non-public memos
|
||||
|
||||
**Response:**
|
||||
- `200 OK` - File content with proper Content-Type
|
||||
- `206 Partial Content` - For range requests (video/audio)
|
||||
- `401 Unauthorized` - Authentication required
|
||||
- `403 Forbidden` - User not authorized
|
||||
- `404 Not Found` - Attachment not found
|
||||
|
||||
**Headers:**
|
||||
- `Content-Type` - MIME type of the file
|
||||
- `Cache-Control: public, max-age=3600`
|
||||
- `Accept-Ranges: bytes` - For video/audio
|
||||
- `Content-Range` - For partial responses (206)
|
||||
|
||||
### 2. User Avatar
|
||||
```
|
||||
GET /file/users/:identifier/avatar
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `identifier` - User ID (e.g., `1`) or username (e.g., `steven`)
|
||||
|
||||
**Authentication:** Not required (avatars are public)
|
||||
|
||||
**Response:**
|
||||
- `200 OK` - Avatar image (PNG/JPEG)
|
||||
- `404 Not Found` - User not found or no avatar set
|
||||
|
||||
**Headers:**
|
||||
- `Content-Type` - image/png or image/jpeg
|
||||
- `Cache-Control: public, max-age=3600`
|
||||
|
||||
## Authentication
|
||||
|
||||
### Supported Methods
|
||||
|
||||
The fileserver supports the following authentication methods:
|
||||
|
||||
1. **JWT Access Token** (`Authorization: Bearer {token}`)
|
||||
- Short-lived tokens (15 minutes) for API access
|
||||
- Stateless validation using JWT signature
|
||||
- Extracts user ID from token claims
|
||||
|
||||
2. **Personal Access Token (PAT)** (`Authorization: Bearer {pat}`)
|
||||
- Long-lived tokens for programmatic access
|
||||
- Validates against database for revocation
|
||||
- Prefixed with specific identifier
|
||||
|
||||
### Authentication Flow
|
||||
|
||||
```
|
||||
Request → getCurrentUser()
|
||||
├─→ Try Session Cookie
|
||||
│ ├─→ Parse cookie value
|
||||
│ ├─→ Get user from DB
|
||||
│ ├─→ Validate session
|
||||
│ └─→ Return user (if valid)
|
||||
│
|
||||
└─→ Try JWT Token
|
||||
├─→ Parse Authorization header
|
||||
├─→ Verify JWT signature
|
||||
├─→ Get user from DB
|
||||
├─→ Validate token in access tokens list
|
||||
└─→ Return user (if valid)
|
||||
```
|
||||
|
||||
### Permission Model
|
||||
|
||||
**Attachments:**
|
||||
- Unlinked: Public (no auth required)
|
||||
- Public memo: Public (no auth required)
|
||||
- Protected memo: Requires authentication
|
||||
- Private memo: Creator only
|
||||
|
||||
**Avatars:**
|
||||
- Always public (no auth required)
|
||||
|
||||
## Key Functions
|
||||
|
||||
### HTTP Handlers
|
||||
|
||||
#### `serveAttachmentFile(c echo.Context) error`
|
||||
Main handler for attachment binary serving.
|
||||
|
||||
**Flow:**
|
||||
1. Extract UID from URL parameter
|
||||
2. Fetch attachment from database
|
||||
3. Check permissions (memo visibility)
|
||||
4. Get binary blob (local file, S3, or database)
|
||||
5. Handle thumbnail request (if applicable)
|
||||
6. Set security headers (XSS prevention)
|
||||
7. Serve with range request support (video/audio)
|
||||
|
||||
#### `serveUserAvatar(c echo.Context) error`
|
||||
Main handler for user avatar serving.
|
||||
|
||||
**Flow:**
|
||||
1. Extract identifier (ID or username) from URL
|
||||
2. Lookup user in database
|
||||
3. Check if avatar exists
|
||||
4. Decode base64 data URI
|
||||
5. Serve with proper content type and caching
|
||||
|
||||
### Authentication
|
||||
|
||||
#### `getCurrentUser(ctx, c) (*store.User, error)`
|
||||
Authenticates request using session cookie or JWT token.
|
||||
|
||||
#### `authenticateBySession(ctx, cookie) (*store.User, error)`
|
||||
Validates session cookie and returns authenticated user.
|
||||
|
||||
#### `authenticateByJWT(ctx, token) (*store.User, error)`
|
||||
Validates JWT access token and returns authenticated user.
|
||||
|
||||
### Permission Checks
|
||||
|
||||
#### `checkAttachmentPermission(ctx, c, attachment) error`
|
||||
Validates user has permission to access attachment based on memo visibility.
|
||||
|
||||
### File Operations
|
||||
|
||||
#### `getAttachmentBlob(attachment) ([]byte, error)`
|
||||
Retrieves binary content from local storage, S3, or database.
|
||||
|
||||
#### `getOrGenerateThumbnail(ctx, attachment) ([]byte, error)`
|
||||
Returns cached thumbnail or generates new one (with semaphore limiting).
|
||||
|
||||
### Utilities
|
||||
|
||||
#### `getUserByIdentifier(ctx, identifier) (*store.User, error)`
|
||||
Finds user by ID (int) or username (string).
|
||||
|
||||
#### `extractImageInfo(dataURI) (type, base64, error)`
|
||||
Parses data URI to extract MIME type and base64 data.
|
||||
|
||||
## Dependencies
|
||||
|
||||
### External Packages
|
||||
- `github.com/labstack/echo/v4` - HTTP router and middleware
|
||||
- `github.com/golang-jwt/jwt/v5` - JWT parsing and validation
|
||||
- `github.com/disintegration/imaging` - Image thumbnail generation
|
||||
- `golang.org/x/sync/semaphore` - Concurrency control for thumbnails
|
||||
|
||||
### Internal Packages
|
||||
- `server/auth` - Authentication utilities
|
||||
- `store` - Database operations
|
||||
- `internal/profile` - Server configuration
|
||||
- `plugin/storage/s3` - S3 storage client
|
||||
|
||||
## Configuration
|
||||
|
||||
### Constants
|
||||
|
||||
Auth-related constants are imported from `server/auth`:
|
||||
- `auth.RefreshTokenCookieName` - "memos_refresh"
|
||||
- `auth.PersonalAccessTokenPrefix` - PAT identifier prefix
|
||||
|
||||
Package-specific constants:
|
||||
- `ThumbnailCacheFolder` - ".thumbnail_cache"
|
||||
- `thumbnailMaxSize` - 600px
|
||||
- `SupportedThumbnailMimeTypes` - ["image/png", "image/jpeg"]
|
||||
|
||||
## Error Handling
|
||||
|
||||
All handlers return Echo HTTP errors with appropriate status codes:
|
||||
|
||||
```go
|
||||
// Bad request
|
||||
echo.NewHTTPError(http.StatusBadRequest, "message")
|
||||
|
||||
// Unauthorized (no auth)
|
||||
echo.NewHTTPError(http.StatusUnauthorized, "message")
|
||||
|
||||
// Forbidden (auth but no permission)
|
||||
echo.NewHTTPError(http.StatusForbidden, "message")
|
||||
|
||||
// Not found
|
||||
echo.NewHTTPError(http.StatusNotFound, "message")
|
||||
|
||||
// Internal error
|
||||
echo.NewHTTPError(http.StatusInternalServerError, "message").SetInternal(err)
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### 1. XSS Prevention
|
||||
SVG and HTML files are served as `application/octet-stream` to prevent script execution:
|
||||
|
||||
```go
|
||||
if contentType == "image/svg+xml" ||
|
||||
contentType == "text/html" ||
|
||||
contentType == "application/xhtml+xml" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Authentication
|
||||
Private content requires valid JWT access token or Personal Access Token.
|
||||
|
||||
### 3. Authorization
|
||||
Memo visibility rules enforced before serving attachments.
|
||||
|
||||
### 4. Input Validation
|
||||
- Attachment UID validated from database
|
||||
- User identifier validated (ID or username)
|
||||
- Range requests validated before processing
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
### 1. Thumbnail Caching
|
||||
Thumbnails cached on disk to avoid regeneration:
|
||||
- Cache location: `{data_dir}/.thumbnail_cache/`
|
||||
- Filename: `{attachment_id}{extension}`
|
||||
- Semaphore limits concurrent generation (max 3)
|
||||
|
||||
### 2. HTTP Range Requests
|
||||
Video/audio files use `http.ServeContent()` for efficient streaming:
|
||||
- Automatic range parsing
|
||||
- Efficient memory usage (streaming, not loading full file)
|
||||
- Safari-compatible partial content responses
|
||||
|
||||
### 3. Caching Headers
|
||||
All responses include cache headers:
|
||||
```
|
||||
Cache-Control: public, max-age=3600
|
||||
```
|
||||
|
||||
### 4. S3 External Links
|
||||
S3 files served via presigned URLs (no server download).
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Tests (To Add)
|
||||
See SAFARI_FIX.md for recommended test coverage.
|
||||
|
||||
### Manual Testing
|
||||
```bash
|
||||
# Test attachment
|
||||
curl "http://localhost:8081/file/attachments/{uid}/file.jpg"
|
||||
|
||||
# Test avatar by ID
|
||||
curl "http://localhost:8081/file/users/1/avatar"
|
||||
|
||||
# Test avatar by username
|
||||
curl "http://localhost:8081/file/users/steven/avatar"
|
||||
|
||||
# Test range request
|
||||
curl -H "Range: bytes=0-999" "http://localhost:8081/file/attachments/{uid}/video.mp4"
|
||||
```
|
||||
|
||||
## Future Improvements
|
||||
|
||||
See SAFARI_FIX.md section "Future Improvements" for planned enhancements.
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [SAFARI_FIX.md](../../../SAFARI_FIX.md) - Full migration guide
|
||||
- [server/router/api/v1/auth.go](../api/v1/auth.go) - Auth constants source of truth
|
||||
- [RFC 7233](https://tools.ietf.org/html/rfc7233) - HTTP Range Requests spec
|
||||
587
server/router/fileserver/fileserver.go
Normal file
587
server/router/fileserver/fileserver.go
Normal file
@@ -0,0 +1,587 @@
|
||||
package fileserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/storage/s3"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// Constants for file serving configuration.
|
||||
const (
|
||||
// ThumbnailCacheFolder is the folder name where thumbnail images are stored.
|
||||
ThumbnailCacheFolder = ".thumbnail_cache"
|
||||
|
||||
// thumbnailMaxSize is the maximum dimension (width or height) for thumbnails.
|
||||
thumbnailMaxSize = 600
|
||||
|
||||
// maxConcurrentThumbnails limits concurrent thumbnail generation to prevent memory exhaustion.
|
||||
maxConcurrentThumbnails = 3
|
||||
|
||||
// cacheMaxAge is the max-age value for Cache-Control headers (1 hour).
|
||||
cacheMaxAge = "public, max-age=3600"
|
||||
)
|
||||
|
||||
// xssUnsafeTypes contains MIME types that could execute scripts if served directly.
|
||||
// These are served as application/octet-stream to prevent XSS attacks.
|
||||
var xssUnsafeTypes = map[string]bool{
|
||||
"text/html": true,
|
||||
"text/javascript": true,
|
||||
"application/javascript": true,
|
||||
"application/x-javascript": true,
|
||||
"text/xml": true,
|
||||
"application/xml": true,
|
||||
"application/xhtml+xml": true,
|
||||
"image/svg+xml": true,
|
||||
}
|
||||
|
||||
// thumbnailSupportedTypes contains image MIME types that support thumbnail generation.
|
||||
var thumbnailSupportedTypes = map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/heic": true,
|
||||
"image/heif": true,
|
||||
"image/webp": true,
|
||||
}
|
||||
|
||||
// avatarAllowedTypes contains MIME types allowed for user avatars.
|
||||
var avatarAllowedTypes = map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
"image/heic": true,
|
||||
"image/heif": true,
|
||||
}
|
||||
|
||||
// SupportedThumbnailMimeTypes is the exported list of thumbnail-supported MIME types.
|
||||
var SupportedThumbnailMimeTypes = []string{
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/heic",
|
||||
"image/heif",
|
||||
"image/webp",
|
||||
}
|
||||
|
||||
// dataURIRegex parses data URI format: data:image/png;base64,iVBORw0KGgo...
|
||||
var dataURIRegex = regexp.MustCompile(`^data:(?P<type>[^;]+);base64,(?P<base64>.+)`)
|
||||
|
||||
// FileServerService handles HTTP file serving with proper range request support.
|
||||
type FileServerService struct {
|
||||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
authenticator *auth.Authenticator
|
||||
|
||||
// thumbnailSemaphore limits concurrent thumbnail generation.
|
||||
thumbnailSemaphore *semaphore.Weighted
|
||||
}
|
||||
|
||||
// NewFileServerService creates a new file server service.
|
||||
func NewFileServerService(profile *profile.Profile, store *store.Store, secret string) *FileServerService {
|
||||
return &FileServerService{
|
||||
Profile: profile,
|
||||
Store: store,
|
||||
authenticator: auth.NewAuthenticator(store, secret),
|
||||
thumbnailSemaphore: semaphore.NewWeighted(maxConcurrentThumbnails),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers HTTP file serving routes.
|
||||
func (s *FileServerService) RegisterRoutes(echoServer *echo.Echo) {
|
||||
fileGroup := echoServer.Group("/file")
|
||||
fileGroup.GET("/attachments/:uid/:filename", s.serveAttachmentFile)
|
||||
fileGroup.GET("/users/:identifier/avatar", s.serveUserAvatar)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// HTTP Handlers
|
||||
// =============================================================================
|
||||
|
||||
// serveAttachmentFile serves attachment binary content using native HTTP.
|
||||
func (s *FileServerService) serveAttachmentFile(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
uid := c.Param("uid")
|
||||
wantThumbnail := c.QueryParam("thumbnail") == "true"
|
||||
|
||||
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
|
||||
UID: &uid,
|
||||
GetBlob: true,
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment").Wrap(err)
|
||||
}
|
||||
if attachment == nil {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "attachment not found")
|
||||
}
|
||||
|
||||
if err := s.checkAttachmentPermission(ctx, c, attachment); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
contentType := s.sanitizeContentType(attachment.Type)
|
||||
|
||||
// Stream video/audio to avoid loading entire file into memory.
|
||||
if isMediaType(attachment.Type) {
|
||||
return s.serveMediaStream(c, attachment, contentType)
|
||||
}
|
||||
|
||||
return s.serveStaticFile(c, attachment, contentType, wantThumbnail)
|
||||
}
|
||||
|
||||
// serveUserAvatar serves user avatar images.
|
||||
func (s *FileServerService) serveUserAvatar(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
identifier := c.Param("identifier")
|
||||
|
||||
user, err := s.getUserByIdentifier(ctx, identifier)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get user").Wrap(err)
|
||||
}
|
||||
if user == nil {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "user not found")
|
||||
}
|
||||
if user.AvatarURL == "" {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "avatar not found")
|
||||
}
|
||||
|
||||
imageType, imageData, err := s.parseDataURI(user.AvatarURL)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to parse avatar data").Wrap(err)
|
||||
}
|
||||
|
||||
if !avatarAllowedTypes[imageType] {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid avatar image type")
|
||||
}
|
||||
|
||||
setSecurityHeaders(c)
|
||||
c.Response().Header().Set(echo.HeaderContentType, imageType)
|
||||
c.Response().Header().Set(echo.HeaderCacheControl, cacheMaxAge)
|
||||
|
||||
return c.Blob(http.StatusOK, imageType, imageData)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// File Serving Methods
|
||||
// =============================================================================
|
||||
|
||||
// serveMediaStream serves video/audio files using streaming to avoid memory exhaustion.
|
||||
func (s *FileServerService) serveMediaStream(c *echo.Context, attachment *store.Attachment, contentType string) error {
|
||||
setSecurityHeaders(c)
|
||||
setMediaHeaders(c, contentType, attachment.Type)
|
||||
|
||||
switch attachment.StorageType {
|
||||
case storepb.AttachmentStorageType_LOCAL:
|
||||
filePath, err := s.resolveLocalPath(attachment.Reference)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to resolve file path").Wrap(err)
|
||||
}
|
||||
http.ServeFile(c.Response(), c.Request(), filePath)
|
||||
return nil
|
||||
|
||||
case storepb.AttachmentStorageType_S3:
|
||||
presignURL, err := s.getS3PresignedURL(c.Request().Context(), attachment)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to generate presigned URL").Wrap(err)
|
||||
}
|
||||
return c.Redirect(http.StatusTemporaryRedirect, presignURL)
|
||||
|
||||
default:
|
||||
// Database storage fallback.
|
||||
modTime := time.Unix(attachment.UpdatedTs, 0)
|
||||
http.ServeContent(c.Response(), c.Request(), attachment.Filename, modTime, bytes.NewReader(attachment.Blob))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// serveStaticFile serves non-streaming files (images, documents, etc.).
|
||||
func (s *FileServerService) serveStaticFile(c *echo.Context, attachment *store.Attachment, contentType string, wantThumbnail bool) error {
|
||||
blob, err := s.getAttachmentBlob(attachment)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment blob").Wrap(err)
|
||||
}
|
||||
|
||||
// Generate thumbnail for supported image types.
|
||||
if wantThumbnail && thumbnailSupportedTypes[attachment.Type] {
|
||||
if thumbnailBlob, err := s.getOrGenerateThumbnail(c.Request().Context(), attachment); err != nil {
|
||||
slog.Warn("failed to get thumbnail", "error", err)
|
||||
} else {
|
||||
blob = thumbnailBlob
|
||||
}
|
||||
}
|
||||
|
||||
setSecurityHeaders(c)
|
||||
setMediaHeaders(c, contentType, attachment.Type)
|
||||
|
||||
// Force download for non-media files to prevent XSS execution.
|
||||
if !strings.HasPrefix(contentType, "image/") && contentType != "application/pdf" {
|
||||
c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("attachment; filename=%q", attachment.Filename))
|
||||
}
|
||||
|
||||
return c.Blob(http.StatusOK, contentType, blob)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Storage Operations
|
||||
// =============================================================================
|
||||
|
||||
// getAttachmentBlob retrieves the binary content of an attachment from storage.
|
||||
func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
|
||||
switch attachment.StorageType {
|
||||
case storepb.AttachmentStorageType_LOCAL:
|
||||
return s.readLocalFile(attachment.Reference)
|
||||
|
||||
case storepb.AttachmentStorageType_S3:
|
||||
return s.downloadFromS3(attachment)
|
||||
|
||||
default:
|
||||
return attachment.Blob, nil
|
||||
}
|
||||
}
|
||||
|
||||
// getAttachmentReader returns a reader for streaming attachment content.
|
||||
func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) {
|
||||
switch attachment.StorageType {
|
||||
case storepb.AttachmentStorageType_LOCAL:
|
||||
filePath, err := s.resolveLocalPath(attachment.Reference)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, "file not found")
|
||||
}
|
||||
return nil, errors.Wrap(err, "failed to open file")
|
||||
}
|
||||
return file, nil
|
||||
|
||||
case storepb.AttachmentStorageType_S3:
|
||||
s3Client, s3Object, err := s.createS3Client(attachment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to stream from S3")
|
||||
}
|
||||
return reader, nil
|
||||
|
||||
default:
|
||||
return io.NopCloser(bytes.NewReader(attachment.Blob)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// resolveLocalPath converts a storage reference to an absolute file path.
|
||||
func (s *FileServerService) resolveLocalPath(reference string) (string, error) {
|
||||
filePath := filepath.FromSlash(reference)
|
||||
if !filepath.IsAbs(filePath) {
|
||||
filePath = filepath.Join(s.Profile.Data, filePath)
|
||||
}
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// readLocalFile reads the entire contents of a local file.
|
||||
func (s *FileServerService) readLocalFile(reference string) ([]byte, error) {
|
||||
filePath, err := s.resolveLocalPath(reference)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, "file not found")
|
||||
}
|
||||
return nil, errors.Wrap(err, "failed to open file")
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
blob, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read file")
|
||||
}
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
// createS3Client creates an S3 client from attachment payload.
|
||||
func (*FileServerService) createS3Client(attachment *store.Attachment) (*s3.Client, *storepb.AttachmentPayload_S3Object, error) {
|
||||
if attachment.Payload == nil {
|
||||
return nil, nil, errors.New("attachment payload is missing")
|
||||
}
|
||||
s3Object := attachment.Payload.GetS3Object()
|
||||
if s3Object == nil {
|
||||
return nil, nil, errors.New("S3 object payload is missing")
|
||||
}
|
||||
if s3Object.S3Config == nil {
|
||||
return nil, nil, errors.New("S3 config is missing")
|
||||
}
|
||||
if s3Object.Key == "" {
|
||||
return nil, nil, errors.New("S3 object key is missing")
|
||||
}
|
||||
|
||||
client, err := s3.NewClient(context.Background(), s3Object.S3Config)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create S3 client")
|
||||
}
|
||||
return client, s3Object, nil
|
||||
}
|
||||
|
||||
// downloadFromS3 downloads the entire object from S3.
|
||||
func (s *FileServerService) downloadFromS3(attachment *store.Attachment) ([]byte, error) {
|
||||
client, s3Object, err := s.createS3Client(attachment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := client.GetObject(context.Background(), s3Object.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to download from S3")
|
||||
}
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
// getS3PresignedURL generates a presigned URL for direct S3 access.
|
||||
func (s *FileServerService) getS3PresignedURL(ctx context.Context, attachment *store.Attachment) (string, error) {
|
||||
client, s3Object, err := s.createS3Client(attachment)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url, err := client.PresignGetObject(ctx, s3Object.Key)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to presign URL")
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Thumbnail Generation
|
||||
// =============================================================================
|
||||
|
||||
// getOrGenerateThumbnail returns the thumbnail image of the attachment.
|
||||
// Uses semaphore to limit concurrent thumbnail generation and prevent memory exhaustion.
|
||||
func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachment *store.Attachment) ([]byte, error) {
|
||||
thumbnailPath, err := s.getThumbnailPath(attachment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fast path: return cached thumbnail if exists.
|
||||
if blob, err := s.readCachedThumbnail(thumbnailPath); err == nil {
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
// Acquire semaphore to limit concurrent generation.
|
||||
if err := s.thumbnailSemaphore.Acquire(ctx, 1); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to acquire semaphore")
|
||||
}
|
||||
defer s.thumbnailSemaphore.Release(1)
|
||||
|
||||
// Double-check after acquiring semaphore (another goroutine may have generated it).
|
||||
if blob, err := s.readCachedThumbnail(thumbnailPath); err == nil {
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
return s.generateThumbnail(attachment, thumbnailPath)
|
||||
}
|
||||
|
||||
// getThumbnailPath returns the file path for a cached thumbnail.
|
||||
func (s *FileServerService) getThumbnailPath(attachment *store.Attachment) (string, error) {
|
||||
cacheFolder := filepath.Join(s.Profile.Data, ThumbnailCacheFolder)
|
||||
if err := os.MkdirAll(cacheFolder, os.ModePerm); err != nil {
|
||||
return "", errors.Wrap(err, "failed to create thumbnail cache folder")
|
||||
}
|
||||
filename := fmt.Sprintf("%d%s", attachment.ID, filepath.Ext(attachment.Filename))
|
||||
return filepath.Join(cacheFolder, filename), nil
|
||||
}
|
||||
|
||||
// readCachedThumbnail reads a thumbnail from the cache directory.
|
||||
func (*FileServerService) readCachedThumbnail(path string) ([]byte, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
return io.ReadAll(file)
|
||||
}
|
||||
|
||||
// generateThumbnail creates a new thumbnail and saves it to disk.
|
||||
func (s *FileServerService) generateThumbnail(attachment *store.Attachment, thumbnailPath string) ([]byte, error) {
|
||||
reader, err := s.getAttachmentReader(attachment)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get attachment reader")
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
img, err := imaging.Decode(reader, imaging.AutoOrientation(true))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to decode image")
|
||||
}
|
||||
|
||||
width, height := img.Bounds().Dx(), img.Bounds().Dy()
|
||||
thumbnailWidth, thumbnailHeight := calculateThumbnailDimensions(width, height)
|
||||
|
||||
thumbnailImage := imaging.Resize(img, thumbnailWidth, thumbnailHeight, imaging.Lanczos)
|
||||
|
||||
if err := imaging.Save(thumbnailImage, thumbnailPath); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to save thumbnail")
|
||||
}
|
||||
|
||||
return s.readCachedThumbnail(thumbnailPath)
|
||||
}
|
||||
|
||||
// calculateThumbnailDimensions calculates the target dimensions for a thumbnail.
|
||||
// The largest dimension is constrained to thumbnailMaxSize while maintaining aspect ratio.
|
||||
// Small images are not enlarged.
|
||||
func calculateThumbnailDimensions(width, height int) (int, int) {
|
||||
if max(width, height) <= thumbnailMaxSize {
|
||||
return width, height
|
||||
}
|
||||
if width >= height {
|
||||
return thumbnailMaxSize, 0 // Landscape: constrain width.
|
||||
}
|
||||
return 0, thumbnailMaxSize // Portrait: constrain height.
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Authentication & Authorization
|
||||
// =============================================================================
|
||||
|
||||
// checkAttachmentPermission verifies the user has permission to access the attachment.
|
||||
func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c *echo.Context, attachment *store.Attachment) error {
|
||||
// For unlinked attachments, only the creator can access.
|
||||
if attachment.MemoID == nil {
|
||||
user, err := s.getCurrentUser(ctx, c)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get current user").Wrap(err)
|
||||
}
|
||||
if user == nil {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "unauthorized access")
|
||||
}
|
||||
if user.ID != attachment.CreatorID && user.Role != store.RoleAdmin {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "forbidden access")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to find memo").Wrap(err)
|
||||
}
|
||||
if memo == nil {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "memo not found")
|
||||
}
|
||||
|
||||
if memo.Visibility == store.Public {
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := s.getCurrentUser(ctx, c)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get current user").Wrap(err)
|
||||
}
|
||||
if user == nil {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "unauthorized access")
|
||||
}
|
||||
|
||||
if memo.Visibility == store.Private && user.ID != memo.CreatorID && user.Role != store.RoleAdmin {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "forbidden access")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getCurrentUser retrieves the current authenticated user from the request.
|
||||
// Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie.
|
||||
func (s *FileServerService) getCurrentUser(ctx context.Context, c *echo.Context) (*store.User, error) {
|
||||
authHeader := c.Request().Header.Get(echo.HeaderAuthorization)
|
||||
cookieHeader := c.Request().Header.Get("Cookie")
|
||||
return s.authenticator.AuthenticateToUser(ctx, authHeader, cookieHeader)
|
||||
}
|
||||
|
||||
// getUserByIdentifier finds a user by either ID or username.
|
||||
func (s *FileServerService) getUserByIdentifier(ctx context.Context, identifier string) (*store.User, error) {
|
||||
if userID, err := util.ConvertStringToInt32(identifier); err == nil {
|
||||
return s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
}
|
||||
return s.Store.GetUser(ctx, &store.FindUser{Username: &identifier})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Helper Functions
|
||||
// =============================================================================
|
||||
|
||||
// sanitizeContentType converts potentially dangerous MIME types to safe alternatives.
|
||||
func (*FileServerService) sanitizeContentType(mimeType string) string {
|
||||
contentType := mimeType
|
||||
if strings.HasPrefix(contentType, "text/") {
|
||||
contentType += "; charset=utf-8"
|
||||
}
|
||||
// Normalize for case-insensitive lookup.
|
||||
if xssUnsafeTypes[strings.ToLower(mimeType)] {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
return contentType
|
||||
}
|
||||
|
||||
// parseDataURI extracts MIME type and decoded data from a data URI.
|
||||
func (*FileServerService) parseDataURI(dataURI string) (string, []byte, error) {
|
||||
matches := dataURIRegex.FindStringSubmatch(dataURI)
|
||||
if len(matches) != 3 {
|
||||
return "", nil, errors.New("invalid data URI format")
|
||||
}
|
||||
|
||||
imageType := matches[1]
|
||||
imageData, err := base64.StdEncoding.DecodeString(matches[2])
|
||||
if err != nil {
|
||||
return "", nil, errors.Wrap(err, "failed to decode base64 data")
|
||||
}
|
||||
|
||||
return imageType, imageData, nil
|
||||
}
|
||||
|
||||
// isMediaType checks if the MIME type is video or audio.
|
||||
func isMediaType(mimeType string) bool {
|
||||
return strings.HasPrefix(mimeType, "video/") || strings.HasPrefix(mimeType, "audio/")
|
||||
}
|
||||
|
||||
// setSecurityHeaders sets common security headers for all responses.
|
||||
func setSecurityHeaders(c *echo.Context) {
|
||||
h := c.Response().Header()
|
||||
h.Set("X-Content-Type-Options", "nosniff")
|
||||
h.Set("X-Frame-Options", "DENY")
|
||||
h.Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';")
|
||||
}
|
||||
|
||||
// setMediaHeaders sets headers for media file responses.
|
||||
func setMediaHeaders(c *echo.Context, contentType, originalType string) {
|
||||
h := c.Response().Header()
|
||||
h.Set(echo.HeaderContentType, contentType)
|
||||
h.Set(echo.HeaderCacheControl, cacheMaxAge)
|
||||
|
||||
// Support HDR/wide color gamut for images and videos.
|
||||
if strings.HasPrefix(originalType, "image/") || strings.HasPrefix(originalType, "video/") {
|
||||
h.Set("Color-Gamut", "srgb, p3, rec2020")
|
||||
}
|
||||
}
|
||||
68
server/router/frontend/frontend.go
Normal file
68
server/router/frontend/frontend.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package frontend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"io/fs"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/labstack/echo/v5/middleware"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
//go:embed dist/*
|
||||
var embeddedFiles embed.FS
|
||||
|
||||
type FrontendService struct {
|
||||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
}
|
||||
|
||||
func NewFrontendService(profile *profile.Profile, store *store.Store) *FrontendService {
|
||||
return &FrontendService{
|
||||
Profile: profile,
|
||||
Store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (*FrontendService) Serve(_ context.Context, e *echo.Echo) {
|
||||
skipper := func(c *echo.Context) bool {
|
||||
// Skip API routes.
|
||||
if util.HasPrefixes(c.Path(), "/api", "/memos.api.v1") {
|
||||
return true
|
||||
}
|
||||
// For index.html and root path, set no-cache headers to prevent browser caching
|
||||
// This prevents sensitive data from being accessible via browser back button after logout
|
||||
if c.Path() == "/" || c.Path() == "/index.html" {
|
||||
c.Response().Header().Set(echo.HeaderCacheControl, "no-cache, no-store, must-revalidate")
|
||||
c.Response().Header().Set("Pragma", "no-cache")
|
||||
c.Response().Header().Set("Expires", "0")
|
||||
return false
|
||||
}
|
||||
// Set Cache-Control header for static assets.
|
||||
// Since Vite generates content-hashed filenames (e.g., index-BtVjejZf.js),
|
||||
// we can cache aggressively but use immutable to prevent revalidation checks.
|
||||
// For frequently redeployed instances, use shorter max-age (1 hour) to avoid
|
||||
// serving stale assets after redeployment.
|
||||
c.Response().Header().Set(echo.HeaderCacheControl, "public, max-age=3600, immutable") // 1 hour
|
||||
return false
|
||||
}
|
||||
|
||||
// Route to serve the main app with HTML5 fallback for SPA behavior.
|
||||
e.Use(middleware.StaticWithConfig(middleware.StaticConfig{
|
||||
Filesystem: getFileSystem("dist"),
|
||||
HTML5: true, // Enable fallback to index.html
|
||||
Skipper: skipper,
|
||||
}))
|
||||
}
|
||||
|
||||
func getFileSystem(path string) fs.FS {
|
||||
sub, err := fs.Sub(embeddedFiles, path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return sub
|
||||
}
|
||||
66
server/router/mcp/README.md
Normal file
66
server/router/mcp/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# MCP Server
|
||||
|
||||
This package implements a [Model Context Protocol (MCP)](https://modelcontextprotocol.io) server embedded in the Memos HTTP process. It exposes memo operations as MCP tools, making Memos accessible to any MCP-compatible AI client (Claude Desktop, Cursor, Zed, etc.).
|
||||
|
||||
## Endpoint
|
||||
|
||||
```
|
||||
POST /mcp (tool calls, initialize)
|
||||
GET /mcp (optional SSE stream for server-to-client messages)
|
||||
```
|
||||
|
||||
Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26).
|
||||
|
||||
## Authentication
|
||||
|
||||
Every request must include a Personal Access Token (PAT):
|
||||
|
||||
```
|
||||
Authorization: Bearer <your-PAT>
|
||||
```
|
||||
|
||||
PATs are long-lived tokens created in Settings → My Account → Access Tokens. Short-lived JWT session tokens are not accepted. Requests without a valid PAT receive `HTTP 401`.
|
||||
|
||||
## Tools
|
||||
|
||||
All tools are scoped to the authenticated user's memos.
|
||||
|
||||
| Tool | Description | Required params | Optional params |
|
||||
|---|---|---|---|
|
||||
| `list_memos` | List memos | — | `page_size` (int, max 100), `filter` (CEL expression) |
|
||||
| `get_memo` | Get a single memo | `name` | — |
|
||||
| `search_memos` | Full-text search | `query` | — |
|
||||
| `create_memo` | Create a memo | `content` | `visibility` |
|
||||
| `update_memo` | Update content or visibility | `name` | `content`, `visibility` |
|
||||
| `delete_memo` | Delete a memo | `name` | — |
|
||||
|
||||
**`name`** is the memo resource name, e.g. `memos/abc123`.
|
||||
|
||||
**`visibility`** accepts `PRIVATE` (default), `PROTECTED`, or `PUBLIC`.
|
||||
|
||||
**`filter`** accepts CEL expressions supported by the memo filter engine, e.g.:
|
||||
- `content.contains("keyword")`
|
||||
- `visibility == "PUBLIC"`
|
||||
- `has_task_list`
|
||||
|
||||
## Connecting Claude Code
|
||||
|
||||
```bash
|
||||
claude mcp add --transport http memos http://localhost:5230/mcp \
|
||||
--header "Authorization: Bearer <your-PAT>"
|
||||
```
|
||||
|
||||
Use `--scope user` to make it available across all projects:
|
||||
|
||||
```bash
|
||||
claude mcp add --scope user --transport http memos http://localhost:5230/mcp \
|
||||
--header "Authorization: Bearer <your-PAT>"
|
||||
```
|
||||
|
||||
## Package Structure
|
||||
|
||||
| File | Responsibility |
|
||||
|---|---|
|
||||
| `mcp.go` | `MCPService` struct, constructor, route registration |
|
||||
| `auth_middleware.go` | Echo middleware — validates Bearer token, sets user ID in context |
|
||||
| `tools_memo.go` | Tool registration and six memo tool handlers |
|
||||
56
server/router/mcp/mcp.go
Normal file
56
server/router/mcp/mcp.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/labstack/echo/v5/middleware"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type MCPService struct {
|
||||
store *store.Store
|
||||
authenticator *auth.Authenticator
|
||||
}
|
||||
|
||||
func NewMCPService(store *store.Store, secret string) *MCPService {
|
||||
return &MCPService{
|
||||
store: store,
|
||||
authenticator: auth.NewAuthenticator(store, secret),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
|
||||
mcpSrv := mcpserver.NewMCPServer("Memos", "1.0.0",
|
||||
mcpserver.WithToolCapabilities(false),
|
||||
)
|
||||
s.registerMemoTools(mcpSrv)
|
||||
s.registerTagTools(mcpSrv)
|
||||
s.registerMemoResources(mcpSrv)
|
||||
s.registerPrompts(mcpSrv)
|
||||
|
||||
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv)
|
||||
|
||||
mcpGroup := echoServer.Group("")
|
||||
mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
}))
|
||||
mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c *echo.Context) error {
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader != "" {
|
||||
result := s.authenticator.Authenticate(c.Request().Context(), authHeader)
|
||||
if result == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired token"})
|
||||
}
|
||||
ctx := auth.ApplyToContext(c.Request().Context(), result)
|
||||
c.SetRequest(c.Request().WithContext(ctx))
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler))
|
||||
}
|
||||
84
server/router/mcp/prompts.go
Normal file
84
server/router/mcp/prompts.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
func (s *MCPService) registerPrompts(mcpSrv *mcpserver.MCPServer) {
|
||||
// capture — turns free-form user input into a structured create_memo call.
|
||||
mcpSrv.AddPrompt(
|
||||
mcp.NewPrompt("capture",
|
||||
mcp.WithPromptDescription("Capture a thought, idea, or note as a new memo. "+
|
||||
"Use this prompt when the user wants to quickly save something. "+
|
||||
"The assistant will call create_memo with the provided content."),
|
||||
mcp.WithArgument("content",
|
||||
mcp.ArgumentDescription("The text to save as a memo"),
|
||||
mcp.RequiredArgument(),
|
||||
),
|
||||
mcp.WithArgument("tags",
|
||||
mcp.ArgumentDescription("Comma-separated tags to apply, e.g. \"work,project\""),
|
||||
),
|
||||
),
|
||||
s.handleCapturePrompt,
|
||||
)
|
||||
|
||||
// review — surfaces existing memos on a topic for summarisation.
|
||||
mcpSrv.AddPrompt(
|
||||
mcp.NewPrompt("review",
|
||||
mcp.WithPromptDescription("Search and review memos on a given topic. "+
|
||||
"The assistant will call search_memos and summarise the results."),
|
||||
mcp.WithArgument("topic",
|
||||
mcp.ArgumentDescription("Topic or keyword to search for"),
|
||||
mcp.RequiredArgument(),
|
||||
),
|
||||
),
|
||||
s.handleReviewPrompt,
|
||||
)
|
||||
}
|
||||
|
||||
func (*MCPService) handleCapturePrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
content := req.Params.Arguments["content"]
|
||||
if content == "" {
|
||||
return nil, errors.New("content argument is required")
|
||||
}
|
||||
|
||||
tags := req.Params.Arguments["tags"]
|
||||
instruction := fmt.Sprintf(
|
||||
"Please save the following as a new private memo using the create_memo tool.\n\nContent:\n%s",
|
||||
content,
|
||||
)
|
||||
if tags != "" {
|
||||
instruction += fmt.Sprintf("\n\nAppend these tags inline using #tag syntax: %s", tags)
|
||||
}
|
||||
|
||||
return &mcp.GetPromptResult{
|
||||
Description: "Capture a memo",
|
||||
Messages: []mcp.PromptMessage{
|
||||
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*MCPService) handleReviewPrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
topic := req.Params.Arguments["topic"]
|
||||
if topic == "" {
|
||||
return nil, errors.New("topic argument is required")
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf(
|
||||
"Please use the search_memos tool to find memos about %q, then provide a concise summary of what has been written on this topic, grouped by theme. Include the memo names so the user can reference them.",
|
||||
topic,
|
||||
)
|
||||
|
||||
return &mcp.GetPromptResult{
|
||||
Description: fmt.Sprintf("Review memos about %q", topic),
|
||||
Messages: []mcp.PromptMessage{
|
||||
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
85
server/router/mcp/resources_memo.go
Normal file
85
server/router/mcp/resources_memo.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// Memo resource URI scheme: memo://memos/{uid}
|
||||
// Clients can read any memo they have access to by URI without calling a tool.
|
||||
|
||||
func (s *MCPService) registerMemoResources(mcpSrv *mcpserver.MCPServer) {
|
||||
mcpSrv.AddResourceTemplate(
|
||||
mcp.NewResourceTemplate(
|
||||
"memo://memos/{uid}",
|
||||
"Memo",
|
||||
mcp.WithTemplateDescription("A single Memos note identified by its UID. Returns the memo content as Markdown with a YAML frontmatter header containing metadata."),
|
||||
mcp.WithTemplateMIMEType("text/markdown"),
|
||||
),
|
||||
s.handleReadMemoResource,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *MCPService) handleReadMemoResource(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
|
||||
// URI format: memo://memos/{uid}
|
||||
uid := strings.TrimPrefix(req.Params.URI, "memo://memos/")
|
||||
if uid == req.Params.URI || uid == "" {
|
||||
return nil, errors.Errorf("invalid memo URI %q: expected memo://memos/<uid>", req.Params.URI)
|
||||
}
|
||||
|
||||
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo")
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, errors.Errorf("memo not found: %s", uid)
|
||||
}
|
||||
if err := checkMemoAccess(memo, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
j := storeMemoToJSON(memo)
|
||||
text := formatMemoMarkdown(j)
|
||||
|
||||
return []mcp.ResourceContents{
|
||||
mcp.TextResourceContents{
|
||||
URI: req.Params.URI,
|
||||
MIMEType: "text/markdown",
|
||||
Text: text,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// formatMemoMarkdown renders a memo as Markdown with a YAML frontmatter header.
|
||||
func formatMemoMarkdown(j memoJSON) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("---\n")
|
||||
fmt.Fprintf(&sb, "name: %s\n", j.Name)
|
||||
fmt.Fprintf(&sb, "creator: %s\n", j.Creator)
|
||||
fmt.Fprintf(&sb, "visibility: %s\n", j.Visibility)
|
||||
fmt.Fprintf(&sb, "state: %s\n", j.State)
|
||||
fmt.Fprintf(&sb, "pinned: %v\n", j.Pinned)
|
||||
if len(j.Tags) > 0 {
|
||||
fmt.Fprintf(&sb, "tags: [%s]\n", strings.Join(j.Tags, ", "))
|
||||
}
|
||||
fmt.Fprintf(&sb, "create_time: %d\n", j.CreateTime)
|
||||
fmt.Fprintf(&sb, "update_time: %d\n", j.UpdateTime)
|
||||
if j.Parent != "" {
|
||||
fmt.Fprintf(&sb, "parent: %s\n", j.Parent)
|
||||
}
|
||||
sb.WriteString("---\n\n")
|
||||
sb.WriteString(j.Content)
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
599
server/router/mcp/tools_memo.go
Normal file
599
server/router/mcp/tools_memo.go
Normal file
@@ -0,0 +1,599 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// tagRegexp matches #tag patterns in memo content.
|
||||
// A tag must start with a letter and contain no whitespace or # characters.
|
||||
var tagRegexp = regexp.MustCompile(`(?:^|\s)#([A-Za-z][^\s#]*)`)
|
||||
|
||||
// extractTags does a best-effort extraction of #tags from raw markdown content.
|
||||
// It is used when creating or updating memos via MCP to pre-populate Payload.Tags.
|
||||
// The full markdown service may later rebuild a more accurate payload.
|
||||
func extractTags(content string) []string {
|
||||
matches := tagRegexp.FindAllStringSubmatch(content, -1)
|
||||
seen := make(map[string]struct{}, len(matches))
|
||||
tags := make([]string, 0, len(matches))
|
||||
for _, m := range matches {
|
||||
tag := m[1]
|
||||
if _, ok := seen[tag]; !ok {
|
||||
seen[tag] = struct{}{}
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
return tags
|
||||
}
|
||||
|
||||
// buildPayload constructs a MemoPayload with tags extracted from content.
|
||||
// Returns nil when no tags are found so the store omits the payload entirely.
|
||||
func buildPayload(content string) *storepb.MemoPayload {
|
||||
tags := extractTags(content)
|
||||
if len(tags) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &storepb.MemoPayload{Tags: tags}
|
||||
}
|
||||
|
||||
// propertyJSON is the serialisable form of MemoPayload.Property.
|
||||
type propertyJSON struct {
|
||||
HasLink bool `json:"has_link"`
|
||||
HasTaskList bool `json:"has_task_list"`
|
||||
HasCode bool `json:"has_code"`
|
||||
HasIncompleteTasks bool `json:"has_incomplete_tasks"`
|
||||
}
|
||||
|
||||
// memoJSON is the canonical response shape for all MCP memo results.
|
||||
// It serialises correctly with standard encoding/json (no proto marshalling needed).
|
||||
type memoJSON struct {
|
||||
Name string `json:"name"`
|
||||
Creator string `json:"creator"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
UpdateTime int64 `json:"update_time"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Visibility string `json:"visibility"`
|
||||
Tags []string `json:"tags"`
|
||||
Pinned bool `json:"pinned"`
|
||||
State string `json:"state"`
|
||||
Property *propertyJSON `json:"property,omitempty"`
|
||||
Parent string `json:"parent,omitempty"`
|
||||
}
|
||||
|
||||
func storeMemoToJSON(m *store.Memo) memoJSON {
|
||||
j := memoJSON{
|
||||
Name: "memos/" + m.UID,
|
||||
Creator: fmt.Sprintf("users/%d", m.CreatorID),
|
||||
CreateTime: m.CreatedTs,
|
||||
UpdateTime: m.UpdatedTs,
|
||||
Content: m.Content,
|
||||
Visibility: string(m.Visibility),
|
||||
Pinned: m.Pinned,
|
||||
State: string(m.RowStatus),
|
||||
Tags: []string{},
|
||||
}
|
||||
if m.Payload != nil {
|
||||
if len(m.Payload.Tags) > 0 {
|
||||
j.Tags = m.Payload.Tags
|
||||
}
|
||||
if p := m.Payload.Property; p != nil && (p.HasLink || p.HasTaskList || p.HasCode || p.HasIncompleteTasks) {
|
||||
j.Property = &propertyJSON{
|
||||
HasLink: p.HasLink,
|
||||
HasTaskList: p.HasTaskList,
|
||||
HasCode: p.HasCode,
|
||||
HasIncompleteTasks: p.HasIncompleteTasks,
|
||||
}
|
||||
}
|
||||
}
|
||||
if m.ParentUID != nil {
|
||||
j.Parent = "memos/" + *m.ParentUID
|
||||
}
|
||||
return j
|
||||
}
|
||||
|
||||
// checkMemoAccess returns an error if the caller cannot read memo.
|
||||
// userID == 0 means anonymous.
|
||||
func checkMemoAccess(memo *store.Memo, userID int32) error {
|
||||
switch memo.Visibility {
|
||||
case store.Protected:
|
||||
if userID == 0 {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
case store.Private:
|
||||
if memo.CreatorID != userID {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
default:
|
||||
// store.Public and any unknown visibility: allow
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyVisibilityFilter restricts find to memos the caller may see.
|
||||
func applyVisibilityFilter(find *store.FindMemo, userID int32) {
|
||||
if userID == 0 {
|
||||
find.VisibilityList = []store.Visibility{store.Public}
|
||||
} else {
|
||||
find.Filters = append(find.Filters, fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, userID))
|
||||
}
|
||||
}
|
||||
|
||||
// parseMemoUID extracts the UID from a "memos/<uid>" resource name.
|
||||
func parseMemoUID(name string) (string, error) {
|
||||
uid, ok := strings.CutPrefix(name, "memos/")
|
||||
if !ok || uid == "" {
|
||||
return "", errors.Errorf(`memo name must be in the format "memos/<uid>", got %q`, name)
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// parseVisibility validates a visibility string and returns the store constant.
|
||||
func parseVisibility(s string) (store.Visibility, error) {
|
||||
switch v := store.Visibility(s); v {
|
||||
case store.Public, store.Protected, store.Private:
|
||||
return v, nil
|
||||
default:
|
||||
return "", errors.Errorf("visibility must be PRIVATE, PROTECTED, or PUBLIC; got %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
// parseRowStatus validates a state string and returns the store constant.
|
||||
func parseRowStatus(s string) (store.RowStatus, error) {
|
||||
switch rs := store.RowStatus(s); rs {
|
||||
case store.Normal, store.Archived:
|
||||
return rs, nil
|
||||
default:
|
||||
return "", errors.Errorf("state must be NORMAL or ARCHIVED; got %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
func extractUserID(ctx context.Context) (int32, error) {
|
||||
id := auth.GetUserID(ctx)
|
||||
if id == 0 {
|
||||
return 0, errors.New("unauthenticated: a personal access token is required")
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func marshalJSON(v any) (string, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
|
||||
mcpSrv.AddTool(mcp.NewTool("list_memos",
|
||||
mcp.WithDescription("List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos."),
|
||||
mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1–100, default 20)")),
|
||||
mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")),
|
||||
mcp.WithString("state",
|
||||
mcp.Enum("NORMAL", "ARCHIVED"),
|
||||
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
|
||||
),
|
||||
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
|
||||
mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
|
||||
), s.handleListMemos)
|
||||
|
||||
mcpSrv.AddTool(mcp.NewTool("get_memo",
|
||||
mcp.WithDescription("Get a single memo by resource name. Public memos are accessible without authentication."),
|
||||
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
|
||||
), s.handleGetMemo)
|
||||
|
||||
mcpSrv.AddTool(mcp.NewTool("create_memo",
|
||||
mcp.WithDescription("Create a new memo. Requires authentication."),
|
||||
mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")),
|
||||
mcp.WithString("visibility",
|
||||
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
|
||||
mcp.Description("Visibility (default: PRIVATE)"),
|
||||
),
|
||||
), s.handleCreateMemo)
|
||||
|
||||
mcpSrv.AddTool(mcp.NewTool("update_memo",
|
||||
mcp.WithDescription("Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged."),
|
||||
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
|
||||
mcp.WithString("content", mcp.Description("New Markdown content")),
|
||||
mcp.WithString("visibility",
|
||||
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
|
||||
mcp.Description("New visibility"),
|
||||
),
|
||||
mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")),
|
||||
mcp.WithString("state",
|
||||
mcp.Enum("NORMAL", "ARCHIVED"),
|
||||
mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"),
|
||||
),
|
||||
), s.handleUpdateMemo)
|
||||
|
||||
mcpSrv.AddTool(mcp.NewTool("delete_memo",
|
||||
mcp.WithDescription("Permanently delete a memo. Requires authentication and ownership."),
|
||||
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
|
||||
), s.handleDeleteMemo)
|
||||
|
||||
mcpSrv.AddTool(mcp.NewTool("search_memos",
|
||||
mcp.WithDescription("Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only."),
|
||||
mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")),
|
||||
), s.handleSearchMemos)
|
||||
|
||||
mcpSrv.AddTool(mcp.NewTool("list_memo_comments",
|
||||
mcp.WithDescription("List comments on a memo. Visibility rules for comments match those of the parent memo."),
|
||||
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
|
||||
), s.handleListMemoComments)
|
||||
|
||||
mcpSrv.AddTool(mcp.NewTool("create_memo_comment",
|
||||
mcp.WithDescription("Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication."),
|
||||
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)),
|
||||
mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")),
|
||||
), s.handleCreateMemoComment)
|
||||
}
|
||||
|
||||
func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
|
||||
pageSize := req.GetInt("page_size", 20)
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
page := req.GetInt("page", 0)
|
||||
if page < 0 {
|
||||
page = 0
|
||||
}
|
||||
|
||||
var rowStatus *store.RowStatus
|
||||
if state := req.GetString("state", "NORMAL"); state != "" {
|
||||
rs, err := parseRowStatus(state)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
rowStatus = &rs
|
||||
}
|
||||
|
||||
limit := pageSize + 1
|
||||
offset := page * pageSize
|
||||
find := &store.FindMemo{
|
||||
ExcludeComments: true,
|
||||
RowStatus: rowStatus,
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
OrderByPinned: req.GetBool("order_by_pinned", false),
|
||||
}
|
||||
applyVisibilityFilter(find, userID)
|
||||
if filter := req.GetString("filter", ""); filter != "" {
|
||||
find.Filters = append(find.Filters, filter)
|
||||
}
|
||||
|
||||
memos, err := s.store.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
|
||||
}
|
||||
|
||||
hasMore := len(memos) > pageSize
|
||||
if hasMore {
|
||||
memos = memos[:pageSize]
|
||||
}
|
||||
|
||||
results := make([]memoJSON, len(memos))
|
||||
for i, m := range memos {
|
||||
results[i] = storeMemoToJSON(m)
|
||||
}
|
||||
|
||||
type listResponse struct {
|
||||
Memos []memoJSON `json:"memos"`
|
||||
HasMore bool `json:"has_more"`
|
||||
}
|
||||
out, err := marshalJSON(listResponse{Memos: results, HasMore: hasMore})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mcp.NewToolResultText(out), nil
|
||||
}
|
||||
|
||||
func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
|
||||
uid, err := parseMemoUID(req.GetString("name", ""))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
|
||||
}
|
||||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if err := checkMemoAccess(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
out, err := marshalJSON(storeMemoToJSON(memo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mcp.NewToolResultText(out), nil
|
||||
}
|
||||
|
||||
func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID, err := extractUserID(ctx)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
content := req.GetString("content", "")
|
||||
if content == "" {
|
||||
return mcp.NewToolResultError("content is required"), nil
|
||||
}
|
||||
visibility, err := parseVisibility(req.GetString("visibility", "PRIVATE"))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
memo, err := s.store.CreateMemo(ctx, &store.Memo{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: userID,
|
||||
Content: content,
|
||||
Visibility: visibility,
|
||||
Payload: buildPayload(content),
|
||||
})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil
|
||||
}
|
||||
|
||||
out, err := marshalJSON(storeMemoToJSON(memo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mcp.NewToolResultText(out), nil
|
||||
}
|
||||
|
||||
func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID, err := extractUserID(ctx)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
uid, err := parseMemoUID(req.GetString("name", ""))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
|
||||
}
|
||||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if memo.CreatorID != userID {
|
||||
return mcp.NewToolResultError("permission denied"), nil
|
||||
}
|
||||
|
||||
update := &store.UpdateMemo{ID: memo.ID}
|
||||
args := req.GetArguments()
|
||||
|
||||
if v := req.GetString("content", ""); v != "" {
|
||||
update.Content = &v
|
||||
update.Payload = buildPayload(v)
|
||||
}
|
||||
if v := req.GetString("visibility", ""); v != "" {
|
||||
vis, err := parseVisibility(v)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
update.Visibility = &vis
|
||||
}
|
||||
if v := req.GetString("state", ""); v != "" {
|
||||
rs, err := parseRowStatus(v)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
update.RowStatus = &rs
|
||||
}
|
||||
if _, ok := args["pinned"]; ok {
|
||||
pinned := req.GetBool("pinned", false)
|
||||
update.Pinned = &pinned
|
||||
}
|
||||
|
||||
if err := s.store.UpdateMemo(ctx, update); err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to update memo: %v", err)), nil
|
||||
}
|
||||
|
||||
updated, err := s.store.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil
|
||||
}
|
||||
|
||||
out, err := marshalJSON(storeMemoToJSON(updated))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mcp.NewToolResultText(out), nil
|
||||
}
|
||||
|
||||
func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID, err := extractUserID(ctx)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
uid, err := parseMemoUID(req.GetString("name", ""))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
|
||||
}
|
||||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if memo.CreatorID != userID {
|
||||
return mcp.NewToolResultError("permission denied"), nil
|
||||
}
|
||||
|
||||
if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil
|
||||
}
|
||||
return mcp.NewToolResultText(`{"deleted":true}`), nil
|
||||
}
|
||||
|
||||
func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
|
||||
query := req.GetString("query", "")
|
||||
if query == "" {
|
||||
return mcp.NewToolResultError("query is required"), nil
|
||||
}
|
||||
|
||||
limit := 50
|
||||
zero := 0
|
||||
rowStatus := store.Normal
|
||||
find := &store.FindMemo{
|
||||
ExcludeComments: true,
|
||||
RowStatus: &rowStatus,
|
||||
Limit: &limit,
|
||||
Offset: &zero,
|
||||
Filters: []string{fmt.Sprintf(`content.contains(%q)`, query)},
|
||||
}
|
||||
applyVisibilityFilter(find, userID)
|
||||
|
||||
memos, err := s.store.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to search memos: %v", err)), nil
|
||||
}
|
||||
|
||||
results := make([]memoJSON, len(memos))
|
||||
for i, m := range memos {
|
||||
results[i] = storeMemoToJSON(m)
|
||||
}
|
||||
out, err := marshalJSON(results)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mcp.NewToolResultText(out), nil
|
||||
}
|
||||
|
||||
func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
|
||||
uid, err := parseMemoUID(req.GetString("name", ""))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
parent, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
|
||||
}
|
||||
if parent == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if err := checkMemoAccess(parent, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
relationType := store.MemoRelationComment
|
||||
relations, err := s.store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
RelatedMemoID: &parent.ID,
|
||||
Type: &relationType,
|
||||
})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to list relations: %v", err)), nil
|
||||
}
|
||||
if len(relations) == 0 {
|
||||
out, _ := marshalJSON([]memoJSON{})
|
||||
return mcp.NewToolResultText(out), nil
|
||||
}
|
||||
|
||||
commentIDs := make([]int32, len(relations))
|
||||
for i, r := range relations {
|
||||
commentIDs[i] = r.MemoID
|
||||
}
|
||||
|
||||
memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: commentIDs})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to list comments: %v", err)), nil
|
||||
}
|
||||
|
||||
results := make([]memoJSON, 0, len(memos))
|
||||
for _, m := range memos {
|
||||
if checkMemoAccess(m, userID) == nil {
|
||||
results = append(results, storeMemoToJSON(m))
|
||||
}
|
||||
}
|
||||
out, err := marshalJSON(results)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mcp.NewToolResultText(out), nil
|
||||
}
|
||||
|
||||
func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID, err := extractUserID(ctx)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
uid, err := parseMemoUID(req.GetString("name", ""))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
content := req.GetString("content", "")
|
||||
if content == "" {
|
||||
return mcp.NewToolResultError("content is required"), nil
|
||||
}
|
||||
|
||||
parent, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
|
||||
}
|
||||
if parent == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if err := checkMemoAccess(parent, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
comment, err := s.store.CreateMemo(ctx, &store.Memo{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: userID,
|
||||
Content: content,
|
||||
Visibility: parent.Visibility,
|
||||
Payload: buildPayload(content),
|
||||
ParentUID: &parent.UID,
|
||||
})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %v", err)), nil
|
||||
}
|
||||
|
||||
if _, err = s.store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: comment.ID,
|
||||
RelatedMemoID: parent.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
}); err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to link comment: %v", err)), nil
|
||||
}
|
||||
|
||||
out, err := marshalJSON(storeMemoToJSON(comment))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mcp.NewToolResultText(out), nil
|
||||
}
|
||||
68
server/router/mcp/tools_tag.go
Normal file
68
server/router/mcp/tools_tag.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) {
|
||||
mcpSrv.AddTool(mcp.NewTool("list_tags",
|
||||
mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."),
|
||||
), s.handleListTags)
|
||||
}
|
||||
|
||||
type tagEntry struct {
|
||||
Tag string `json:"tag"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
|
||||
rowStatus := store.Normal
|
||||
find := &store.FindMemo{
|
||||
ExcludeComments: true,
|
||||
ExcludeContent: true,
|
||||
RowStatus: &rowStatus,
|
||||
}
|
||||
applyVisibilityFilter(find, userID)
|
||||
|
||||
memos, err := s.store.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
|
||||
}
|
||||
|
||||
counts := make(map[string]int)
|
||||
for _, m := range memos {
|
||||
if m.Payload == nil {
|
||||
continue
|
||||
}
|
||||
for _, tag := range m.Payload.Tags {
|
||||
counts[tag]++
|
||||
}
|
||||
}
|
||||
|
||||
entries := make([]tagEntry, 0, len(counts))
|
||||
for tag, count := range counts {
|
||||
entries = append(entries, tagEntry{Tag: tag, Count: count})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
if entries[i].Count != entries[j].Count {
|
||||
return entries[i].Count > entries[j].Count
|
||||
}
|
||||
return entries[i].Tag < entries[j].Tag
|
||||
})
|
||||
|
||||
out, err := marshalJSON(entries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mcp.NewToolResultText(out), nil
|
||||
}
|
||||
423
server/router/rss/rss.go
Normal file
423
server/router/rss/rss.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package rss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/feeds"
|
||||
"github.com/labstack/echo/v5"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/plugin/markdown"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
maxRSSItemCount = 100
|
||||
defaultCacheDuration = 1 * time.Hour
|
||||
maxCacheSize = 50 // Maximum number of cached feeds
|
||||
)
|
||||
|
||||
var (
|
||||
// Regex to match markdown headings at the start of a line.
|
||||
markdownHeadingRegex = regexp.MustCompile(`^#{1,6}\s*`)
|
||||
)
|
||||
|
||||
// cacheEntry represents a cached RSS feed with expiration.
|
||||
type cacheEntry struct {
|
||||
content string
|
||||
etag string
|
||||
lastModified time.Time
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
type RSSService struct {
|
||||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
MarkdownService markdown.Service
|
||||
|
||||
// Cache for RSS feeds
|
||||
cache map[string]*cacheEntry
|
||||
cacheMutex sync.RWMutex
|
||||
}
|
||||
|
||||
type RSSHeading struct {
|
||||
Title string
|
||||
Description string
|
||||
Language string
|
||||
}
|
||||
|
||||
func NewRSSService(profile *profile.Profile, store *store.Store, markdownService markdown.Service) *RSSService {
|
||||
return &RSSService{
|
||||
Profile: profile,
|
||||
Store: store,
|
||||
MarkdownService: markdownService,
|
||||
cache: make(map[string]*cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RSSService) RegisterRoutes(g *echo.Group) {
|
||||
g.GET("/explore/rss.xml", s.GetExploreRSS)
|
||||
g.GET("/u/:username/rss.xml", s.GetUserRSS)
|
||||
}
|
||||
|
||||
func (s *RSSService) GetExploreRSS(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
cacheKey := "explore"
|
||||
|
||||
// Check cache first
|
||||
if cached := s.getFromCache(cacheKey); cached != nil {
|
||||
// Check ETag for conditional request
|
||||
if c.Request().Header.Get("If-None-Match") == cached.etag {
|
||||
return c.NoContent(http.StatusNotModified)
|
||||
}
|
||||
s.setRSSHeaders(c, cached.etag, cached.lastModified)
|
||||
return c.String(http.StatusOK, cached.content)
|
||||
}
|
||||
|
||||
normalStatus := store.Normal
|
||||
limit := maxRSSItemCount
|
||||
memoFind := store.FindMemo{
|
||||
RowStatus: &normalStatus,
|
||||
VisibilityList: []store.Visibility{store.Public},
|
||||
Limit: &limit,
|
||||
}
|
||||
memoList, err := s.Store.ListMemos(ctx, &memoFind)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").Wrap(err)
|
||||
}
|
||||
|
||||
baseURL := c.Scheme() + "://" + c.Request().Host
|
||||
rss, lastModified, err := s.generateRSSFromMemoList(ctx, memoList, baseURL, nil)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate rss").Wrap(err)
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
etag := s.putInCache(cacheKey, rss, lastModified)
|
||||
s.setRSSHeaders(c, etag, lastModified)
|
||||
return c.String(http.StatusOK, rss)
|
||||
}
|
||||
|
||||
func (s *RSSService) GetUserRSS(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
username := c.Param("username")
|
||||
cacheKey := "user:" + username
|
||||
|
||||
// Check cache first
|
||||
if cached := s.getFromCache(cacheKey); cached != nil {
|
||||
// Check ETag for conditional request
|
||||
if c.Request().Header.Get("If-None-Match") == cached.etag {
|
||||
return c.NoContent(http.StatusNotModified)
|
||||
}
|
||||
s.setRSSHeaders(c, cached.etag, cached.lastModified)
|
||||
return c.String(http.StatusOK, cached.content)
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").Wrap(err)
|
||||
}
|
||||
if user == nil {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "User not found")
|
||||
}
|
||||
|
||||
normalStatus := store.Normal
|
||||
limit := maxRSSItemCount
|
||||
memoFind := store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
RowStatus: &normalStatus,
|
||||
VisibilityList: []store.Visibility{store.Public},
|
||||
Limit: &limit,
|
||||
}
|
||||
memoList, err := s.Store.ListMemos(ctx, &memoFind)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").Wrap(err)
|
||||
}
|
||||
|
||||
baseURL := c.Scheme() + "://" + c.Request().Host
|
||||
rss, lastModified, err := s.generateRSSFromMemoList(ctx, memoList, baseURL, user)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate rss").Wrap(err)
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
etag := s.putInCache(cacheKey, rss, lastModified)
|
||||
s.setRSSHeaders(c, etag, lastModified)
|
||||
return c.String(http.StatusOK, rss)
|
||||
}
|
||||
|
||||
func (s *RSSService) generateRSSFromMemoList(ctx context.Context, memoList []*store.Memo, baseURL string, user *store.User) (string, time.Time, error) {
|
||||
rssHeading, err := getRSSHeading(ctx, s.Store)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
feed := &feeds.Feed{
|
||||
Title: rssHeading.Title,
|
||||
Link: &feeds.Link{Href: baseURL},
|
||||
Description: rssHeading.Description,
|
||||
Created: time.Now(),
|
||||
}
|
||||
|
||||
var itemCountLimit = min(len(memoList), maxRSSItemCount)
|
||||
if itemCountLimit == 0 {
|
||||
// Return empty feed if no memos
|
||||
rss, err := feed.ToRss()
|
||||
return rss, time.Time{}, err
|
||||
}
|
||||
|
||||
// Track the most recent update time for Last-Modified header
|
||||
var lastModified time.Time
|
||||
if len(memoList) > 0 {
|
||||
lastModified = time.Unix(memoList[0].UpdatedTs, 0)
|
||||
}
|
||||
|
||||
// Batch load all attachments for all memos to avoid N+1 query problem
|
||||
memoIDs := make([]int32, itemCountLimit)
|
||||
for i := 0; i < itemCountLimit; i++ {
|
||||
memoIDs[i] = memoList[i].ID
|
||||
}
|
||||
|
||||
allAttachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoIDList: memoIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return "", lastModified, err
|
||||
}
|
||||
|
||||
// Group attachments by memo ID for quick lookup
|
||||
attachmentsByMemoID := make(map[int32][]*store.Attachment)
|
||||
for _, attachment := range allAttachments {
|
||||
if attachment.MemoID != nil {
|
||||
attachmentsByMemoID[*attachment.MemoID] = append(attachmentsByMemoID[*attachment.MemoID], attachment)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch load all memo creators
|
||||
creatorMap := make(map[int32]*store.User)
|
||||
if user != nil {
|
||||
// Single user feed - reuse the user object
|
||||
creatorMap[user.ID] = user
|
||||
} else {
|
||||
// Multi-user feed - batch load all unique creators
|
||||
creatorIDs := make(map[int32]bool)
|
||||
for _, memo := range memoList[:itemCountLimit] {
|
||||
creatorIDs[memo.CreatorID] = true
|
||||
}
|
||||
|
||||
// Batch load all users with a single query by getting all users and filtering
|
||||
// Note: This is more efficient than N separate queries
|
||||
for creatorID := range creatorIDs {
|
||||
creator, err := s.Store.GetUser(ctx, &store.FindUser{ID: &creatorID})
|
||||
if err == nil && creator != nil {
|
||||
creatorMap[creatorID] = creator
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate feed items
|
||||
feed.Items = make([]*feeds.Item, itemCountLimit)
|
||||
for i := 0; i < itemCountLimit; i++ {
|
||||
memo := memoList[i]
|
||||
|
||||
// Generate item title from memo content
|
||||
title := s.generateItemTitle(memo.Content)
|
||||
|
||||
// Render content as HTML
|
||||
htmlContent, err := s.getRSSItemDescription(memo.Content)
|
||||
if err != nil {
|
||||
return "", lastModified, err
|
||||
}
|
||||
|
||||
link := &feeds.Link{Href: baseURL + "/memos/" + memo.UID}
|
||||
|
||||
item := &feeds.Item{
|
||||
Title: title,
|
||||
Link: link,
|
||||
Description: htmlContent, // Summary/excerpt
|
||||
Content: htmlContent, // Full content in content:encoded
|
||||
Created: time.Unix(memo.CreatedTs, 0),
|
||||
Updated: time.Unix(memo.UpdatedTs, 0),
|
||||
Id: link.Href,
|
||||
}
|
||||
|
||||
// Add author information
|
||||
if creator, ok := creatorMap[memo.CreatorID]; ok {
|
||||
authorName := creator.Nickname
|
||||
if authorName == "" {
|
||||
authorName = creator.Username
|
||||
}
|
||||
item.Author = &feeds.Author{
|
||||
Name: authorName,
|
||||
Email: creator.Email,
|
||||
}
|
||||
}
|
||||
|
||||
// Note: gorilla/feeds doesn't support categories in RSS items
|
||||
// Tags could be added to the description or content if needed
|
||||
|
||||
// Add first attachment as enclosure
|
||||
if attachments, ok := attachmentsByMemoID[memo.ID]; ok && len(attachments) > 0 {
|
||||
attachment := attachments[0]
|
||||
enclosure := feeds.Enclosure{}
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 {
|
||||
enclosure.Url = attachment.Reference
|
||||
} else {
|
||||
enclosure.Url = fmt.Sprintf("%s/file/attachments/%s/%s", baseURL, attachment.UID, attachment.Filename)
|
||||
}
|
||||
enclosure.Length = strconv.Itoa(int(attachment.Size))
|
||||
enclosure.Type = attachment.Type
|
||||
item.Enclosure = &enclosure
|
||||
}
|
||||
|
||||
feed.Items[i] = item
|
||||
}
|
||||
|
||||
rss, err := feed.ToRss()
|
||||
if err != nil {
|
||||
return "", lastModified, err
|
||||
}
|
||||
return rss, lastModified, nil
|
||||
}
|
||||
|
||||
func (*RSSService) generateItemTitle(content string) string {
|
||||
// Extract first line as title
|
||||
lines := strings.Split(content, "\n")
|
||||
title := strings.TrimSpace(lines[0])
|
||||
|
||||
// Remove markdown heading syntax using regex (handles # to ###### with optional spaces)
|
||||
title = markdownHeadingRegex.ReplaceAllString(title, "")
|
||||
title = strings.TrimSpace(title)
|
||||
|
||||
// Limit title length
|
||||
const maxTitleLength = 100
|
||||
if len(title) > maxTitleLength {
|
||||
// Find last space before limit to avoid cutting words
|
||||
cutoff := maxTitleLength
|
||||
for i := min(maxTitleLength-1, len(title)-1); i > 0; i-- {
|
||||
if title[i] == ' ' {
|
||||
cutoff = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if cutoff < maxTitleLength {
|
||||
title = title[:cutoff] + "..."
|
||||
} else {
|
||||
// No space found, just truncate
|
||||
title = title[:maxTitleLength] + "..."
|
||||
}
|
||||
}
|
||||
|
||||
// If title is empty, use a default
|
||||
if title == "" {
|
||||
title = "Memo"
|
||||
}
|
||||
|
||||
return title
|
||||
}
|
||||
|
||||
func (s *RSSService) getRSSItemDescription(content string) (string, error) {
|
||||
html, err := s.MarkdownService.RenderHTML([]byte(content))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return html, nil
|
||||
}
|
||||
|
||||
// getFromCache retrieves a cached feed entry if it exists and is not expired.
|
||||
func (s *RSSService) getFromCache(key string) *cacheEntry {
|
||||
s.cacheMutex.RLock()
|
||||
entry, exists := s.cache[key]
|
||||
s.cacheMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if cache entry is still valid
|
||||
if time.Since(entry.createdAt) > defaultCacheDuration {
|
||||
// Entry is expired, remove it
|
||||
s.cacheMutex.Lock()
|
||||
delete(s.cache, key)
|
||||
s.cacheMutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
// putInCache stores a feed in the cache and returns its ETag.
|
||||
func (s *RSSService) putInCache(key, content string, lastModified time.Time) string {
|
||||
s.cacheMutex.Lock()
|
||||
defer s.cacheMutex.Unlock()
|
||||
|
||||
// Generate ETag from content hash
|
||||
hash := sha256.Sum256([]byte(content))
|
||||
etag := fmt.Sprintf(`"%x"`, hash[:8])
|
||||
|
||||
// Implement simple LRU: if cache is too large, remove oldest entries
|
||||
if len(s.cache) >= maxCacheSize {
|
||||
var oldestKey string
|
||||
var oldestTime time.Time
|
||||
for k, v := range s.cache {
|
||||
if oldestKey == "" || v.createdAt.Before(oldestTime) {
|
||||
oldestKey = k
|
||||
oldestTime = v.createdAt
|
||||
}
|
||||
}
|
||||
if oldestKey != "" {
|
||||
delete(s.cache, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
s.cache[key] = &cacheEntry{
|
||||
content: content,
|
||||
etag: etag,
|
||||
lastModified: lastModified,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
return etag
|
||||
}
|
||||
|
||||
// setRSSHeaders sets appropriate HTTP headers for RSS responses.
|
||||
func (*RSSService) setRSSHeaders(c *echo.Context, etag string, lastModified time.Time) {
|
||||
c.Response().Header().Set(echo.HeaderContentType, "application/rss+xml; charset=utf-8")
|
||||
c.Response().Header().Set(echo.HeaderCacheControl, fmt.Sprintf("public, max-age=%d", int(defaultCacheDuration.Seconds())))
|
||||
c.Response().Header().Set("ETag", etag)
|
||||
if !lastModified.IsZero() {
|
||||
c.Response().Header().Set("Last-Modified", lastModified.UTC().Format(http.TimeFormat))
|
||||
}
|
||||
}
|
||||
|
||||
func getRSSHeading(ctx context.Context, stores *store.Store) (RSSHeading, error) {
|
||||
settings, err := stores.GetInstanceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return RSSHeading{}, err
|
||||
}
|
||||
if settings == nil || settings.CustomProfile == nil {
|
||||
return RSSHeading{
|
||||
Title: "Memos",
|
||||
Description: "An open source, lightweight note-taking service. Easily capture and share your great thoughts.",
|
||||
Language: "en-us",
|
||||
}, nil
|
||||
}
|
||||
customProfile := settings.CustomProfile
|
||||
|
||||
return RSSHeading{
|
||||
Title: customProfile.Title,
|
||||
Description: customProfile.Description,
|
||||
Language: "en-us",
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user