first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled

This commit is contained in:
2026-03-04 06:30:47 +00:00
commit bb402d4ccc
777 changed files with 135661 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
package v1
// PublicMethods defines API endpoints that don't require authentication.
// All other endpoints require a valid session or access token.
//
// This is the SINGLE SOURCE OF TRUTH for public endpoints.
// Both Connect interceptor and gRPC-Gateway interceptor use this map.
//
// Format: Full gRPC procedure path as returned by req.Spec().Procedure (Connect)
// or info.FullMethod (gRPC interceptor).
var PublicMethods = map[string]struct{}{
// Auth Service - login/token endpoints must be accessible without auth
"/memos.api.v1.AuthService/SignIn": {},
"/memos.api.v1.AuthService/RefreshToken": {}, // Token refresh uses cookie, must be accessible when access token expired
// Instance Service - needed before login to show instance info
"/memos.api.v1.InstanceService/GetInstanceProfile": {},
"/memos.api.v1.InstanceService/GetInstanceSetting": {},
// User Service - public user profiles and stats
"/memos.api.v1.UserService/CreateUser": {}, // Allow first user registration
"/memos.api.v1.UserService/GetUser": {},
"/memos.api.v1.UserService/GetUserAvatar": {},
"/memos.api.v1.UserService/GetUserStats": {},
"/memos.api.v1.UserService/ListAllUserStats": {},
"/memos.api.v1.UserService/SearchUsers": {},
// Identity Provider Service - SSO buttons on login page
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": {},
// Memo Service - public memos (visibility filtering done in service layer)
"/memos.api.v1.MemoService/GetMemo": {},
"/memos.api.v1.MemoService/ListMemos": {},
"/memos.api.v1.MemoService/ListMemoComments": {},
}
// IsPublicMethod checks if a procedure path is public (no authentication required).
// Returns true for public methods, false for protected methods.
func IsPublicMethod(procedure string) bool {
_, ok := PublicMethods[procedure]
return ok
}

View File

@@ -0,0 +1,88 @@
package v1
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestPublicMethodsArePublic verifies that methods in PublicMethods are recognized as public.
func TestPublicMethodsArePublic(t *testing.T) {
publicMethods := []string{
// Auth Service
"/memos.api.v1.AuthService/SignIn",
"/memos.api.v1.AuthService/RefreshToken",
// Instance Service
"/memos.api.v1.InstanceService/GetInstanceProfile",
"/memos.api.v1.InstanceService/GetInstanceSetting",
// User Service
"/memos.api.v1.UserService/CreateUser",
"/memos.api.v1.UserService/GetUser",
"/memos.api.v1.UserService/GetUserAvatar",
"/memos.api.v1.UserService/GetUserStats",
"/memos.api.v1.UserService/ListAllUserStats",
"/memos.api.v1.UserService/SearchUsers",
// Identity Provider Service
"/memos.api.v1.IdentityProviderService/ListIdentityProviders",
// Memo Service
"/memos.api.v1.MemoService/GetMemo",
"/memos.api.v1.MemoService/ListMemos",
}
for _, method := range publicMethods {
t.Run(method, func(t *testing.T) {
assert.True(t, IsPublicMethod(method), "Expected %s to be public", method)
})
}
}
// TestProtectedMethodsRequireAuth verifies that non-public methods are recognized as protected.
func TestProtectedMethodsRequireAuth(t *testing.T) {
protectedMethods := []string{
// Auth Service - logout and get current user require auth
"/memos.api.v1.AuthService/SignOut",
"/memos.api.v1.AuthService/GetCurrentUser",
// Instance Service - admin operations
"/memos.api.v1.InstanceService/UpdateInstanceSetting",
// User Service - modification operations
"/memos.api.v1.UserService/ListUsers",
"/memos.api.v1.UserService/UpdateUser",
"/memos.api.v1.UserService/DeleteUser",
// Memo Service - write operations
"/memos.api.v1.MemoService/CreateMemo",
"/memos.api.v1.MemoService/UpdateMemo",
"/memos.api.v1.MemoService/DeleteMemo",
// Attachment Service - write operations
"/memos.api.v1.AttachmentService/CreateAttachment",
"/memos.api.v1.AttachmentService/DeleteAttachment",
// Shortcut Service
"/memos.api.v1.ShortcutService/CreateShortcut",
"/memos.api.v1.ShortcutService/ListShortcuts",
"/memos.api.v1.ShortcutService/UpdateShortcut",
"/memos.api.v1.ShortcutService/DeleteShortcut",
// Activity Service
"/memos.api.v1.ActivityService/GetActivity",
}
for _, method := range protectedMethods {
t.Run(method, func(t *testing.T) {
assert.False(t, IsPublicMethod(method), "Expected %s to require auth", method)
})
}
}
// TestUnknownMethodsRequireAuth verifies that unknown methods default to requiring auth.
func TestUnknownMethodsRequireAuth(t *testing.T) {
unknownMethods := []string{
"/unknown.Service/Method",
"/memos.api.v1.UnknownService/Method",
"",
"invalid",
}
for _, method := range unknownMethods {
t.Run(method, func(t *testing.T) {
assert.False(t, IsPublicMethod(method), "Unknown method %q should require auth", method)
})
}
}

View File

@@ -0,0 +1,155 @@
package v1
import (
"context"
"fmt"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListActivities(ctx context.Context, request *v1pb.ListActivitiesRequest) (*v1pb.ListActivitiesResponse, error) {
// Set default page size if not specified
pageSize := request.PageSize
if pageSize <= 0 || pageSize > 1000 {
pageSize = 100
}
// TODO: Implement pagination with page_token and use pageSize for limiting
// For now, we'll fetch all activities and the pageSize will be used in future pagination implementation
_ = pageSize // Acknowledge pageSize variable to avoid linter warning
activities, err := s.Store.ListActivities(ctx, &store.FindActivity{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list activities: %v", err)
}
var activityMessages []*v1pb.Activity
for _, activity := range activities {
activityMessage, err := s.convertActivityFromStore(ctx, activity)
if err != nil {
// Skip activities that reference deleted memos instead of failing the entire list
continue
}
if activityMessage != nil {
activityMessages = append(activityMessages, activityMessage)
}
}
return &v1pb.ListActivitiesResponse{
Activities: activityMessages,
// TODO: Implement next_page_token for pagination
}, nil
}
func (s *APIV1Service) GetActivity(ctx context.Context, request *v1pb.GetActivityRequest) (*v1pb.Activity, error) {
activityID, err := ExtractActivityIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid activity name: %v", err)
}
activity, err := s.Store.GetActivity(ctx, &store.FindActivity{
ID: &activityID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get activity: %v", err)
}
activityMessage, err := s.convertActivityFromStore(ctx, activity)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert activity from store: %v", err)
}
if activityMessage == nil {
return nil, status.Errorf(codes.NotFound, "activity references deleted content")
}
return activityMessage, nil
}
// convertActivityFromStore converts a storage-layer activity to an API activity.
// This handles the mapping between internal activity representation and the public API,
// including proper type and level conversions.
// Returns nil if the activity references deleted content (to allow graceful skipping).
func (s *APIV1Service) convertActivityFromStore(ctx context.Context, activity *store.Activity) (*v1pb.Activity, error) {
payload, err := s.convertActivityPayloadFromStore(ctx, activity.Payload)
if err != nil {
return nil, err
}
// Skip activities that reference deleted memos
if payload == nil {
return nil, nil
}
// Convert store activity type to proto enum
var activityType v1pb.Activity_Type
switch activity.Type {
case store.ActivityTypeMemoComment:
activityType = v1pb.Activity_MEMO_COMMENT
default:
activityType = v1pb.Activity_TYPE_UNSPECIFIED
}
// Convert store activity level to proto enum
var activityLevel v1pb.Activity_Level
switch activity.Level {
case store.ActivityLevelInfo:
activityLevel = v1pb.Activity_INFO
default:
activityLevel = v1pb.Activity_LEVEL_UNSPECIFIED
}
return &v1pb.Activity{
Name: fmt.Sprintf("%s%d", ActivityNamePrefix, activity.ID),
Creator: fmt.Sprintf("%s%d", UserNamePrefix, activity.CreatorID),
Type: activityType,
Level: activityLevel,
CreateTime: timestamppb.New(time.Unix(activity.CreatedTs, 0)),
Payload: payload,
}, nil
}
// convertActivityPayloadFromStore converts a storage-layer activity payload to an API payload.
// This resolves references (e.g., memo IDs) to resource names for the API.
// Returns nil if the activity references deleted content (to allow graceful skipping).
func (s *APIV1Service) convertActivityPayloadFromStore(ctx context.Context, payload *storepb.ActivityPayload) (*v1pb.ActivityPayload, error) {
v2Payload := &v1pb.ActivityPayload{}
if payload.MemoComment != nil {
// Fetch the comment memo
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &payload.MemoComment.MemoId,
ExcludeContent: true,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
// If the comment memo was deleted, skip this activity gracefully
if memo == nil {
return nil, nil
}
// Fetch the related memo (the one being commented on)
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &payload.MemoComment.RelatedMemoId,
ExcludeContent: true,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get related memo: %v", err)
}
// If the related memo was deleted, skip this activity gracefully
if relatedMemo == nil {
return nil, nil
}
v2Payload.Payload = &v1pb.ActivityPayload_MemoComment{
MemoComment: &v1pb.ActivityMemoCommentPayload{
Memo: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
RelatedMemo: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
},
}
}
return v2Payload, nil
}

View File

@@ -0,0 +1,397 @@
package v1
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/labstack/echo/v5"
groqai "github.com/usememos/memos/plugin/ai/groq"
ollamaai "github.com/usememos/memos/plugin/ai/ollama"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
const aiSettingsStoreKey = "ai_settings_v1"
// RegisterAIHTTPHandlers registers plain HTTP JSON handlers for the AI service.
// Uses direct raw DB storage to avoid proto generation dependency.
func (s *APIV1Service) RegisterAIHTTPHandlers(g *echo.Group) {
g.POST("/api/ai/settings", s.handleGetAISettings)
g.POST("/api/ai/settings/update", s.handleUpdateAISettings)
g.POST("/api/ai/complete", s.handleGenerateCompletion)
g.POST("/api/ai/test", s.handleTestAIProvider)
g.POST("/api/ai/models/ollama", s.handleListOllamaModels)
}
type aiSettingsJSON struct {
Groq *groqSettingsJSON `json:"groq,omitempty"`
Ollama *ollamaSettingsJSON `json:"ollama,omitempty"`
}
type groqSettingsJSON struct {
Enabled bool `json:"enabled"`
APIKey string `json:"apiKey"`
DefaultModel string `json:"defaultModel"`
}
type ollamaSettingsJSON struct {
Enabled bool `json:"enabled"`
Host string `json:"host"`
DefaultModel string `json:"defaultModel"`
TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
}
func defaultAISettings() *aiSettingsJSON {
return &aiSettingsJSON{
Groq: &groqSettingsJSON{DefaultModel: "llama-3.1-8b-instant"},
Ollama: &ollamaSettingsJSON{Host: "http://localhost:11434", DefaultModel: "llama3", TimeoutSeconds: 120},
}
}
func (s *APIV1Service) loadAISettings(ctx context.Context) (*aiSettingsJSON, error) {
settings := defaultAISettings()
list, err := s.Store.GetDriver().ListInstanceSettings(ctx, &store.FindInstanceSetting{Name: aiSettingsStoreKey})
if err != nil || len(list) == 0 {
return settings, nil
}
_ = json.Unmarshal([]byte(list[0].Value), settings)
return settings, nil
}
func (s *APIV1Service) saveAISettings(ctx context.Context, settings *aiSettingsJSON) error {
b, err := json.Marshal(settings)
if err != nil {
return err
}
_, err = s.Store.GetDriver().UpsertInstanceSetting(ctx, &store.InstanceSetting{
Name: aiSettingsStoreKey,
Value: string(b),
})
return err
}
func (s *APIV1Service) handleGetAISettings(c *echo.Context) error {
ctx := c.Request().Context()
if auth.GetUserID(ctx) == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
settings, err := s.loadAISettings(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, settings)
}
func (s *APIV1Service) handleUpdateAISettings(c *echo.Context) error {
ctx := c.Request().Context()
if auth.GetUserID(ctx) == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
var req aiSettingsJSON
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
if err := s.saveAISettings(ctx, &req); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, req)
}
type generateCompletionRequest struct {
Provider int `json:"provider"`
Prompt string `json:"prompt"`
Model string `json:"model"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"maxTokens"`
AutoTag bool `json:"autoTag,omitempty"`
SpellCheck bool `json:"spellCheck,omitempty"`
}
type generateCompletionResponse struct {
Text string `json:"text"`
ModelUsed string `json:"modelUsed"`
PromptTokens int `json:"promptTokens"`
CompletionTokens int `json:"completionTokens"`
TotalTokens int `json:"totalTokens"`
}
func (s *APIV1Service) handleGenerateCompletion(c *echo.Context) error {
ctx := c.Request().Context()
if auth.GetUserID(ctx) == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
var req generateCompletionRequest
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
if req.Prompt == "" {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "prompt is required"})
}
settings, err := s.loadAISettings(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
temp := float32(0.7)
if req.Temperature > 0 {
temp = float32(req.Temperature)
}
maxTokens := 1000
if req.MaxTokens > 0 {
maxTokens = req.MaxTokens
}
switch req.Provider {
case 1: // Groq
if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Groq is not enabled or API key is missing. Configure it in Settings > AI."})
}
model := req.Model
if model == "" {
model = settings.Groq.DefaultModel
}
if model == "" {
model = "llama-3.1-8b-instant"
}
client := groqai.NewGroqClient(groqai.GroqConfig{
APIKey: settings.Groq.APIKey,
DefaultModel: model,
})
resp, err := client.GenerateCompletion(ctx, groqai.CompletionRequest{
Model: model,
Prompt: req.Prompt,
Temperature: temp,
MaxTokens: maxTokens,
})
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
// Apply post-processing if requested
text := resp.Text
if req.SpellCheck {
text = s.correctSpelling(ctx, text)
}
if req.AutoTag {
text = s.addAutoTags(ctx, text)
}
return c.JSON(http.StatusOK, generateCompletionResponse{
Text: text, ModelUsed: resp.Model,
PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens,
})
case 2: // Ollama
if settings.Ollama == nil || !settings.Ollama.Enabled {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."})
}
model := req.Model
if model == "" {
model = settings.Ollama.DefaultModel
}
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
Host: settings.Ollama.Host,
DefaultModel: model,
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
})
resp, err := client.GenerateCompletion(ctx, ollamaai.CompletionRequest{
Model: model,
Prompt: req.Prompt,
Temperature: temp,
MaxTokens: maxTokens,
})
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
// Apply post-processing if requested
text := resp.Text
if req.SpellCheck {
text = s.correctSpelling(ctx, text)
}
if req.AutoTag {
text = s.addAutoTags(ctx, text)
}
return c.JSON(http.StatusOK, generateCompletionResponse{
Text: text, ModelUsed: resp.Model,
PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens,
})
default:
return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider. Use 1 for Groq, 2 for Ollama"})
}
}
func (s *APIV1Service) handleTestAIProvider(c *echo.Context) error {
ctx := c.Request().Context()
if auth.GetUserID(ctx) == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
var req struct {
Provider int `json:"provider"`
}
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
settings, err := s.loadAISettings(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
switch req.Provider {
case 1:
if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" {
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Groq not enabled or API key missing"})
}
client := groqai.NewGroqClient(groqai.GroqConfig{APIKey: settings.Groq.APIKey, DefaultModel: settings.Groq.DefaultModel})
if err := client.TestConnection(ctx); err != nil {
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()})
}
return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Groq connected successfully"})
case 2:
if settings.Ollama == nil || !settings.Ollama.Enabled {
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Ollama not enabled"})
}
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
Host: settings.Ollama.Host,
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
})
if err := client.TestConnection(ctx); err != nil {
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()})
}
return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Ollama connected successfully"})
default:
return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider"})
}
}
// correctSpelling applies basic spelling corrections to the text
func (s *APIV1Service) correctSpelling(ctx context.Context, text string) string {
// TODO: Implement actual spell checking using a spell checker library
// For now, return the text as-is
return text
}
// addAutoTags analyzes the text and adds relevant #tags
func (s *APIV1Service) addAutoTags(ctx context.Context, text string) string {
// Common programming languages and technologies
tagMap := map[string][]string{
"javascript": {"javascript", "js"},
"python": {"python", "py"},
"java": {"java"},
"go": {"go", "golang"},
"rust": {"rust"},
"dart": {"dart"},
"flutter": {"flutter"},
"react": {"react", "reactjs"},
"vue": {"vue", "vuejs"},
"angular": {"angular"},
"node": {"nodejs", "node"},
"express": {"express"},
"mongodb": {"mongodb"},
"postgresql": {"postgresql", "postgres"},
"mysql": {"mysql"},
"redis": {"redis"},
"docker": {"docker"},
"kubernetes": {"kubernetes", "k8s"},
"aws": {"aws", "amazon"},
"azure": {"azure"},
"gcp": {"gcp", "google-cloud"},
"git": {"git"},
"github": {"github"},
"gitlab": {"gitlab"},
"ci/cd": {"ci-cd"},
"testing": {"testing"},
"api": {"api"},
"rest": {"rest", "api"},
"graphql": {"graphql"},
"database": {"database", "db"},
"frontend": {"frontend"},
"backend": {"backend"},
"fullstack": {"fullstack"},
"mobile": {"mobile"},
"web": {"web"},
"devops": {"devops"},
"security": {"security"},
"machine learning": {"ml", "machine-learning"},
"ai": {"ai", "artificial-intelligence"},
"blockchain": {"blockchain"},
"crypto": {"crypto", "cryptocurrency"},
}
// Convert text to lowercase for matching
lowerText := strings.ToLower(text)
var tags []string
tagSet := make(map[string]bool)
// Check for programming languages and technologies
for keyword, tagList := range tagMap {
if strings.Contains(lowerText, keyword) {
for _, tag := range tagList {
if !tagSet[tag] {
tags = append(tags, tag)
tagSet[tag] = true
}
}
}
}
// Add tags to the end of the text if any were found
if len(tags) > 0 {
tagString := "\n\n"
for i, tag := range tags {
if i > 0 {
tagString += " "
}
tagString += "#" + tag
}
text += tagString
}
return text
}
func (s *APIV1Service) handleListOllamaModels(c *echo.Context) error {
ctx := c.Request().Context()
if auth.GetUserID(ctx) == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
settings, err := s.loadAISettings(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if settings.Ollama == nil || !settings.Ollama.Enabled {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."})
}
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
Host: settings.Ollama.Host,
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
})
models, err := client.ListModels(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": fmt.Sprintf("failed to list Ollama models: %v", err)})
}
// Convert to the same format as frontend expects
var modelList []map[string]string
for _, model := range models {
modelList = append(modelList, map[string]string{
"id": model,
"name": model,
})
}
return c.JSON(http.StatusOK, map[string]interface{}{
"models": modelList,
})
}

View File

@@ -0,0 +1,191 @@
package v1
import (
"bytes"
"image"
"image/color"
"image/jpeg"
"testing"
"github.com/disintegration/imaging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShouldStripExif(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mimeType string
expected bool
}{
{
name: "JPEG should strip EXIF",
mimeType: "image/jpeg",
expected: true,
},
{
name: "JPG should strip EXIF",
mimeType: "image/jpg",
expected: true,
},
{
name: "TIFF should strip EXIF",
mimeType: "image/tiff",
expected: true,
},
{
name: "WebP should strip EXIF",
mimeType: "image/webp",
expected: true,
},
{
name: "HEIC should strip EXIF",
mimeType: "image/heic",
expected: true,
},
{
name: "HEIF should strip EXIF",
mimeType: "image/heif",
expected: true,
},
{
name: "PNG should not strip EXIF",
mimeType: "image/png",
expected: false,
},
{
name: "GIF should not strip EXIF",
mimeType: "image/gif",
expected: false,
},
{
name: "text file should not strip EXIF",
mimeType: "text/plain",
expected: false,
},
{
name: "PDF should not strip EXIF",
mimeType: "application/pdf",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := shouldStripExif(tt.mimeType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestStripImageExif(t *testing.T) {
t.Parallel()
// Create a simple test image
img := image.NewRGBA(image.Rect(0, 0, 100, 100))
// Fill with red color
for y := 0; y < 100; y++ {
for x := 0; x < 100; x++ {
img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255})
}
}
// Encode as JPEG
var buf bytes.Buffer
err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 90})
require.NoError(t, err)
originalData := buf.Bytes()
t.Run("strip JPEG metadata", func(t *testing.T) {
t.Parallel()
strippedData, err := stripImageExif(originalData, "image/jpeg")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's still a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.Equal(t, 100, decodedImg.Bounds().Dx())
assert.Equal(t, 100, decodedImg.Bounds().Dy())
})
t.Run("strip JPG metadata (alternate extension)", func(t *testing.T) {
t.Parallel()
strippedData, err := stripImageExif(originalData, "image/jpg")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's still a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.NotNil(t, decodedImg)
})
t.Run("strip PNG metadata", func(t *testing.T) {
t.Parallel()
// Encode as PNG first
var pngBuf bytes.Buffer
err := imaging.Encode(&pngBuf, img, imaging.PNG)
require.NoError(t, err)
strippedData, err := stripImageExif(pngBuf.Bytes(), "image/png")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's still a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.Equal(t, 100, decodedImg.Bounds().Dx())
assert.Equal(t, 100, decodedImg.Bounds().Dy())
})
t.Run("handle WebP format by converting to JPEG", func(t *testing.T) {
t.Parallel()
// WebP format will be converted to JPEG
strippedData, err := stripImageExif(originalData, "image/webp")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.NotNil(t, decodedImg)
})
t.Run("handle HEIC format by converting to JPEG", func(t *testing.T) {
t.Parallel()
strippedData, err := stripImageExif(originalData, "image/heic")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.NotNil(t, decodedImg)
})
t.Run("return error for invalid image data", func(t *testing.T) {
t.Parallel()
invalidData := []byte("not an image")
_, err := stripImageExif(invalidData, "image/jpeg")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode image")
})
t.Run("return error for empty image data", func(t *testing.T) {
t.Parallel()
emptyData := []byte{}
_, err := stripImageExif(emptyData, "image/jpeg")
assert.Error(t, err)
})
}

View File

@@ -0,0 +1,646 @@
package v1
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"log/slog"
"mime"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/disintegration/imaging"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/filter"
"github.com/usememos/memos/plugin/storage/s3"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
const (
// The upload memory buffer is 32 MiB.
// It should be kept low, so RAM usage doesn't get out of control.
// This is unrelated to maximum upload size limit, which is now set through system setting.
MaxUploadBufferSizeBytes = 32 << 20
MebiByte = 1024 * 1024
// ThumbnailCacheFolder is the folder name where the thumbnail images are stored.
ThumbnailCacheFolder = ".thumbnail_cache"
// defaultJPEGQuality is the JPEG quality used when re-encoding images for EXIF stripping.
// Quality 95 maintains visual quality while ensuring metadata is removed.
defaultJPEGQuality = 95
)
var SupportedThumbnailMimeTypes = []string{
"image/png",
"image/jpeg",
}
// exifCapableImageTypes defines image formats that may contain EXIF metadata.
// These formats will have their EXIF metadata stripped on upload for privacy.
var exifCapableImageTypes = map[string]bool{
"image/jpeg": true,
"image/jpg": true,
"image/tiff": true,
"image/webp": true,
"image/heic": true,
"image/heif": true,
}
func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.CreateAttachmentRequest) (*v1pb.Attachment, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Validate required fields
if request.Attachment == nil {
return nil, status.Errorf(codes.InvalidArgument, "attachment is required")
}
if request.Attachment.Filename == "" {
return nil, status.Errorf(codes.InvalidArgument, "filename is required")
}
if !validateFilename(request.Attachment.Filename) {
return nil, status.Errorf(codes.InvalidArgument, "filename contains invalid characters or format")
}
if request.Attachment.Type == "" {
ext := filepath.Ext(request.Attachment.Filename)
mimeType := mime.TypeByExtension(ext)
if mimeType == "" {
mimeType = http.DetectContentType(request.Attachment.Content)
}
// ParseMediaType to strip parameters
mediaType, _, err := mime.ParseMediaType(mimeType)
if err == nil {
request.Attachment.Type = mediaType
}
}
if request.Attachment.Type == "" {
request.Attachment.Type = "application/octet-stream"
}
if !isValidMimeType(request.Attachment.Type) {
return nil, status.Errorf(codes.InvalidArgument, "invalid MIME type format")
}
// Use provided attachment_id or generate a new one
attachmentUID := request.AttachmentId
if attachmentUID == "" {
attachmentUID = shortuuid.New()
}
create := &store.Attachment{
UID: attachmentUID,
CreatorID: user.ID,
Filename: request.Attachment.Filename,
Type: request.Attachment.Type,
}
// No upload size limit - accept files of any size
size := binary.Size(request.Attachment.Content)
create.Size = int64(size)
create.Blob = request.Attachment.Content
// Strip EXIF metadata from images for privacy protection.
// This removes sensitive information like GPS location, device details, etc.
if shouldStripExif(create.Type) {
if strippedBlob, err := stripImageExif(create.Blob, create.Type); err != nil {
// Log warning but continue with original image to ensure uploads don't fail.
slog.Warn("failed to strip EXIF metadata from image",
slog.String("type", create.Type),
slog.String("filename", create.Filename),
slog.String("error", err.Error()))
} else {
create.Blob = strippedBlob
create.Size = int64(len(strippedBlob))
}
}
if err := SaveAttachmentBlob(ctx, s.Profile, s.Store, create); err != nil {
return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err)
}
if request.Attachment.Memo != nil {
memoUID, err := ExtractMemoUIDFromName(*request.Attachment.Memo)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found: %s", *request.Attachment.Memo)
}
create.MemoID = &memo.ID
}
attachment, err := s.Store.CreateAttachment(ctx, create)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create attachment: %v", err)
}
return convertAttachmentFromStore(attachment), nil
}
func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAttachmentsRequest) (*v1pb.ListAttachmentsResponse, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Set default page size
pageSize := int(request.PageSize)
if pageSize <= 0 {
pageSize = 50
}
if pageSize > 1000 {
pageSize = 1000
}
// Parse page token for offset
offset := 0
if request.PageToken != "" {
// Simple implementation: page token is the offset as string
// In production, you might want to use encrypted tokens
if parsed, err := fmt.Sscanf(request.PageToken, "%d", &offset); err != nil || parsed != 1 {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token")
}
}
findAttachment := &store.FindAttachment{
CreatorID: &user.ID,
Limit: &pageSize,
Offset: &offset,
}
// Parse filter if provided
if request.Filter != "" {
if err := s.validateAttachmentFilter(ctx, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
findAttachment.Filters = append(findAttachment.Filters, request.Filter)
}
attachments, err := s.Store.ListAttachments(ctx, findAttachment)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err)
}
response := &v1pb.ListAttachmentsResponse{}
for _, attachment := range attachments {
response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment))
}
// For simplicity, set total size to the number of returned attachments.
// In a full implementation, you'd want a separate count query
response.TotalSize = int32(len(response.Attachments))
// Set next page token if we got the full page size (indicating there might be more)
if len(attachments) == pageSize {
response.NextPageToken = fmt.Sprintf("%d", offset+pageSize)
}
return response, nil
}
func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttachmentRequest) (*v1pb.Attachment, error) {
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
// Check access permission based on linked memo visibility.
if err := s.checkAttachmentAccess(ctx, attachment); err != nil {
return nil, err
}
return convertAttachmentFromStore(attachment), nil
}
func (s *APIV1Service) UpdateAttachment(ctx context.Context, request *v1pb.UpdateAttachmentRequest) (*v1pb.Attachment, error) {
attachmentUID, err := ExtractAttachmentUIDFromName(request.Attachment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
// Only the creator or admin can update the attachment.
if attachment.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
currentTs := time.Now().Unix()
update := &store.UpdateAttachment{
ID: attachment.ID,
UpdatedTs: &currentTs,
}
for _, field := range request.UpdateMask.Paths {
if field == "filename" {
if !validateFilename(request.Attachment.Filename) {
return nil, status.Errorf(codes.InvalidArgument, "filename contains invalid characters or format")
}
update.Filename = &request.Attachment.Filename
}
}
if err := s.Store.UpdateAttachment(ctx, update); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
}
return s.GetAttachment(ctx, &v1pb.GetAttachmentRequest{
Name: request.Attachment.Name,
})
}
func (s *APIV1Service) DeleteAttachment(ctx context.Context, request *v1pb.DeleteAttachmentRequest) (*emptypb.Empty, error) {
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
UID: &attachmentUID,
CreatorID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find attachment: %v", err)
}
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
// Delete the attachment from the database.
if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{
ID: attachment.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete attachment: %v", err)
}
return &emptypb.Empty{}, nil
}
func convertAttachmentFromStore(attachment *store.Attachment) *v1pb.Attachment {
attachmentMessage := &v1pb.Attachment{
Name: fmt.Sprintf("%s%s", AttachmentNamePrefix, attachment.UID),
CreateTime: timestamppb.New(time.Unix(attachment.CreatedTs, 0)),
Filename: attachment.Filename,
Type: attachment.Type,
Size: attachment.Size,
}
if attachment.MemoUID != nil && *attachment.MemoUID != "" {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, *attachment.MemoUID)
attachmentMessage.Memo = &memoName
}
if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 {
attachmentMessage.ExternalLink = attachment.Reference
}
return attachmentMessage
}
// SaveAttachmentBlob save the blob of attachment based on the storage config.
func SaveAttachmentBlob(ctx context.Context, profile *profile.Profile, stores *store.Store, create *store.Attachment) error {
instanceStorageSetting, err := stores.GetInstanceStorageSetting(ctx)
if err != nil {
return errors.Wrap(err, "Failed to find instance storage setting")
}
if instanceStorageSetting.StorageType == storepb.InstanceStorageSetting_LOCAL {
filepathTemplate := "assets/{timestamp}_{filename}"
if instanceStorageSetting.FilepathTemplate != "" {
filepathTemplate = instanceStorageSetting.FilepathTemplate
}
internalPath := filepathTemplate
if !strings.Contains(internalPath, "{filename}") {
internalPath = filepath.Join(internalPath, "{filename}")
}
internalPath = replaceFilenameWithPathTemplate(internalPath, create.Filename)
internalPath = filepath.ToSlash(internalPath)
// Ensure the directory exists.
osPath := filepath.FromSlash(internalPath)
if !filepath.IsAbs(osPath) {
osPath = filepath.Join(profile.Data, osPath)
}
dir := filepath.Dir(osPath)
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
return errors.Wrap(err, "Failed to create directory")
}
// Write the blob to the file.
if err := os.WriteFile(osPath, create.Blob, 0644); err != nil {
return errors.Wrap(err, "Failed to write file")
}
create.Reference = internalPath
create.Blob = nil
create.StorageType = storepb.AttachmentStorageType_LOCAL
} else if instanceStorageSetting.StorageType == storepb.InstanceStorageSetting_S3 {
s3Config := instanceStorageSetting.S3Config
if s3Config == nil {
return errors.Errorf("No activated external storage found")
}
s3Client, err := s3.NewClient(ctx, s3Config)
if err != nil {
return errors.Wrap(err, "Failed to create s3 client")
}
filepathTemplate := instanceStorageSetting.FilepathTemplate
if !strings.Contains(filepathTemplate, "{filename}") {
filepathTemplate = filepath.Join(filepathTemplate, "{filename}")
}
filepathTemplate = replaceFilenameWithPathTemplate(filepathTemplate, create.Filename)
key, err := s3Client.UploadObject(ctx, filepathTemplate, create.Type, bytes.NewReader(create.Blob))
if err != nil {
return errors.Wrap(err, "Failed to upload via s3 client")
}
presignURL, err := s3Client.PresignGetObject(ctx, key)
if err != nil {
return errors.Wrap(err, "Failed to presign via s3 client")
}
create.Reference = presignURL
create.Blob = nil
create.StorageType = storepb.AttachmentStorageType_S3
create.Payload = &storepb.AttachmentPayload{
Payload: &storepb.AttachmentPayload_S3Object_{
S3Object: &storepb.AttachmentPayload_S3Object{
S3Config: s3Config,
Key: key,
LastPresignedTime: timestamppb.New(time.Now()),
},
},
}
}
return nil
}
func (s *APIV1Service) GetAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
// For local storage, read the file from the local disk.
if attachment.StorageType == storepb.AttachmentStorageType_LOCAL {
attachmentPath := filepath.FromSlash(attachment.Reference)
if !filepath.IsAbs(attachmentPath) {
attachmentPath = filepath.Join(s.Profile.Data, attachmentPath)
}
file, err := os.Open(attachmentPath)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.Wrap(err, "file not found")
}
return nil, errors.Wrap(err, "failed to open the file")
}
defer file.Close()
blob, err := io.ReadAll(file)
if err != nil {
return nil, errors.Wrap(err, "failed to read the file")
}
return blob, nil
}
// For S3 storage, download the file from S3.
if attachment.StorageType == storepb.AttachmentStorageType_S3 {
if attachment.Payload == nil {
return nil, errors.New("attachment payload is missing")
}
s3Object := attachment.Payload.GetS3Object()
if s3Object == nil {
return nil, errors.New("S3 object payload is missing")
}
if s3Object.S3Config == nil {
return nil, errors.New("S3 config is missing")
}
if s3Object.Key == "" {
return nil, errors.New("S3 object key is missing")
}
s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config)
if err != nil {
return nil, errors.Wrap(err, "failed to create S3 client")
}
blob, err := s3Client.GetObject(context.Background(), s3Object.Key)
if err != nil {
return nil, errors.Wrap(err, "failed to get object from S3")
}
return blob, nil
}
// For database storage, return the blob from the database.
return attachment.Blob, nil
}
var fileKeyPattern = regexp.MustCompile(`\{[a-z]{1,9}\}`)
func replaceFilenameWithPathTemplate(path, filename string) string {
t := time.Now()
path = fileKeyPattern.ReplaceAllStringFunc(path, func(s string) string {
switch s {
case "{filename}":
return filename
case "{timestamp}":
return fmt.Sprintf("%d", t.Unix())
case "{year}":
return fmt.Sprintf("%d", t.Year())
case "{month}":
return fmt.Sprintf("%02d", t.Month())
case "{day}":
return fmt.Sprintf("%02d", t.Day())
case "{hour}":
return fmt.Sprintf("%02d", t.Hour())
case "{minute}":
return fmt.Sprintf("%02d", t.Minute())
case "{second}":
return fmt.Sprintf("%02d", t.Second())
case "{uuid}":
return util.GenUUID()
default:
return s
}
})
return path
}
func validateFilename(filename string) bool {
// Reject path traversal attempts and make sure no additional directories are created
if !filepath.IsLocal(filename) || strings.ContainsAny(filename, "/\\") {
return false
}
// Reject filenames starting or ending with spaces or periods
if strings.HasPrefix(filename, " ") || strings.HasSuffix(filename, " ") ||
strings.HasPrefix(filename, ".") || strings.HasSuffix(filename, ".") {
return false
}
return true
}
func isValidMimeType(mimeType string) bool {
// Reject empty or excessively long MIME types
if mimeType == "" || len(mimeType) > 255 {
return false
}
// MIME type must match the pattern: type/subtype
// Allow common characters in MIME types per RFC 2045
matched, _ := regexp.MatchString(`^[a-zA-Z0-9][a-zA-Z0-9!#$&^_.+-]{0,126}/[a-zA-Z0-9][a-zA-Z0-9!#$&^_.+-]{0,126}$`, mimeType)
return matched
}
func (s *APIV1Service) validateAttachmentFilter(ctx context.Context, filterStr string) error {
if filterStr == "" {
return errors.New("filter cannot be empty")
}
engine, err := filter.DefaultAttachmentEngine()
if err != nil {
return err
}
var dialect filter.DialectName
switch s.Profile.Driver {
case "mysql":
dialect = filter.DialectMySQL
case "postgres":
dialect = filter.DialectPostgres
default:
dialect = filter.DialectSQLite
}
if _, err := engine.CompileToStatement(ctx, filterStr, filter.RenderOptions{Dialect: dialect}); err != nil {
return errors.Wrap(err, "failed to compile filter")
}
return nil
}
// checkAttachmentAccess verifies the user has permission to access the attachment.
// For unlinked attachments (no memo), only the creator can access.
// For linked attachments, access follows the memo's visibility rules.
func (s *APIV1Service) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment) error {
user, _ := s.fetchCurrentUser(ctx)
// For unlinked attachments, only the creator can access.
if attachment.MemoID == nil {
if user == nil {
return status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if attachment.CreatorID != user.ID && !isSuperUser(user) {
return status.Errorf(codes.PermissionDenied, "permission denied")
}
return nil
}
// For linked attachments, check memo visibility.
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
if err != nil {
return status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility == store.Public {
return nil
}
if user == nil {
return status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return status.Errorf(codes.PermissionDenied, "permission denied")
}
return nil
}
// shouldStripExif checks if the MIME type is an image format that may contain EXIF metadata.
// Returns true for formats like JPEG, TIFF, WebP, HEIC, and HEIF which commonly contain
// privacy-sensitive metadata such as GPS coordinates, camera settings, and device information.
func shouldStripExif(mimeType string) bool {
return exifCapableImageTypes[mimeType]
}
// stripImageExif removes EXIF metadata from image files by decoding and re-encoding them.
// This prevents exposure of sensitive metadata such as GPS location, camera details, and timestamps.
//
// The function preserves the correct image orientation by applying EXIF orientation tags
// during decoding before stripping all metadata. Images are re-encoded with high quality
// to minimize visual degradation.
//
// Supported formats:
// - JPEG/JPG: Re-encoded as JPEG with quality 95
// - PNG: Re-encoded as PNG (lossless)
// - TIFF/WebP/HEIC/HEIF: Re-encoded as JPEG with quality 95
//
// Returns the cleaned image data without any EXIF metadata, or an error if processing fails.
func stripImageExif(imageData []byte, mimeType string) ([]byte, error) {
// Decode image with automatic EXIF orientation correction.
// This ensures the image displays correctly after metadata removal.
img, err := imaging.Decode(bytes.NewReader(imageData), imaging.AutoOrientation(true))
if err != nil {
return nil, errors.Wrap(err, "failed to decode image")
}
// Re-encode the image without EXIF metadata.
var buf bytes.Buffer
var encodeErr error
if mimeType == "image/png" {
// Preserve PNG format for lossless encoding
encodeErr = imaging.Encode(&buf, img, imaging.PNG)
} else {
// For JPEG, TIFF, WebP, HEIC, HEIF - re-encode as JPEG.
// This ensures EXIF is stripped and provides good compression.
encodeErr = imaging.Encode(&buf, img, imaging.JPEG, imaging.JPEGQuality(defaultJPEGQuality))
}
if encodeErr != nil {
return nil, errors.Wrap(encodeErr, "failed to encode image")
}
return buf.Bytes(), nil
}

View File

@@ -0,0 +1,612 @@
package v1
import (
"context"
"fmt"
"log/slog"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
const (
unmatchedUsernameAndPasswordError = "unmatched username and password"
)
// GetCurrentUser returns the authenticated user's information.
// Validates the access token and returns user details.
//
// Authentication: Required (access token).
// Returns: User information.
func (s *APIV1Service) GetCurrentUser(ctx context.Context, _ *v1pb.GetCurrentUserRequest) (*v1pb.GetCurrentUserResponse, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
}
if user == nil {
// Clear auth cookies
if err := s.clearAuthCookies(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies: %v", err)
}
return nil, status.Errorf(codes.Unauthenticated, "user not found")
}
return &v1pb.GetCurrentUserResponse{
User: convertUserFromStore(user),
}, nil
}
// SignIn authenticates a user with credentials and returns tokens.
// On success, returns an access token and sets a refresh token cookie.
//
// Supports two authentication methods:
// 1. Password-based authentication (username + password).
// 2. SSO authentication (OAuth2 authorization code).
//
// Authentication: Not required (public endpoint).
// Returns: User info, access token, and token expiry.
func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.SignInResponse, error) {
var existingUser *store.User
// Authentication Method 1: Password-based authentication
if passwordCredentials := request.GetPasswordCredentials(); passwordCredentials != nil {
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &passwordCredentials.Username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
}
// Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(passwordCredentials.Password)); err != nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
}
instanceGeneralSetting, err := s.Store.GetInstanceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance general setting, error: %v", err)
}
// Check if the password auth in is allowed.
if instanceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser {
return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed")
}
existingUser = user
} else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil {
// Authentication Method 2: SSO (OAuth2) authentication
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &ssoCredentials.IdpId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.InvalidArgument, "identity provider not found")
}
var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
}
// Pass code_verifier for PKCE support (empty string if not provided for backward compatibility)
token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code, ssoCredentials.CodeVerifier)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
}
userInfo, err = oauth2IdentityProvider.UserInfo(token)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user info, error: %v", err)
}
}
identifierFilter := identityProvider.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to compile identifier filter regex, error: %v", err)
}
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
return nil, status.Errorf(codes.PermissionDenied, "identifier %s is not allowed", userInfo.Identifier)
}
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
}
if user == nil {
// Check if the user is allowed to sign up.
instanceGeneralSetting, err := s.Store.GetInstanceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance general setting, error: %v", err)
}
if instanceGeneralSetting.DisallowUserRegistration {
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
}
// Create a new user with the user info from the identity provider.
userCreate := &store.User{
Username: userInfo.Identifier,
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
AvatarURL: userInfo.AvatarURL,
}
password, err := util.RandomString(20)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate random password, error: %v", err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err)
}
userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err)
}
}
existingUser = user
}
if existingUser == nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid credentials")
}
if existingUser.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", existingUser.Username)
}
accessToken, accessExpiresAt, err := s.doSignIn(ctx, existingUser)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to sign in: %v", err)
}
return &v1pb.SignInResponse{
User: convertUserFromStore(existingUser),
AccessToken: accessToken,
AccessTokenExpiresAt: timestamppb.New(accessExpiresAt),
}, nil
}
// doSignIn performs the actual sign-in operation by creating a session and setting the cookie.
//
// This function:
// 1. Generates refresh token and access token.
// 2. Stores refresh token metadata in user_setting.
// 3. Sets refresh token as HttpOnly cookie.
// 4. Returns access token and its expiry time.
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User) (string, time.Time, error) {
// Generate refresh token
tokenID := util.GenUUID()
refreshToken, refreshExpiresAt, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(s.Secret))
if err != nil {
return "", time.Time{}, status.Errorf(codes.Internal, "failed to generate refresh token: %v", err)
}
// Store refresh token metadata
clientInfo := s.extractClientInfo(ctx)
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(refreshExpiresAt),
CreatedAt: timestamppb.Now(),
ClientInfo: clientInfo,
}
if err := s.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord); err != nil {
slog.Error("failed to store refresh token", "error", err)
}
// Set refresh token cookie
refreshCookie := s.buildRefreshTokenCookie(ctx, refreshToken, refreshExpiresAt)
if err := SetResponseHeader(ctx, "Set-Cookie", refreshCookie); err != nil {
return "", time.Time{}, status.Errorf(codes.Internal, "failed to set refresh token cookie: %v", err)
}
// Generate access token
accessToken, accessExpiresAt, err := auth.GenerateAccessTokenV2(
user.ID,
user.Username,
string(user.Role),
string(user.RowStatus),
[]byte(s.Secret),
)
if err != nil {
return "", time.Time{}, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}
return accessToken, accessExpiresAt, nil
}
// SignOut terminates the user's authentication.
// Revokes the refresh token and clears the authentication cookie.
//
// Authentication: Required (access token).
// Returns: Empty response on success.
func (s *APIV1Service) SignOut(ctx context.Context, _ *v1pb.SignOutRequest) (*emptypb.Empty, error) {
// Get user from access token claims
claims := auth.GetUserClaims(ctx)
if claims != nil {
// Revoke refresh token if we can identify it
refreshToken := ""
if md, ok := metadata.FromIncomingContext(ctx); ok {
if cookies := md.Get("cookie"); len(cookies) > 0 {
refreshToken = auth.ExtractRefreshTokenFromCookie(cookies[0])
}
}
if refreshToken != "" {
refreshClaims, err := auth.ParseRefreshToken(refreshToken, []byte(s.Secret))
if err == nil {
// Remove refresh token from user_setting by token_id
_ = s.Store.RemoveUserRefreshToken(ctx, claims.UserID, refreshClaims.TokenID)
}
}
}
// Clear refresh token cookie
if err := s.clearAuthCookies(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies, error: %v", err)
}
return &emptypb.Empty{}, nil
}
// RefreshToken exchanges a valid refresh token for a new access token.
//
// This endpoint implements refresh token rotation with sliding window sessions:
// 1. Extracts the refresh token from the HttpOnly cookie (memos_refresh)
// 2. Validates the refresh token against the database (checking expiry and revocation)
// 3. Rotates the refresh token: generates a new one with fresh 30-day expiry
// 4. Generates a new short-lived access token (15 minutes)
// 5. Sets the new refresh token as HttpOnly cookie
// 6. Returns the new access token and its expiry time
//
// Token rotation provides:
// - Sliding window sessions: active users stay logged in indefinitely
// - Better security: stolen refresh tokens become invalid after legitimate refresh
//
// Authentication: Requires valid refresh token in cookie (public endpoint)
// Returns: New access token and expiry timestamp.
func (s *APIV1Service) RefreshToken(ctx context.Context, _ *v1pb.RefreshTokenRequest) (*v1pb.RefreshTokenResponse, error) {
// Extract refresh token from cookie
refreshToken := ""
if md, ok := metadata.FromIncomingContext(ctx); ok {
if cookies := md.Get("cookie"); len(cookies) > 0 {
refreshToken = auth.ExtractRefreshTokenFromCookie(cookies[0])
}
}
if refreshToken == "" {
return nil, status.Errorf(codes.Unauthenticated, "refresh token not found")
}
// Validate refresh token and get old token ID for rotation
authenticator := auth.NewAuthenticator(s.Store, s.Secret)
user, oldTokenID, err := authenticator.AuthenticateByRefreshToken(ctx, refreshToken)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid refresh token: %v", err)
}
// --- Refresh Token Rotation ---
// Generate new refresh token with fresh 30-day expiry (sliding window)
newTokenID := util.GenUUID()
newRefreshToken, newRefreshExpiresAt, err := auth.GenerateRefreshToken(user.ID, newTokenID, []byte(s.Secret))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate refresh token: %v", err)
}
// Store new refresh token (add before remove to handle race conditions)
clientInfo := s.extractClientInfo(ctx)
newRefreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: newTokenID,
ExpiresAt: timestamppb.New(newRefreshExpiresAt),
CreatedAt: timestamppb.Now(),
ClientInfo: clientInfo,
}
if err := s.Store.AddUserRefreshToken(ctx, user.ID, newRefreshTokenRecord); err != nil {
return nil, status.Errorf(codes.Internal, "failed to store refresh token: %v", err)
}
// Remove old refresh token
if err := s.Store.RemoveUserRefreshToken(ctx, user.ID, oldTokenID); err != nil {
// Log but don't fail - old token will expire naturally
slog.Warn("failed to remove old refresh token", "error", err, "userID", user.ID, "tokenID", oldTokenID)
}
// Set new refresh token cookie
newRefreshCookie := s.buildRefreshTokenCookie(ctx, newRefreshToken, newRefreshExpiresAt)
if err := SetResponseHeader(ctx, "Set-Cookie", newRefreshCookie); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set refresh token cookie: %v", err)
}
// --- End Rotation ---
// Generate new access token
accessToken, expiresAt, err := auth.GenerateAccessTokenV2(
user.ID,
user.Username,
string(user.Role),
string(user.RowStatus),
[]byte(s.Secret),
)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}
return &v1pb.RefreshTokenResponse{
AccessToken: accessToken,
ExpiresAt: timestamppb.New(expiresAt),
}, nil
}
func (s *APIV1Service) clearAuthCookies(ctx context.Context) error {
// Clear refresh token cookie
refreshCookie := s.buildRefreshTokenCookie(ctx, "", time.Time{})
if err := SetResponseHeader(ctx, "Set-Cookie", refreshCookie); err != nil {
return errors.Wrap(err, "failed to set refresh cookie")
}
return nil
}
func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken string, expireTime time.Time) string {
attrs := []string{
fmt.Sprintf("%s=%s", auth.RefreshTokenCookieName, refreshToken),
"Path=/",
"HttpOnly",
}
if expireTime.IsZero() {
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
} else {
// RFC 6265 requires cookie expiration dates to use GMT timezone
// Convert to UTC and format with explicit "GMT" to ensure browser compatibility
attrs = append(attrs, "Expires="+expireTime.UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT"))
}
// Try to determine if the request is HTTPS by checking the origin header
// Default to non-HTTPS (Lax SameSite) if metadata is not available
isHTTPS := false
if md, ok := metadata.FromIncomingContext(ctx); ok {
for _, v := range md.Get("origin") {
if strings.HasPrefix(v, "https://") {
isHTTPS = true
break
}
}
}
if isHTTPS {
attrs = append(attrs, "SameSite=Lax", "Secure")
} else {
attrs = append(attrs, "SameSite=Lax")
}
return strings.Join(attrs, "; ")
}
func (s *APIV1Service) fetchCurrentUser(ctx context.Context) (*store.User, error) {
userID := auth.GetUserID(ctx)
if userID == 0 {
return nil, nil
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.Errorf("user %d not found", userID)
}
return user, nil
}
// extractClientInfo extracts comprehensive client information from the request context.
//
// This function parses metadata from the gRPC context to extract:
// - User Agent: Raw user agent string for detailed parsing
// - IP Address: Client IP from X-Forwarded-For or X-Real-IP headers
// - Device Type: "mobile", "tablet", or "desktop" (parsed from user agent)
// - Operating System: OS name and version (e.g., "iOS 17.1", "Windows 10/11")
// - Browser: Browser name and version (e.g., "Chrome 120.0.0.0")
//
// This information enables users to:
// - See all active sessions with device details
// - Identify suspicious login attempts
// - Revoke specific sessions from unknown devices.
func (s *APIV1Service) extractClientInfo(ctx context.Context) *storepb.RefreshTokensUserSetting_ClientInfo {
clientInfo := &storepb.RefreshTokensUserSetting_ClientInfo{}
// Extract user agent from metadata if available
if md, ok := metadata.FromIncomingContext(ctx); ok {
if userAgents := md.Get("user-agent"); len(userAgents) > 0 {
userAgent := userAgents[0]
clientInfo.UserAgent = userAgent
// Parse user agent to extract device type, OS, browser info
s.parseUserAgent(userAgent, clientInfo)
}
if forwardedFor := md.Get("x-forwarded-for"); len(forwardedFor) > 0 {
ipAddress := strings.Split(forwardedFor[0], ",")[0] // Get the first IP in case of multiple
ipAddress = strings.TrimSpace(ipAddress)
clientInfo.IpAddress = ipAddress
} else if realIP := md.Get("x-real-ip"); len(realIP) > 0 {
clientInfo.IpAddress = realIP[0]
}
}
return clientInfo
}
// parseUserAgent extracts device type, OS, and browser information from user agent string.
//
// Detection logic:
// - Device Type: Checks for keywords like "mobile", "tablet", "ipad"
// - OS: Pattern matches for iOS, Android, Windows, macOS, Linux, Chrome OS
// - Browser: Identifies Edge, Chrome, Firefox, Safari, Opera
//
// Note: This is a simplified parser. For production use with high accuracy requirements,
// consider using a dedicated user agent parsing library.
func (*APIV1Service) parseUserAgent(userAgent string, clientInfo *storepb.RefreshTokensUserSetting_ClientInfo) {
if userAgent == "" {
return
}
userAgent = strings.ToLower(userAgent)
// Detect device type
if strings.Contains(userAgent, "ipad") || strings.Contains(userAgent, "tablet") {
clientInfo.DeviceType = "tablet"
} else if strings.Contains(userAgent, "mobile") || strings.Contains(userAgent, "android") ||
strings.Contains(userAgent, "iphone") || strings.Contains(userAgent, "ipod") ||
strings.Contains(userAgent, "windows phone") || strings.Contains(userAgent, "blackberry") {
clientInfo.DeviceType = "mobile"
} else {
clientInfo.DeviceType = "desktop"
}
// Detect operating system
if strings.Contains(userAgent, "iphone os") || strings.Contains(userAgent, "cpu os") {
// Extract iOS version
if idx := strings.Index(userAgent, "cpu os "); idx != -1 {
versionStart := idx + 7
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd != -1 {
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
clientInfo.Os = "iOS " + version
} else {
clientInfo.Os = "iOS"
}
} else if idx := strings.Index(userAgent, "iphone os "); idx != -1 {
versionStart := idx + 10
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd != -1 {
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
clientInfo.Os = "iOS " + version
} else {
clientInfo.Os = "iOS"
}
} else {
clientInfo.Os = "iOS"
}
} else if strings.Contains(userAgent, "android") {
// Extract Android version
if idx := strings.Index(userAgent, "android "); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], ";")
if versionEnd == -1 {
versionEnd = strings.Index(userAgent[versionStart:], ")")
}
if versionEnd != -1 {
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Os = "Android " + version
} else {
clientInfo.Os = "Android"
}
} else {
clientInfo.Os = "Android"
}
} else if strings.Contains(userAgent, "windows nt 10.0") {
clientInfo.Os = "Windows 10/11"
} else if strings.Contains(userAgent, "windows nt 6.3") {
clientInfo.Os = "Windows 8.1"
} else if strings.Contains(userAgent, "windows nt 6.1") {
clientInfo.Os = "Windows 7"
} else if strings.Contains(userAgent, "windows") {
clientInfo.Os = "Windows"
} else if strings.Contains(userAgent, "mac os x") {
// Extract macOS version
if idx := strings.Index(userAgent, "mac os x "); idx != -1 {
versionStart := idx + 9
versionEnd := strings.Index(userAgent[versionStart:], ";")
if versionEnd == -1 {
versionEnd = strings.Index(userAgent[versionStart:], ")")
}
if versionEnd != -1 {
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
clientInfo.Os = "macOS " + version
} else {
clientInfo.Os = "macOS"
}
} else {
clientInfo.Os = "macOS"
}
} else if strings.Contains(userAgent, "linux") {
clientInfo.Os = "Linux"
} else if strings.Contains(userAgent, "cros") {
clientInfo.Os = "Chrome OS"
}
// Detect browser
if strings.Contains(userAgent, "edg/") {
// Extract Edge version
if idx := strings.Index(userAgent, "edg/"); idx != -1 {
versionStart := idx + 4
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Edge " + version
} else {
clientInfo.Browser = "Edge"
}
} else if strings.Contains(userAgent, "chrome/") && !strings.Contains(userAgent, "edg") {
// Extract Chrome version
if idx := strings.Index(userAgent, "chrome/"); idx != -1 {
versionStart := idx + 7
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Chrome " + version
} else {
clientInfo.Browser = "Chrome"
}
} else if strings.Contains(userAgent, "firefox/") {
// Extract Firefox version
if idx := strings.Index(userAgent, "firefox/"); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Firefox " + version
} else {
clientInfo.Browser = "Firefox"
}
} else if strings.Contains(userAgent, "safari/") && !strings.Contains(userAgent, "chrome") && !strings.Contains(userAgent, "edg") {
// Extract Safari version
if idx := strings.Index(userAgent, "version/"); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Safari " + version
} else {
clientInfo.Browser = "Safari"
}
} else if strings.Contains(userAgent, "opera/") || strings.Contains(userAgent, "opr/") {
clientInfo.Browser = "Opera"
}
}

View File

@@ -0,0 +1,179 @@
package v1
import (
"context"
"testing"
"google.golang.org/grpc/metadata"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestParseUserAgent(t *testing.T) {
service := &APIV1Service{}
tests := []struct {
name string
userAgent string
expectedDevice string
expectedOS string
expectedBrowser string
}{
{
name: "Chrome on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Chrome 119.0.0.0",
},
{
name: "Safari on macOS",
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15",
expectedDevice: "desktop",
expectedOS: "macOS 10.15.7",
expectedBrowser: "Safari 17.0",
},
{
name: "Chrome on Android Mobile",
userAgent: "Mozilla/5.0 (Linux; Android 13; SM-G998B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Mobile Safari/537.36",
expectedDevice: "mobile",
expectedOS: "Android 13",
expectedBrowser: "Chrome 119.0.0.0",
},
{
name: "Safari on iPhone",
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
expectedDevice: "mobile",
expectedOS: "iOS 17.0",
expectedBrowser: "Safari 17.0",
},
{
name: "Firefox on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/119.0",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Firefox 119.0",
},
{
name: "Edge on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Edge 119.0.0.0",
},
{
name: "iPad Safari",
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
expectedDevice: "tablet",
expectedOS: "iOS 17.0",
expectedBrowser: "Safari 17.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientInfo := &storepb.RefreshTokensUserSetting_ClientInfo{}
service.parseUserAgent(tt.userAgent, clientInfo)
if clientInfo.DeviceType != tt.expectedDevice {
t.Errorf("Expected device type %s, got %s", tt.expectedDevice, clientInfo.DeviceType)
}
if clientInfo.Os != tt.expectedOS {
t.Errorf("Expected OS %s, got %s", tt.expectedOS, clientInfo.Os)
}
if clientInfo.Browser != tt.expectedBrowser {
t.Errorf("Expected browser %s, got %s", tt.expectedBrowser, clientInfo.Browser)
}
})
}
}
func TestExtractClientInfo(t *testing.T) {
service := &APIV1Service{}
// Test with metadata containing user agent and IP
md := metadata.New(map[string]string{
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
"x-forwarded-for": "203.0.113.1, 198.51.100.1",
"x-real-ip": "203.0.113.1",
})
ctx := metadata.NewIncomingContext(context.Background(), md)
clientInfo := service.extractClientInfo(ctx)
if clientInfo.UserAgent == "" {
t.Error("Expected user agent to be set")
}
if clientInfo.IpAddress != "203.0.113.1" {
t.Errorf("Expected IP address to be 203.0.113.1, got %s", clientInfo.IpAddress)
}
if clientInfo.DeviceType != "desktop" {
t.Errorf("Expected device type to be desktop, got %s", clientInfo.DeviceType)
}
if clientInfo.Os != "Windows 10/11" {
t.Errorf("Expected OS to be Windows 10/11, got %s", clientInfo.Os)
}
if clientInfo.Browser != "Chrome 119.0.0.0" {
t.Errorf("Expected browser to be Chrome 119.0.0.0, got %s", clientInfo.Browser)
}
}
// TestClientInfoExamples demonstrates the enhanced client info extraction with various user agents.
func TestClientInfoExamples(t *testing.T) {
service := &APIV1Service{}
examples := []struct {
description string
userAgent string
}{
{
description: "Modern Chrome on Windows 11",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
},
{
description: "Safari on iPhone 15 Pro",
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
},
{
description: "Chrome on Samsung Galaxy",
userAgent: "Mozilla/5.0 (Linux; Android 14; SM-S918B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36",
},
{
description: "Firefox on Ubuntu",
userAgent: "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/120.0",
},
{
description: "Edge on Windows 10",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
},
{
description: "Safari on iPad Air",
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
},
}
for _, example := range examples {
t.Run(example.description, func(t *testing.T) {
clientInfo := &storepb.RefreshTokensUserSetting_ClientInfo{}
service.parseUserAgent(example.userAgent, clientInfo)
t.Logf("User Agent: %s", example.userAgent)
t.Logf("Device Type: %s", clientInfo.DeviceType)
t.Logf("Operating System: %s", clientInfo.Os)
t.Logf("Browser: %s", clientInfo.Browser)
t.Log("---")
// Ensure all fields are populated
if clientInfo.DeviceType == "" {
t.Error("Device type should not be empty")
}
if clientInfo.Os == "" {
t.Error("OS should not be empty")
}
if clientInfo.Browser == "" {
t.Error("Browser should not be empty")
}
})
}
}

View File

@@ -0,0 +1,68 @@
package v1
import (
"encoding/base64"
"github.com/pkg/errors"
"google.golang.org/protobuf/proto"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
const (
// DefaultPageSize is the default page size for requests.
DefaultPageSize = 10
// MaxPageSize is the maximum page size for requests.
MaxPageSize = 1000
)
func convertStateFromStore(rowStatus store.RowStatus) v1pb.State {
switch rowStatus {
case store.Normal:
return v1pb.State_NORMAL
case store.Archived:
return v1pb.State_ARCHIVED
default:
return v1pb.State_STATE_UNSPECIFIED
}
}
func convertStateToStore(state v1pb.State) store.RowStatus {
switch state {
case v1pb.State_ARCHIVED:
return store.Archived
default:
return store.Normal
}
}
func getPageToken(limit int, offset int) (string, error) {
return marshalPageToken(&v1pb.PageToken{
Limit: int32(limit),
Offset: int32(offset),
})
}
func marshalPageToken(pageToken *v1pb.PageToken) (string, error) {
b, err := proto.Marshal(pageToken)
if err != nil {
return "", errors.Wrapf(err, "failed to marshal page token")
}
return base64.StdEncoding.EncodeToString(b), nil
}
func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return errors.Wrapf(err, "failed to decode page token")
}
if err := proto.Unmarshal(b, pageToken); err != nil {
return errors.Wrapf(err, "failed to unmarshal page token")
}
return nil
}
func isSuperUser(user *store.User) bool {
return user.Role == store.RoleAdmin
}

View File

@@ -0,0 +1,80 @@
package v1
import (
"net/http"
"connectrpc.com/connect"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/usememos/memos/proto/gen/api/v1/apiv1connect"
)
// ConnectServiceHandler wraps APIV1Service to implement Connect handler interfaces.
// It adapts the existing gRPC service implementations to work with Connect's
// request/response wrapper types.
//
// This wrapper pattern allows us to:
// - Reuse existing gRPC service implementations
// - Support both native gRPC and Connect protocols
// - Maintain a single source of truth for business logic.
type ConnectServiceHandler struct {
*APIV1Service
}
// NewConnectServiceHandler creates a new Connect service handler.
func NewConnectServiceHandler(svc *APIV1Service) *ConnectServiceHandler {
return &ConnectServiceHandler{APIV1Service: svc}
}
// RegisterConnectHandlers registers all Connect service handlers on the given mux.
func (s *ConnectServiceHandler) RegisterConnectHandlers(mux *http.ServeMux, opts ...connect.HandlerOption) {
// Register all service handlers
handlers := []struct {
path string
handler http.Handler
}{
wrap(apiv1connect.NewInstanceServiceHandler(s, opts...)),
wrap(apiv1connect.NewAuthServiceHandler(s, opts...)),
wrap(apiv1connect.NewUserServiceHandler(s, opts...)),
wrap(apiv1connect.NewMemoServiceHandler(s, opts...)),
wrap(apiv1connect.NewAttachmentServiceHandler(s, opts...)),
wrap(apiv1connect.NewShortcutServiceHandler(s, opts...)),
wrap(apiv1connect.NewActivityServiceHandler(s, opts...)),
wrap(apiv1connect.NewIdentityProviderServiceHandler(s, opts...)),
}
for _, h := range handlers {
mux.Handle(h.path, h.handler)
}
}
// wrap converts (path, handler) return value to a struct for cleaner iteration.
func wrap(path string, handler http.Handler) struct {
path string
handler http.Handler
} {
return struct {
path string
handler http.Handler
}{path, handler}
}
// convertGRPCError converts gRPC status errors to Connect errors.
// This preserves the error code semantics between the two protocols.
func convertGRPCError(err error) error {
if err == nil {
return nil
}
if st, ok := status.FromError(err); ok {
return connect.NewError(grpcCodeToConnectCode(st.Code()), err)
}
return connect.NewError(connect.CodeInternal, err)
}
// grpcCodeToConnectCode converts gRPC status codes to Connect error codes.
// gRPC and Connect use the same error code semantics, so this is a direct cast.
// See: https://connectrpc.com/docs/protocol/#error-codes
func grpcCodeToConnectCode(code codes.Code) connect.Code {
return connect.Code(code)
}

View File

@@ -0,0 +1,237 @@
package v1
import (
"context"
"errors"
"fmt"
"log/slog"
"reflect"
"runtime/debug"
"connectrpc.com/connect"
pkgerrors "github.com/pkg/errors"
"google.golang.org/grpc/metadata"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// MetadataInterceptor converts Connect HTTP headers to gRPC metadata.
//
// This ensures service methods can use metadata.FromIncomingContext() to access
// headers like User-Agent, X-Forwarded-For, etc., regardless of whether the
// request came via Connect RPC or gRPC-Gateway.
type MetadataInterceptor struct{}
// NewMetadataInterceptor creates a new metadata interceptor.
func NewMetadataInterceptor() *MetadataInterceptor {
return &MetadataInterceptor{}
}
func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
// Convert HTTP headers to gRPC metadata
header := req.Header()
md := metadata.MD{}
// Copy important headers for client info extraction
if ua := header.Get("User-Agent"); ua != "" {
md.Set("user-agent", ua)
}
if xff := header.Get("X-Forwarded-For"); xff != "" {
md.Set("x-forwarded-for", xff)
}
if xri := header.Get("X-Real-Ip"); xri != "" {
md.Set("x-real-ip", xri)
}
// Forward Cookie header for authentication methods that need it (e.g., RefreshToken)
if cookie := header.Get("Cookie"); cookie != "" {
md.Set("cookie", cookie)
}
// Set metadata in context so services can use metadata.FromIncomingContext()
ctx = metadata.NewIncomingContext(ctx, md)
// Execute the request
resp, err := next(ctx, req)
// Prevent browser caching of API responses to avoid stale data issues
// See: https://github.com/usememos/memos/issues/5470
if !isNilAnyResponse(resp) && resp.Header() != nil {
resp.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
resp.Header().Set("Pragma", "no-cache")
resp.Header().Set("Expires", "0")
}
return resp, err
}
}
func isNilAnyResponse(resp connect.AnyResponse) bool {
if resp == nil {
return true
}
val := reflect.ValueOf(resp)
return val.Kind() == reflect.Ptr && val.IsNil()
}
func (*MetadataInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next
}
func (*MetadataInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next
}
// LoggingInterceptor logs Connect RPC requests with appropriate log levels.
//
// Log levels:
// - INFO: Successful requests and expected client errors (not found, permission denied, etc.)
// - ERROR: Server errors (internal, unavailable, etc.)
type LoggingInterceptor struct {
logStacktrace bool
}
// NewLoggingInterceptor creates a new logging interceptor.
func NewLoggingInterceptor(logStacktrace bool) *LoggingInterceptor {
return &LoggingInterceptor{logStacktrace: logStacktrace}
}
func (in *LoggingInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
resp, err := next(ctx, req)
in.log(req.Spec().Procedure, err)
return resp, err
}
}
func (*LoggingInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next // No-op for server-side interceptor
}
func (*LoggingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next // Streaming not used in this service
}
func (in *LoggingInterceptor) log(procedure string, err error) {
level, msg := in.classifyError(err)
attrs := []slog.Attr{slog.String("method", procedure)}
if err != nil {
attrs = append(attrs, slog.String("error", err.Error()))
if in.logStacktrace {
attrs = append(attrs, slog.String("stacktrace", fmt.Sprintf("%+v", err)))
}
}
slog.LogAttrs(context.Background(), level, msg, attrs...)
}
func (*LoggingInterceptor) classifyError(err error) (slog.Level, string) {
if err == nil {
return slog.LevelInfo, "OK"
}
var connectErr *connect.Error
if !pkgerrors.As(err, &connectErr) {
return slog.LevelError, "unknown error"
}
// Client errors (expected, log at INFO)
switch connectErr.Code() {
case connect.CodeCanceled,
connect.CodeInvalidArgument,
connect.CodeNotFound,
connect.CodeAlreadyExists,
connect.CodePermissionDenied,
connect.CodeUnauthenticated,
connect.CodeResourceExhausted,
connect.CodeFailedPrecondition,
connect.CodeAborted,
connect.CodeOutOfRange:
return slog.LevelInfo, "client error"
default:
// Server errors
return slog.LevelError, "server error"
}
}
// RecoveryInterceptor recovers from panics in Connect handlers and returns an internal error.
type RecoveryInterceptor struct {
logStacktrace bool
}
// NewRecoveryInterceptor creates a new recovery interceptor.
func NewRecoveryInterceptor(logStacktrace bool) *RecoveryInterceptor {
return &RecoveryInterceptor{logStacktrace: logStacktrace}
}
func (in *RecoveryInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (resp connect.AnyResponse, err error) {
defer func() {
if r := recover(); r != nil {
in.logPanic(req.Spec().Procedure, r)
err = connect.NewError(connect.CodeInternal, pkgerrors.New("internal server error"))
}
}()
return next(ctx, req)
}
}
func (*RecoveryInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next
}
func (*RecoveryInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next
}
func (in *RecoveryInterceptor) logPanic(procedure string, panicValue any) {
attrs := []slog.Attr{
slog.String("method", procedure),
slog.Any("panic", panicValue),
}
if in.logStacktrace {
attrs = append(attrs, slog.String("stacktrace", string(debug.Stack())))
}
slog.LogAttrs(context.Background(), slog.LevelError, "panic recovered in Connect handler", attrs...)
}
// AuthInterceptor handles authentication for Connect handlers.
//
// It enforces authentication for all endpoints except those listed in PublicMethods.
// Role-based authorization (admin checks) remains in the service layer.
type AuthInterceptor struct {
authenticator *auth.Authenticator
}
// NewAuthInterceptor creates a new auth interceptor.
func NewAuthInterceptor(store *store.Store, secret string) *AuthInterceptor {
return &AuthInterceptor{
authenticator: auth.NewAuthenticator(store, secret),
}
}
func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
header := req.Header()
authHeader := header.Get("Authorization")
result := in.authenticator.Authenticate(ctx, authHeader)
// Enforce authentication for non-public methods
if result == nil && !IsPublicMethod(req.Spec().Procedure) {
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required"))
}
ctx = auth.ApplyToContext(ctx, result)
return next(ctx, req)
}
}
func (*AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next
}
func (*AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next
}

View File

@@ -0,0 +1,490 @@
package v1
import (
"context"
"connectrpc.com/connect"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
// This file contains all Connect service handler method implementations.
// Each method delegates to the underlying gRPC service implementation,
// converting between Connect and gRPC request/response types.
// InstanceService
func (s *ConnectServiceHandler) GetInstanceProfile(ctx context.Context, req *connect.Request[v1pb.GetInstanceProfileRequest]) (*connect.Response[v1pb.InstanceProfile], error) {
resp, err := s.APIV1Service.GetInstanceProfile(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetInstanceSetting(ctx context.Context, req *connect.Request[v1pb.GetInstanceSettingRequest]) (*connect.Response[v1pb.InstanceSetting], error) {
resp, err := s.APIV1Service.GetInstanceSetting(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateInstanceSetting(ctx context.Context, req *connect.Request[v1pb.UpdateInstanceSettingRequest]) (*connect.Response[v1pb.InstanceSetting], error) {
resp, err := s.APIV1Service.UpdateInstanceSetting(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// AuthService
//
// Auth service methods need special handling for response headers (cookies).
// We use connectWithHeaderCarrier helper to inject a header carrier into the context,
// which allows the service to set headers in a protocol-agnostic way.
func (s *ConnectServiceHandler) GetCurrentUser(ctx context.Context, req *connect.Request[v1pb.GetCurrentUserRequest]) (*connect.Response[v1pb.GetCurrentUserResponse], error) {
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.GetCurrentUserResponse, error) {
return s.APIV1Service.GetCurrentUser(ctx, req.Msg)
})
}
func (s *ConnectServiceHandler) SignIn(ctx context.Context, req *connect.Request[v1pb.SignInRequest]) (*connect.Response[v1pb.SignInResponse], error) {
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.SignInResponse, error) {
return s.APIV1Service.SignIn(ctx, req.Msg)
})
}
func (s *ConnectServiceHandler) SignOut(ctx context.Context, req *connect.Request[v1pb.SignOutRequest]) (*connect.Response[emptypb.Empty], error) {
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*emptypb.Empty, error) {
return s.APIV1Service.SignOut(ctx, req.Msg)
})
}
func (s *ConnectServiceHandler) RefreshToken(ctx context.Context, req *connect.Request[v1pb.RefreshTokenRequest]) (*connect.Response[v1pb.RefreshTokenResponse], error) {
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.RefreshTokenResponse, error) {
return s.APIV1Service.RefreshToken(ctx, req.Msg)
})
}
// UserService
func (s *ConnectServiceHandler) ListUsers(ctx context.Context, req *connect.Request[v1pb.ListUsersRequest]) (*connect.Response[v1pb.ListUsersResponse], error) {
resp, err := s.APIV1Service.ListUsers(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetUser(ctx context.Context, req *connect.Request[v1pb.GetUserRequest]) (*connect.Response[v1pb.User], error) {
resp, err := s.APIV1Service.GetUser(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateUser(ctx context.Context, req *connect.Request[v1pb.CreateUserRequest]) (*connect.Response[v1pb.User], error) {
resp, err := s.APIV1Service.CreateUser(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateUser(ctx context.Context, req *connect.Request[v1pb.UpdateUserRequest]) (*connect.Response[v1pb.User], error) {
resp, err := s.APIV1Service.UpdateUser(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteUser(ctx context.Context, req *connect.Request[v1pb.DeleteUserRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteUser(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListAllUserStats(ctx context.Context, req *connect.Request[v1pb.ListAllUserStatsRequest]) (*connect.Response[v1pb.ListAllUserStatsResponse], error) {
resp, err := s.APIV1Service.ListAllUserStats(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetUserStats(ctx context.Context, req *connect.Request[v1pb.GetUserStatsRequest]) (*connect.Response[v1pb.UserStats], error) {
resp, err := s.APIV1Service.GetUserStats(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetUserSetting(ctx context.Context, req *connect.Request[v1pb.GetUserSettingRequest]) (*connect.Response[v1pb.UserSetting], error) {
resp, err := s.APIV1Service.GetUserSetting(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateUserSetting(ctx context.Context, req *connect.Request[v1pb.UpdateUserSettingRequest]) (*connect.Response[v1pb.UserSetting], error) {
resp, err := s.APIV1Service.UpdateUserSetting(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListUserSettings(ctx context.Context, req *connect.Request[v1pb.ListUserSettingsRequest]) (*connect.Response[v1pb.ListUserSettingsResponse], error) {
resp, err := s.APIV1Service.ListUserSettings(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListPersonalAccessTokens(ctx context.Context, req *connect.Request[v1pb.ListPersonalAccessTokensRequest]) (*connect.Response[v1pb.ListPersonalAccessTokensResponse], error) {
resp, err := s.APIV1Service.ListPersonalAccessTokens(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreatePersonalAccessToken(ctx context.Context, req *connect.Request[v1pb.CreatePersonalAccessTokenRequest]) (*connect.Response[v1pb.CreatePersonalAccessTokenResponse], error) {
resp, err := s.APIV1Service.CreatePersonalAccessToken(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeletePersonalAccessToken(ctx context.Context, req *connect.Request[v1pb.DeletePersonalAccessTokenRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeletePersonalAccessToken(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListUserWebhooks(ctx context.Context, req *connect.Request[v1pb.ListUserWebhooksRequest]) (*connect.Response[v1pb.ListUserWebhooksResponse], error) {
resp, err := s.APIV1Service.ListUserWebhooks(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateUserWebhook(ctx context.Context, req *connect.Request[v1pb.CreateUserWebhookRequest]) (*connect.Response[v1pb.UserWebhook], error) {
resp, err := s.APIV1Service.CreateUserWebhook(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateUserWebhook(ctx context.Context, req *connect.Request[v1pb.UpdateUserWebhookRequest]) (*connect.Response[v1pb.UserWebhook], error) {
resp, err := s.APIV1Service.UpdateUserWebhook(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteUserWebhook(ctx context.Context, req *connect.Request[v1pb.DeleteUserWebhookRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteUserWebhook(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListUserNotifications(ctx context.Context, req *connect.Request[v1pb.ListUserNotificationsRequest]) (*connect.Response[v1pb.ListUserNotificationsResponse], error) {
resp, err := s.APIV1Service.ListUserNotifications(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateUserNotification(ctx context.Context, req *connect.Request[v1pb.UpdateUserNotificationRequest]) (*connect.Response[v1pb.UserNotification], error) {
resp, err := s.APIV1Service.UpdateUserNotification(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteUserNotification(ctx context.Context, req *connect.Request[v1pb.DeleteUserNotificationRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteUserNotification(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// MemoService
func (s *ConnectServiceHandler) CreateMemo(ctx context.Context, req *connect.Request[v1pb.CreateMemoRequest]) (*connect.Response[v1pb.Memo], error) {
resp, err := s.APIV1Service.CreateMemo(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListMemos(ctx context.Context, req *connect.Request[v1pb.ListMemosRequest]) (*connect.Response[v1pb.ListMemosResponse], error) {
resp, err := s.APIV1Service.ListMemos(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetMemo(ctx context.Context, req *connect.Request[v1pb.GetMemoRequest]) (*connect.Response[v1pb.Memo], error) {
resp, err := s.APIV1Service.GetMemo(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateMemo(ctx context.Context, req *connect.Request[v1pb.UpdateMemoRequest]) (*connect.Response[v1pb.Memo], error) {
resp, err := s.APIV1Service.UpdateMemo(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteMemo(ctx context.Context, req *connect.Request[v1pb.DeleteMemoRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteMemo(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) SetMemoAttachments(ctx context.Context, req *connect.Request[v1pb.SetMemoAttachmentsRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.SetMemoAttachments(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListMemoAttachments(ctx context.Context, req *connect.Request[v1pb.ListMemoAttachmentsRequest]) (*connect.Response[v1pb.ListMemoAttachmentsResponse], error) {
resp, err := s.APIV1Service.ListMemoAttachments(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) SetMemoRelations(ctx context.Context, req *connect.Request[v1pb.SetMemoRelationsRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.SetMemoRelations(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListMemoRelations(ctx context.Context, req *connect.Request[v1pb.ListMemoRelationsRequest]) (*connect.Response[v1pb.ListMemoRelationsResponse], error) {
resp, err := s.APIV1Service.ListMemoRelations(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateMemoComment(ctx context.Context, req *connect.Request[v1pb.CreateMemoCommentRequest]) (*connect.Response[v1pb.Memo], error) {
resp, err := s.APIV1Service.CreateMemoComment(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListMemoComments(ctx context.Context, req *connect.Request[v1pb.ListMemoCommentsRequest]) (*connect.Response[v1pb.ListMemoCommentsResponse], error) {
resp, err := s.APIV1Service.ListMemoComments(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListMemoReactions(ctx context.Context, req *connect.Request[v1pb.ListMemoReactionsRequest]) (*connect.Response[v1pb.ListMemoReactionsResponse], error) {
resp, err := s.APIV1Service.ListMemoReactions(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpsertMemoReaction(ctx context.Context, req *connect.Request[v1pb.UpsertMemoReactionRequest]) (*connect.Response[v1pb.Reaction], error) {
resp, err := s.APIV1Service.UpsertMemoReaction(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteMemoReaction(ctx context.Context, req *connect.Request[v1pb.DeleteMemoReactionRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteMemoReaction(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// AttachmentService
func (s *ConnectServiceHandler) CreateAttachment(ctx context.Context, req *connect.Request[v1pb.CreateAttachmentRequest]) (*connect.Response[v1pb.Attachment], error) {
resp, err := s.APIV1Service.CreateAttachment(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListAttachments(ctx context.Context, req *connect.Request[v1pb.ListAttachmentsRequest]) (*connect.Response[v1pb.ListAttachmentsResponse], error) {
resp, err := s.APIV1Service.ListAttachments(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetAttachment(ctx context.Context, req *connect.Request[v1pb.GetAttachmentRequest]) (*connect.Response[v1pb.Attachment], error) {
resp, err := s.APIV1Service.GetAttachment(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateAttachment(ctx context.Context, req *connect.Request[v1pb.UpdateAttachmentRequest]) (*connect.Response[v1pb.Attachment], error) {
resp, err := s.APIV1Service.UpdateAttachment(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteAttachment(ctx context.Context, req *connect.Request[v1pb.DeleteAttachmentRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteAttachment(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// ShortcutService
func (s *ConnectServiceHandler) ListShortcuts(ctx context.Context, req *connect.Request[v1pb.ListShortcutsRequest]) (*connect.Response[v1pb.ListShortcutsResponse], error) {
resp, err := s.APIV1Service.ListShortcuts(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetShortcut(ctx context.Context, req *connect.Request[v1pb.GetShortcutRequest]) (*connect.Response[v1pb.Shortcut], error) {
resp, err := s.APIV1Service.GetShortcut(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateShortcut(ctx context.Context, req *connect.Request[v1pb.CreateShortcutRequest]) (*connect.Response[v1pb.Shortcut], error) {
resp, err := s.APIV1Service.CreateShortcut(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateShortcut(ctx context.Context, req *connect.Request[v1pb.UpdateShortcutRequest]) (*connect.Response[v1pb.Shortcut], error) {
resp, err := s.APIV1Service.UpdateShortcut(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteShortcut(ctx context.Context, req *connect.Request[v1pb.DeleteShortcutRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteShortcut(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// ActivityService
func (s *ConnectServiceHandler) ListActivities(ctx context.Context, req *connect.Request[v1pb.ListActivitiesRequest]) (*connect.Response[v1pb.ListActivitiesResponse], error) {
resp, err := s.APIV1Service.ListActivities(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetActivity(ctx context.Context, req *connect.Request[v1pb.GetActivityRequest]) (*connect.Response[v1pb.Activity], error) {
resp, err := s.APIV1Service.GetActivity(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// IdentityProviderService
func (s *ConnectServiceHandler) ListIdentityProviders(ctx context.Context, req *connect.Request[v1pb.ListIdentityProvidersRequest]) (*connect.Response[v1pb.ListIdentityProvidersResponse], error) {
resp, err := s.APIV1Service.ListIdentityProviders(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetIdentityProvider(ctx context.Context, req *connect.Request[v1pb.GetIdentityProviderRequest]) (*connect.Response[v1pb.IdentityProvider], error) {
resp, err := s.APIV1Service.GetIdentityProvider(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateIdentityProvider(ctx context.Context, req *connect.Request[v1pb.CreateIdentityProviderRequest]) (*connect.Response[v1pb.IdentityProvider], error) {
resp, err := s.APIV1Service.CreateIdentityProvider(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateIdentityProvider(ctx context.Context, req *connect.Request[v1pb.UpdateIdentityProviderRequest]) (*connect.Response[v1pb.IdentityProvider], error) {
resp, err := s.APIV1Service.UpdateIdentityProvider(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteIdentityProvider(ctx context.Context, req *connect.Request[v1pb.DeleteIdentityProviderRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteIdentityProvider(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}

View File

@@ -0,0 +1,124 @@
package v1
import (
"context"
"connectrpc.com/connect"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
// headerCarrierKey is the context key for storing headers to be set in the response.
type headerCarrierKey struct{}
// HeaderCarrier stores headers that need to be set in the response.
//
// Problem: The codebase supports two protocols simultaneously:
// - Native gRPC: Uses grpc.SetHeader() to set response headers
// - Connect-RPC: Uses connect.Response.Header().Set() to set response headers
//
// Solution: HeaderCarrier provides a protocol-agnostic way to set headers.
// - Service methods call SetResponseHeader() regardless of protocol
// - For gRPC requests: SetResponseHeader uses grpc.SetHeader directly
// - For Connect requests: SetResponseHeader stores headers in HeaderCarrier
// - Connect wrappers extract headers from HeaderCarrier and apply to response
//
// This allows service methods to work with both protocols without knowing which one is being used.
type HeaderCarrier struct {
headers map[string]string
}
// newHeaderCarrier creates a new header carrier.
func newHeaderCarrier() *HeaderCarrier {
return &HeaderCarrier{
headers: make(map[string]string),
}
}
// Set adds a header to the carrier.
func (h *HeaderCarrier) Set(key, value string) {
h.headers[key] = value
}
// Get retrieves a header from the carrier.
func (h *HeaderCarrier) Get(key string) string {
return h.headers[key]
}
// All returns all headers.
func (h *HeaderCarrier) All() map[string]string {
return h.headers
}
// WithHeaderCarrier adds a header carrier to the context.
func WithHeaderCarrier(ctx context.Context) context.Context {
return context.WithValue(ctx, headerCarrierKey{}, newHeaderCarrier())
}
// GetHeaderCarrier retrieves the header carrier from the context.
// Returns nil if no carrier is present.
func GetHeaderCarrier(ctx context.Context) *HeaderCarrier {
if carrier, ok := ctx.Value(headerCarrierKey{}).(*HeaderCarrier); ok {
return carrier
}
return nil
}
// SetResponseHeader sets a header in the response.
//
// This function works for both gRPC and Connect protocols:
// - For gRPC: Uses grpc.SetHeader to set headers in gRPC metadata
// - For Connect: Stores in HeaderCarrier for Connect wrapper to apply later
//
// The protocol is automatically detected based on whether a HeaderCarrier
// exists in the context (injected by Connect wrappers).
func SetResponseHeader(ctx context.Context, key, value string) error {
// Try Connect first (check if we have a header carrier)
if carrier := GetHeaderCarrier(ctx); carrier != nil {
carrier.Set(key, value)
return nil
}
// Fall back to gRPC
return grpc.SetHeader(ctx, metadata.New(map[string]string{
key: value,
}))
}
// connectWithHeaderCarrier is a helper for Connect service wrappers that need to set response headers.
//
// It injects a HeaderCarrier into the context, calls the service method,
// and applies any headers from the carrier to the Connect response.
//
// The generic parameter T is the non-pointer protobuf message type (e.g., v1pb.CreateSessionResponse),
// while fn returns *T (the pointer type) as is standard for protobuf messages.
//
// Usage in Connect wrappers:
//
// func (s *ConnectServiceHandler) CreateSession(ctx context.Context, req *connect.Request[v1pb.CreateSessionRequest]) (*connect.Response[v1pb.CreateSessionResponse], error) {
// return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.CreateSessionResponse, error) {
// return s.APIV1Service.CreateSession(ctx, req.Msg)
// })
// }
func connectWithHeaderCarrier[T any](ctx context.Context, fn func(context.Context) (*T, error)) (*connect.Response[T], error) {
// Inject header carrier for Connect protocol
ctx = WithHeaderCarrier(ctx)
// Call the service method
resp, err := fn(ctx)
if err != nil {
return nil, convertGRPCError(err)
}
// Create Connect response
connectResp := connect.NewResponse(resp)
// Apply any headers set via the header carrier
if carrier := GetHeaderCarrier(ctx); carrier != nil {
for key, value := range carrier.All() {
connectResp.Header().Set(key, value)
}
}
return connectResp, nil
}

View File

@@ -0,0 +1,25 @@
package v1
import (
"context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
)
func (s *APIV1Service) Check(ctx context.Context,
_ *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) {
// Check if database is initialized by verifying instance basic setting exists
instanceBasicSetting, err := s.Store.GetInstanceBasicSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Unavailable, "database not initialized: %v", err)
}
// Verify schema version is set (empty means database not properly initialized)
if instanceBasicSetting.SchemaVersion == "" {
return nil, status.Errorf(codes.Unavailable, "schema version not set")
}
return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil
}

View File

@@ -0,0 +1,238 @@
package v1
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb.CreateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
}
return convertIdentityProviderFromStore(identityProvider), nil
}
func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListIdentityProvidersRequest) (*v1pb.ListIdentityProvidersResponse, error) {
identityProviders, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list identity providers, error: %+v", err)
}
response := &v1pb.ListIdentityProvidersResponse{
IdentityProviders: []*v1pb.IdentityProvider{},
}
// Default to lowest-privilege role, update later based on real role
currentUserRole := store.RoleUser
currentUser, err := s.fetchCurrentUser(ctx)
if err == nil && currentUser != nil {
currentUserRole = currentUser.Role
}
for _, identityProvider := range identityProviders {
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
response.IdentityProviders = append(response.IdentityProviders, redactIdentityProviderResponse(identityProviderConverted, currentUserRole))
}
return response, nil
}
func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
// Default to lowest-privilege role, update later based on real role
currentUserRole := store.RoleUser
currentUser, err := s.fetchCurrentUser(ctx)
if err == nil && currentUser != nil {
currentUserRole = currentUser.Role
}
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
return redactIdentityProviderResponse(identityProviderConverted, currentUserRole), nil
}
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
update := &store.UpdateIdentityProviderV1{
ID: id,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
}
for _, field := range request.UpdateMask.Paths {
switch field {
case "title":
update.Name = &request.IdentityProvider.Title
case "identifier_filter":
update.IdentifierFilter = &request.IdentityProvider.IdentifierFilter
case "config":
update.Config = convertIdentityProviderConfigToStore(request.IdentityProvider.Type, request.IdentityProvider.Config)
default:
// Ignore unsupported fields
}
}
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
}
return convertIdentityProviderFromStore(identityProvider), nil
}
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
// Check if the identity provider exists before trying to delete it
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &id})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
}
return &emptypb.Empty{}, nil
}
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
temp := &v1pb.IdentityProvider{
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
Title: identityProvider.Name,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
}
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2Config := identityProvider.Config.GetOauth2Config()
temp.Config = &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &v1pb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
AvatarUrl: oauth2Config.FieldMapping.AvatarUrl,
},
},
},
}
}
return temp
}
func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
temp := &storepb.IdentityProvider{
Id: id,
Name: identityProvider.Title,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
Config: convertIdentityProviderConfigToStore(identityProvider.Type, identityProvider.Config),
}
return temp
}
func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProvider_Type, config *v1pb.IdentityProviderConfig) *storepb.IdentityProviderConfig {
if identityProviderType == v1pb.IdentityProvider_OAUTH2 {
oauth2Config := config.GetOauth2Config()
return &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &storepb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
AvatarUrl: oauth2Config.FieldMapping.AvatarUrl,
},
},
},
}
}
return nil
}
func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider {
if userRole != store.RoleAdmin {
if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 {
identityProvider.Config.GetOauth2Config().ClientSecret = ""
}
}
return identityProvider
}

View File

@@ -0,0 +1,285 @@
package v1
import (
"context"
"fmt"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// GetInstanceProfile returns the instance profile.
func (s *APIV1Service) GetInstanceProfile(ctx context.Context, _ *v1pb.GetInstanceProfileRequest) (*v1pb.InstanceProfile, error) {
admin, err := s.GetInstanceAdmin(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance admin: %v", err)
}
instanceProfile := &v1pb.InstanceProfile{
Version: s.Profile.Version,
Demo: s.Profile.Demo,
InstanceUrl: s.Profile.InstanceURL,
Admin: admin, // nil when not initialized
}
return instanceProfile, nil
}
func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.GetInstanceSettingRequest) (*v1pb.InstanceSetting, error) {
instanceSettingKeyString, err := ExtractInstanceSettingKeyFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid instance setting name: %v", err)
}
instanceSettingKey := storepb.InstanceSettingKey(storepb.InstanceSettingKey_value[instanceSettingKeyString])
// Get instance setting from store with default value.
switch instanceSettingKey {
case storepb.InstanceSettingKey_BASIC:
_, err = s.Store.GetInstanceBasicSetting(ctx)
case storepb.InstanceSettingKey_GENERAL:
_, err = s.Store.GetInstanceGeneralSetting(ctx)
case storepb.InstanceSettingKey_MEMO_RELATED:
_, err = s.Store.GetInstanceMemoRelatedSetting(ctx)
case storepb.InstanceSettingKey_STORAGE:
_, err = s.Store.GetInstanceStorageSetting(ctx)
default:
return nil, status.Errorf(codes.InvalidArgument, "unsupported instance setting key: %v", instanceSettingKey)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance setting: %v", err)
}
instanceSetting, err := s.Store.GetInstanceSetting(ctx, &store.FindInstanceSetting{
Name: instanceSettingKey.String(),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance setting: %v", err)
}
if instanceSetting == nil {
return nil, status.Errorf(codes.NotFound, "instance setting not found")
}
// For storage setting, only admin can get it.
if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if user.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
return convertInstanceSettingFromStore(instanceSetting), nil
}
func (s *APIV1Service) UpdateInstanceSetting(ctx context.Context, request *v1pb.UpdateInstanceSettingRequest) (*v1pb.InstanceSetting, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if user.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// TODO: Apply update_mask if specified
_ = request.UpdateMask
updateSetting := convertInstanceSettingToStore(request.Setting)
instanceSetting, err := s.Store.UpsertInstanceSetting(ctx, updateSetting)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert instance setting: %v", err)
}
return convertInstanceSettingFromStore(instanceSetting), nil
}
func convertInstanceSettingFromStore(setting *storepb.InstanceSetting) *v1pb.InstanceSetting {
instanceSetting := &v1pb.InstanceSetting{
Name: fmt.Sprintf("instance/settings/%s", setting.Key.String()),
}
switch setting.Value.(type) {
case *storepb.InstanceSetting_GeneralSetting:
instanceSetting.Value = &v1pb.InstanceSetting_GeneralSetting_{
GeneralSetting: convertInstanceGeneralSettingFromStore(setting.GetGeneralSetting()),
}
case *storepb.InstanceSetting_StorageSetting:
instanceSetting.Value = &v1pb.InstanceSetting_StorageSetting_{
StorageSetting: convertInstanceStorageSettingFromStore(setting.GetStorageSetting()),
}
case *storepb.InstanceSetting_MemoRelatedSetting:
instanceSetting.Value = &v1pb.InstanceSetting_MemoRelatedSetting_{
MemoRelatedSetting: convertInstanceMemoRelatedSettingFromStore(setting.GetMemoRelatedSetting()),
}
}
return instanceSetting
}
func convertInstanceSettingToStore(setting *v1pb.InstanceSetting) *storepb.InstanceSetting {
settingKeyString, _ := ExtractInstanceSettingKeyFromName(setting.Name)
instanceSetting := &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey(storepb.InstanceSettingKey_value[settingKeyString]),
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: convertInstanceGeneralSettingToStore(setting.GetGeneralSetting()),
},
}
switch instanceSetting.Key {
case storepb.InstanceSettingKey_GENERAL:
instanceSetting.Value = &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: convertInstanceGeneralSettingToStore(setting.GetGeneralSetting()),
}
case storepb.InstanceSettingKey_STORAGE:
instanceSetting.Value = &storepb.InstanceSetting_StorageSetting{
StorageSetting: convertInstanceStorageSettingToStore(setting.GetStorageSetting()),
}
case storepb.InstanceSettingKey_MEMO_RELATED:
instanceSetting.Value = &storepb.InstanceSetting_MemoRelatedSetting{
MemoRelatedSetting: convertInstanceMemoRelatedSettingToStore(setting.GetMemoRelatedSetting()),
}
default:
// Keep the default GeneralSetting value
}
return instanceSetting
}
func convertInstanceGeneralSettingFromStore(setting *storepb.InstanceGeneralSetting) *v1pb.InstanceSetting_GeneralSetting {
if setting == nil {
return nil
}
generalSetting := &v1pb.InstanceSetting_GeneralSetting{
DisallowUserRegistration: setting.DisallowUserRegistration,
DisallowPasswordAuth: setting.DisallowPasswordAuth,
AdditionalScript: setting.AdditionalScript,
AdditionalStyle: setting.AdditionalStyle,
WeekStartDayOffset: setting.WeekStartDayOffset,
DisallowChangeUsername: setting.DisallowChangeUsername,
DisallowChangeNickname: setting.DisallowChangeNickname,
}
if setting.CustomProfile != nil {
generalSetting.CustomProfile = &v1pb.InstanceSetting_GeneralSetting_CustomProfile{
Title: setting.CustomProfile.Title,
Description: setting.CustomProfile.Description,
LogoUrl: setting.CustomProfile.LogoUrl,
}
}
return generalSetting
}
func convertInstanceGeneralSettingToStore(setting *v1pb.InstanceSetting_GeneralSetting) *storepb.InstanceGeneralSetting {
if setting == nil {
return nil
}
generalSetting := &storepb.InstanceGeneralSetting{
DisallowUserRegistration: setting.DisallowUserRegistration,
DisallowPasswordAuth: setting.DisallowPasswordAuth,
AdditionalScript: setting.AdditionalScript,
AdditionalStyle: setting.AdditionalStyle,
WeekStartDayOffset: setting.WeekStartDayOffset,
DisallowChangeUsername: setting.DisallowChangeUsername,
DisallowChangeNickname: setting.DisallowChangeNickname,
}
if setting.CustomProfile != nil {
generalSetting.CustomProfile = &storepb.InstanceCustomProfile{
Title: setting.CustomProfile.Title,
Description: setting.CustomProfile.Description,
LogoUrl: setting.CustomProfile.LogoUrl,
}
}
return generalSetting
}
func convertInstanceStorageSettingFromStore(settingpb *storepb.InstanceStorageSetting) *v1pb.InstanceSetting_StorageSetting {
if settingpb == nil {
return nil
}
setting := &v1pb.InstanceSetting_StorageSetting{
StorageType: v1pb.InstanceSetting_StorageSetting_StorageType(settingpb.StorageType),
FilepathTemplate: settingpb.FilepathTemplate,
UploadSizeLimitMb: settingpb.UploadSizeLimitMb,
}
if settingpb.S3Config != nil {
setting.S3Config = &v1pb.InstanceSetting_StorageSetting_S3Config{
AccessKeyId: settingpb.S3Config.AccessKeyId,
AccessKeySecret: settingpb.S3Config.AccessKeySecret,
Endpoint: settingpb.S3Config.Endpoint,
Region: settingpb.S3Config.Region,
Bucket: settingpb.S3Config.Bucket,
UsePathStyle: settingpb.S3Config.UsePathStyle,
}
}
return setting
}
func convertInstanceStorageSettingToStore(setting *v1pb.InstanceSetting_StorageSetting) *storepb.InstanceStorageSetting {
if setting == nil {
return nil
}
settingpb := &storepb.InstanceStorageSetting{
StorageType: storepb.InstanceStorageSetting_StorageType(setting.StorageType),
FilepathTemplate: setting.FilepathTemplate,
UploadSizeLimitMb: setting.UploadSizeLimitMb,
}
if setting.S3Config != nil {
settingpb.S3Config = &storepb.StorageS3Config{
AccessKeyId: setting.S3Config.AccessKeyId,
AccessKeySecret: setting.S3Config.AccessKeySecret,
Endpoint: setting.S3Config.Endpoint,
Region: setting.S3Config.Region,
Bucket: setting.S3Config.Bucket,
UsePathStyle: setting.S3Config.UsePathStyle,
}
}
return settingpb
}
func convertInstanceMemoRelatedSettingFromStore(setting *storepb.InstanceMemoRelatedSetting) *v1pb.InstanceSetting_MemoRelatedSetting {
if setting == nil {
return nil
}
return &v1pb.InstanceSetting_MemoRelatedSetting{
DisallowPublicVisibility: setting.DisallowPublicVisibility,
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
ContentLengthLimit: setting.ContentLengthLimit,
EnableDoubleClickEdit: setting.EnableDoubleClickEdit,
Reactions: setting.Reactions,
}
}
func convertInstanceMemoRelatedSettingToStore(setting *v1pb.InstanceSetting_MemoRelatedSetting) *storepb.InstanceMemoRelatedSetting {
if setting == nil {
return nil
}
return &storepb.InstanceMemoRelatedSetting{
DisallowPublicVisibility: setting.DisallowPublicVisibility,
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
ContentLengthLimit: setting.ContentLengthLimit,
EnableDoubleClickEdit: setting.EnableDoubleClickEdit,
Reactions: setting.Reactions,
}
}
func (s *APIV1Service) GetInstanceAdmin(ctx context.Context) (*v1pb.User, error) {
adminUserType := store.RoleAdmin
user, err := s.Store.GetUser(ctx, &store.FindUser{
Role: &adminUserType,
})
if err != nil {
return nil, errors.Wrapf(err, "failed to find admin")
}
if user == nil {
return nil, nil
}
return convertUserFromStore(user), nil
}

View File

@@ -0,0 +1,136 @@
package v1
import (
"context"
"slices"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
// Delete attachments that are not in the request.
for _, attachment := range attachments {
found := false
for _, requestAttachment := range request.Attachments {
requestAttachmentUID, err := ExtractAttachmentUIDFromName(requestAttachment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
}
if attachment.UID == requestAttachmentUID {
found = true
break
}
}
if !found {
if err = s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{
ID: int32(attachment.ID),
MemoID: &memo.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
}
}
}
slices.Reverse(request.Attachments)
// Update attachments' memo_id in the request.
for index, attachment := range request.Attachments {
attachmentUID, err := ExtractAttachmentUIDFromName(attachment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
}
tempAttachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
if tempAttachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found: %s", attachmentUID)
}
updatedTs := time.Now().Unix() + int64(index)
if err := s.Store.UpdateAttachment(ctx, &store.UpdateAttachment{
ID: tempAttachment.ID,
MemoID: &memo.ID,
UpdatedTs: &updatedTs,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.ListMemoAttachmentsRequest) (*v1pb.ListMemoAttachmentsResponse, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility.
if memo.Visibility != store.Public {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err)
}
response := &v1pb.ListMemoAttachmentsResponse{
Attachments: []*v1pb.Attachment{},
}
for _, attachment := range attachments {
response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment))
}
return response, nil
}

View File

@@ -0,0 +1,181 @@
package v1
import (
"context"
"fmt"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
referenceType := store.MemoRelationReference
// Delete all reference relations first.
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &memo.ID,
Type: &referenceType,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo relation")
}
for _, relation := range request.Relations {
// Ignore reflexive relations.
if request.Name == relation.RelatedMemo.Name {
continue
}
// Ignore comment relations as there's no need to update a comment's relation.
// Inserting/Deleting a comment is handled elsewhere.
if relation.Type == v1pb.MemoRelation_COMMENT {
continue
}
relatedMemoUID, err := ExtractMemoUIDFromName(relation.RelatedMemo.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &relatedMemoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get related memo")
}
if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo.ID,
RelatedMemoID: relatedMemo.ID,
Type: convertMemoRelationTypeToStore(relation.Type),
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert memo relation")
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.ListMemoRelationsRequest) (*v1pb.ListMemoRelationsResponse, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
var memoFilter string
if currentUser == nil {
memoFilter = `visibility == "PUBLIC"`
} else {
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
}
relationList := []*v1pb.MemoRelation{}
tempList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memo.ID,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo relations: %v", err)
}
for _, raw := range tempList {
relation, err := s.convertMemoRelationFromStore(ctx, raw)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert memo relation")
}
relationList = append(relationList, relation)
}
tempList, err = s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &memo.ID,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list related memo relations: %v", err)
}
for _, raw := range tempList {
relation, err := s.convertMemoRelationFromStore(ctx, raw)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert memo relation")
}
relationList = append(relationList, relation)
}
response := &v1pb.ListMemoRelationsResponse{
Relations: relationList,
}
return response, nil
}
func (s *APIV1Service) convertMemoRelationFromStore(ctx context.Context, memoRelation *store.MemoRelation) (*v1pb.MemoRelation, error) {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoRelation.MemoID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
memoSnippet, err := s.getMemoContentSnippet(memo.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get memo content snippet")
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoRelation.RelatedMemoID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get related memo: %v", err)
}
relatedMemoSnippet, err := s.getMemoContentSnippet(relatedMemo.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get related memo content snippet")
}
return &v1pb.MemoRelation{
Memo: &v1pb.MemoRelation_Memo{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
Snippet: memoSnippet,
},
RelatedMemo: &v1pb.MemoRelation_Memo{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
Snippet: relatedMemoSnippet,
},
Type: convertMemoRelationTypeFromStore(memoRelation.Type),
}, nil
}
func convertMemoRelationTypeFromStore(relationType store.MemoRelationType) v1pb.MemoRelation_Type {
switch relationType {
case store.MemoRelationReference:
return v1pb.MemoRelation_REFERENCE
case store.MemoRelationComment:
return v1pb.MemoRelation_COMMENT
default:
return v1pb.MemoRelation_TYPE_UNSPECIFIED
}
}
func convertMemoRelationTypeToStore(relationType v1pb.MemoRelation_Type) store.MemoRelationType {
switch relationType {
case v1pb.MemoRelation_COMMENT:
return store.MemoRelationComment
default:
return store.MemoRelationReference
}
}

View File

@@ -0,0 +1,866 @@
package v1
import (
"context"
"fmt"
"log/slog"
"strings"
"time"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/internal/base"
"github.com/usememos/memos/plugin/webhook"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/runner/memopayload"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Use custom memo_id if provided, otherwise generate a new UUID
memoUID := strings.TrimSpace(request.MemoId)
if memoUID == "" {
memoUID = shortuuid.New()
} else if !base.UIDMatcher.MatchString(memoUID) {
// Validate custom memo ID format
return nil, status.Errorf(codes.InvalidArgument, "invalid memo_id format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen")
}
create := &store.Memo{
UID: memoUID,
CreatorID: user.ID,
Content: request.Memo.Content,
Visibility: convertVisibilityToStore(request.Memo.Visibility),
}
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
// Handle display_time first: if provided, use it to set the appropriate timestamp
// based on the instance setting (similar to UpdateMemo logic)
// Note: explicit create_time/update_time below will override this if provided
if request.Memo.DisplayTime != nil && request.Memo.DisplayTime.IsValid() {
displayTs := request.Memo.DisplayTime.AsTime().Unix()
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
create.UpdatedTs = displayTs
} else {
create.CreatedTs = displayTs
}
}
// Set custom timestamps if provided in the request
// These take precedence over display_time
if request.Memo.CreateTime != nil && request.Memo.CreateTime.IsValid() {
createdTs := request.Memo.CreateTime.AsTime().Unix()
create.CreatedTs = createdTs
}
if request.Memo.UpdateTime != nil && request.Memo.UpdateTime.IsValid() {
updatedTs := request.Memo.UpdateTime.AsTime().Unix()
create.UpdatedTs = updatedTs
}
if instanceMemoRelatedSetting.DisallowPublicVisibility && create.Visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
contentLengthLimit, err := s.getContentLengthLimit(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
}
if len(create.Content) > contentLengthLimit {
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
}
if err := memopayload.RebuildMemoPayload(create, s.MarkdownService); err != nil {
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
}
if request.Memo.Location != nil {
create.Payload.Location = convertLocationToStore(request.Memo.Location)
}
memo, err := s.Store.CreateMemo(ctx, create)
if err != nil {
// Check for unique constraint violation (AIP-133 compliance)
errMsg := err.Error()
if strings.Contains(errMsg, "UNIQUE constraint failed") ||
strings.Contains(errMsg, "duplicate key") ||
strings.Contains(errMsg, "Duplicate entry") {
return nil, status.Errorf(codes.AlreadyExists, "memo with ID %q already exists", memoUID)
}
return nil, err
}
attachments := []*store.Attachment{}
if len(request.Memo.Attachments) > 0 {
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
Attachments: request.Memo.Attachments,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo attachments")
}
a, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo attachments")
}
attachments = a
}
if len(request.Memo.Relations) > 0 {
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
Relations: request.Memo.Relations,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo relations")
}
}
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
// Try to dispatch webhook when memo is created.
if err := s.DispatchMemoCreatedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err))
}
return memoMessage, nil
}
func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosRequest) (*v1pb.ListMemosResponse, error) {
memoFind := &store.FindMemo{
// Exclude comments by default.
ExcludeComments: true,
}
if request.State == v1pb.State_ARCHIVED {
state := store.Archived
memoFind.RowStatus = &state
} else {
state := store.Normal
memoFind.RowStatus = &state
}
// Parse order_by field (replaces the old sort and direction fields)
if request.OrderBy != "" {
if err := s.parseMemoOrderBy(request.OrderBy, memoFind); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid order_by: %v", err)
}
} else {
// Default ordering by display_time desc
memoFind.OrderByTimeAsc = false
}
if request.Filter != "" {
if err := s.validateFilter(ctx, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
memoFind.Filters = append(memoFind.Filters, request.Filter)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if currentUser == nil {
memoFind.VisibilityList = []store.Visibility{store.Public}
} else {
if memoFind.CreatorID == nil {
filter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
memoFind.Filters = append(memoFind.Filters, filter)
} else if *memoFind.CreatorID != currentUser.ID {
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
}
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
memoFind.OrderByUpdatedTs = true
}
var limit, offset int
if request.PageToken != "" {
var pageToken v1pb.PageToken
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
}
limit = int(pageToken.Limit)
offset = int(pageToken.Offset)
} else {
limit = int(request.PageSize)
}
if limit <= 0 {
limit = DefaultPageSize
}
limitPlusOne := limit + 1
memoFind.Limit = &limitPlusOne
memoFind.Offset = &offset
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
memoMessages := []*v1pb.Memo{}
nextPageToken := ""
if len(memos) == limitPlusOne {
memos = memos[:limit]
nextPageToken, err = getPageToken(limit, offset+limit)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
}
}
if len(memos) == 0 {
response := &v1pb.ListMemosResponse{
Memos: memoMessages,
NextPageToken: nextPageToken,
}
return response, nil
}
reactionMap := make(map[string][]*store.Reaction)
contentIDs := make([]string, 0, len(memos))
attachmentMap := make(map[int32][]*store.Attachment)
memoIDs := make([]int32, 0, len(memos))
for _, m := range memos {
contentIDs = append(contentIDs, fmt.Sprintf("%s%s", MemoNamePrefix, m.UID))
memoIDs = append(memoIDs, m.ID)
}
// REACTIONS
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ContentIDList: contentIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
for _, reaction := range reactions {
reactionMap[reaction.ContentID] = append(reactionMap[reaction.ContentID], reaction)
}
// ATTACHMENTS
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoIDList: memoIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
for _, attachment := range attachments {
attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment)
}
for _, memo := range memos {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
reactions := reactionMap[memoName]
attachments := attachmentMap[memo.ID]
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memoMessages = append(memoMessages, memoMessage)
}
response := &v1pb.ListMemosResponse{
Memos: memoMessages,
NextPageToken: nextPageToken,
}
return response, nil
}
func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest) (*v1pb.Memo, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
UID: &memoUID,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility != store.Public {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
return memoMessage, nil
}
func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoRequest) (*v1pb.Memo, error) {
memoUID, err := ExtractMemoUIDFromName(request.Memo.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Only the creator or admin can update the memo.
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
update := &store.UpdateMemo{
ID: memo.ID,
}
for _, path := range request.UpdateMask.Paths {
if path == "content" {
contentLengthLimit, err := s.getContentLengthLimit(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
}
if len(request.Memo.Content) > contentLengthLimit {
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
}
memo.Content = request.Memo.Content
if err := memopayload.RebuildMemoPayload(memo, s.MarkdownService); err != nil {
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
}
update.Content = &memo.Content
update.Payload = memo.Payload
} else if path == "visibility" {
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
visibility := convertVisibilityToStore(request.Memo.Visibility)
if instanceMemoRelatedSetting.DisallowPublicVisibility && visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
update.Visibility = &visibility
} else if path == "pinned" {
update.Pinned = &request.Memo.Pinned
} else if path == "state" {
rowStatus := convertStateToStore(request.Memo.State)
update.RowStatus = &rowStatus
} else if path == "create_time" {
createdTs := request.Memo.CreateTime.AsTime().Unix()
update.CreatedTs = &createdTs
} else if path == "update_time" {
updatedTs := time.Now().Unix()
if request.Memo.UpdateTime != nil {
updatedTs = request.Memo.UpdateTime.AsTime().Unix()
}
update.UpdatedTs = &updatedTs
} else if path == "display_time" {
displayTs := request.Memo.DisplayTime.AsTime().Unix()
memoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
if memoRelatedSetting.DisplayWithUpdateTime {
update.UpdatedTs = &displayTs
} else {
update.CreatedTs = &displayTs
}
} else if path == "location" {
payload := memo.Payload
payload.Location = convertLocationToStore(request.Memo.Location)
update.Payload = payload
} else if path == "attachments" {
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
Name: request.Memo.Name,
Attachments: request.Memo.Attachments,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo attachments")
}
} else if path == "relations" {
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: request.Memo.Name,
Relations: request.Memo.Relations,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo relations")
}
}
}
if err = s.Store.UpdateMemo(ctx, update); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo")
}
memo, err = s.Store.GetMemo(ctx, &store.FindMemo{
ID: &memo.ID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo")
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Memo.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
// Try to dispatch webhook when memo is updated.
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
}
return memoMessage, nil
}
func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoRequest) (*emptypb.Empty, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
UID: &memoUID,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Only the creator or admin can update the memo.
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments); err == nil {
// Try to dispatch webhook when memo is deleted.
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err))
}
}
// Delete memo comments first (store.DeleteMemo handles their relations and attachments)
commentType := store.MemoRelationComment
relations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{RelatedMemoID: &memo.ID, Type: &commentType})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo comments")
}
for _, relation := range relations {
if err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: relation.MemoID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo comment")
}
}
// Delete the memo (store.DeleteMemo handles relation and attachment cleanup)
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo")
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.CreateMemoCommentRequest) (*v1pb.Memo, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if relatedMemo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility before allowing comment.
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if relatedMemo.Visibility == store.Private && relatedMemo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// Create the memo comment first.
memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{
Memo: request.Comment,
MemoId: request.CommentId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create memo")
}
memoUID, err = ExtractMemoUIDFromName(memoComment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
// Build the relation between the comment memo and the original memo.
_, err = s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationComment,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create memo relation")
}
creatorID, err := ExtractUserIDFromName(memoComment.Creator)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo creator")
}
if memoComment.Visibility != v1pb.Visibility_PRIVATE && creatorID != relatedMemo.CreatorID {
activity, err := s.Store.CreateActivity(ctx, &store.Activity{
CreatorID: creatorID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{
MemoComment: &storepb.ActivityMemoCommentPayload{
MemoId: memo.ID,
RelatedMemoId: relatedMemo.ID,
},
},
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create activity")
}
if _, err := s.Store.CreateInbox(ctx, &store.Inbox{
SenderID: creatorID,
ReceiverID: relatedMemo.CreatorID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
ActivityId: &activity.ID,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to create inbox")
}
}
if err := s.DispatchMemoCommentCreatedWebhook(ctx, memoComment, relatedMemo.CreatorID); err != nil {
slog.Warn("Failed to dispatch memo comment created webhook", slog.Any("err", err))
}
return memoComment, nil
}
func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListMemoCommentsRequest) (*v1pb.ListMemoCommentsResponse, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
var memoFilter string
if currentUser == nil {
memoFilter = `visibility == "PUBLIC"`
} else {
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
}
memoRelationComment := store.MemoRelationComment
memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &memo.ID,
Type: &memoRelationComment,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo relations")
}
if len(memoRelations) == 0 {
response := &v1pb.ListMemoCommentsResponse{
Memos: []*v1pb.Memo{},
}
return response, nil
}
memoRelationIDs := make([]int32, 0, len(memoRelations))
for _, m := range memoRelations {
memoRelationIDs = append(memoRelationIDs, m.MemoID)
}
memos, err := s.Store.ListMemos(ctx, &store.FindMemo{IDList: memoRelationIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
memoIDToNameMap := make(map[int32]string)
contentIDs := make([]string, 0, len(memos))
memoIDsForAttachments := make([]int32, 0, len(memos))
for _, memo := range memos {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
memoIDToNameMap[memo.ID] = memoName
contentIDs = append(contentIDs, memoName)
memoIDsForAttachments = append(memoIDsForAttachments, memo.ID)
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ContentIDList: contentIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
memoReactionsMap := make(map[string][]*store.Reaction)
for _, reaction := range reactions {
memoReactionsMap[reaction.ContentID] = append(memoReactionsMap[reaction.ContentID], reaction)
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoIDList: memoIDsForAttachments})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
attachmentMap := make(map[int32][]*store.Attachment)
for _, attachment := range attachments {
attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment)
}
var memosResponse []*v1pb.Memo
for _, m := range memos {
memoName := memoIDToNameMap[m.ID]
reactions := memoReactionsMap[memoName]
attachments := attachmentMap[m.ID]
memoMessage, err := s.convertMemoFromStore(ctx, m, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memosResponse = append(memosResponse, memoMessage)
}
response := &v1pb.ListMemoCommentsResponse{
Memos: memosResponse,
}
return response, nil
}
func (s *APIV1Service) getContentLengthLimit(ctx context.Context) (int, error) {
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return 0, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
return int(instanceMemoRelatedSetting.ContentLengthLimit), nil
}
// DispatchMemoCreatedWebhook dispatches webhook when memo is created.
func (s *APIV1Service) DispatchMemoCreatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.created")
}
// DispatchMemoUpdatedWebhook dispatches webhook when memo is updated.
func (s *APIV1Service) DispatchMemoUpdatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.updated")
}
// DispatchMemoDeletedWebhook dispatches webhook when memo is deleted.
func (s *APIV1Service) DispatchMemoDeletedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.deleted")
}
// DispatchMemoCommentCreatedWebhook dispatches webhook to the related memo owner when a comment is created.
func (s *APIV1Service) DispatchMemoCommentCreatedWebhook(ctx context.Context, commentMemo *v1pb.Memo, relatedMemoCreatorID int32) error {
webhooks, err := s.Store.GetUserWebhooks(ctx, relatedMemoCreatorID)
if err != nil {
return err
}
for _, hook := range webhooks {
payload, err := convertMemoToWebhookPayload(commentMemo)
if err != nil {
return errors.Wrap(err, "failed to convert memo to webhook payload")
}
payload.ActivityType = "memos.memo.comment.created"
payload.URL = hook.Url
webhook.PostAsync(payload)
}
return nil
}
func (s *APIV1Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *v1pb.Memo, activityType string) error {
creatorID, err := ExtractUserIDFromName(memo.Creator)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid memo creator")
}
webhooks, err := s.Store.GetUserWebhooks(ctx, creatorID)
if err != nil {
return err
}
for _, hook := range webhooks {
payload, err := convertMemoToWebhookPayload(memo)
if err != nil {
return errors.Wrap(err, "failed to convert memo to webhook payload")
}
payload.ActivityType = activityType
payload.URL = hook.Url
// Use asynchronous webhook dispatch
webhook.PostAsync(payload)
}
return nil
}
func convertMemoToWebhookPayload(memo *v1pb.Memo) (*webhook.WebhookRequestPayload, error) {
creatorID, err := ExtractUserIDFromName(memo.Creator)
if err != nil {
return nil, errors.Wrap(err, "invalid memo creator")
}
return &webhook.WebhookRequestPayload{
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creatorID),
Memo: memo,
}, nil
}
func (s *APIV1Service) getMemoContentSnippet(content string) (string, error) {
// Use goldmark service for snippet generation
snippet, err := s.MarkdownService.GenerateSnippet([]byte(content), 64)
if err != nil {
return "", errors.Wrap(err, "failed to generate snippet")
}
return snippet, nil
}
// parseMemoOrderBy parses the order_by field and sets the appropriate ordering in memoFind.
// Follows AIP-132: supports comma-separated list of fields with optional "desc" suffix.
// Example: "pinned desc, display_time desc" or "create_time asc".
func (*APIV1Service) parseMemoOrderBy(orderBy string, memoFind *store.FindMemo) error {
if strings.TrimSpace(orderBy) == "" {
return errors.New("empty order_by")
}
// Split by comma to support multiple sort fields per AIP-132.
fields := strings.Split(orderBy, ",")
// Track if we've seen pinned field.
hasPinned := false
for _, field := range fields {
parts := strings.Fields(strings.TrimSpace(field))
if len(parts) == 0 {
continue
}
fieldName := parts[0]
fieldDirection := "desc" // default per AIP-132 (we use desc as default for time fields)
if len(parts) > 1 {
fieldDirection = strings.ToLower(parts[1])
if fieldDirection != "asc" && fieldDirection != "desc" {
return errors.Errorf("invalid order direction: %s, must be 'asc' or 'desc'", parts[1])
}
}
switch fieldName {
case "pinned":
hasPinned = true
memoFind.OrderByPinned = true
// Note: pinned is always DESC (true first) regardless of direction specified.
case "display_time", "create_time", "name":
// Only set if this is the first time field we encounter.
if !memoFind.OrderByUpdatedTs {
memoFind.OrderByTimeAsc = fieldDirection == "asc"
}
case "update_time":
memoFind.OrderByUpdatedTs = true
memoFind.OrderByTimeAsc = fieldDirection == "asc"
default:
return errors.Errorf("unsupported order field: %s, supported fields are: pinned, display_time, create_time, update_time, name", fieldName)
}
}
// If only pinned was specified, still need to set a default time ordering.
if hasPinned && !memoFind.OrderByUpdatedTs && len(fields) == 1 {
memoFind.OrderByTimeAsc = false // default to desc
}
return nil
}

View File

@@ -0,0 +1,134 @@
package v1
import (
"context"
"fmt"
"time"
"github.com/pkg/errors"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment) (*v1pb.Memo, error) {
displayTs := memo.CreatedTs
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get instance memo related setting")
}
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
name := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
memoMessage := &v1pb.Memo{
Name: name,
State: convertStateFromStore(memo.RowStatus),
Creator: fmt.Sprintf("%s%d", UserNamePrefix, memo.CreatorID),
CreateTime: timestamppb.New(time.Unix(memo.CreatedTs, 0)),
UpdateTime: timestamppb.New(time.Unix(memo.UpdatedTs, 0)),
DisplayTime: timestamppb.New(time.Unix(displayTs, 0)),
Content: memo.Content,
Visibility: convertVisibilityFromStore(memo.Visibility),
Pinned: memo.Pinned,
}
if memo.Payload != nil {
memoMessage.Tags = memo.Payload.Tags
memoMessage.Property = convertMemoPropertyFromStore(memo.Payload.Property)
memoMessage.Location = convertLocationFromStore(memo.Payload.Location)
}
if memo.ParentUID != nil {
parentName := fmt.Sprintf("%s%s", MemoNamePrefix, *memo.ParentUID)
memoMessage.Parent = &parentName
}
memoMessage.Reactions = []*v1pb.Reaction{}
for _, reaction := range reactions {
reactionResponse := convertReactionFromStore(reaction)
memoMessage.Reactions = append(memoMessage.Reactions, reactionResponse)
}
listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &v1pb.ListMemoRelationsRequest{Name: name})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo relations")
}
memoMessage.Relations = listMemoRelationsResponse.Relations
memoMessage.Attachments = []*v1pb.Attachment{}
for _, attachment := range attachments {
attachmentResponse := convertAttachmentFromStore(attachment)
memoMessage.Attachments = append(memoMessage.Attachments, attachmentResponse)
}
snippet, err := s.getMemoContentSnippet(memo.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get memo content snippet")
}
memoMessage.Snippet = snippet
return memoMessage, nil
}
func convertMemoPropertyFromStore(property *storepb.MemoPayload_Property) *v1pb.Memo_Property {
if property == nil {
return nil
}
return &v1pb.Memo_Property{
HasLink: property.HasLink,
HasTaskList: property.HasTaskList,
HasCode: property.HasCode,
HasIncompleteTasks: property.HasIncompleteTasks,
}
}
func convertLocationFromStore(location *storepb.MemoPayload_Location) *v1pb.Location {
if location == nil {
return nil
}
return &v1pb.Location{
Placeholder: location.Placeholder,
Latitude: location.Latitude,
Longitude: location.Longitude,
}
}
func convertLocationToStore(location *v1pb.Location) *storepb.MemoPayload_Location {
if location == nil {
return nil
}
return &storepb.MemoPayload_Location{
Placeholder: location.Placeholder,
Latitude: location.Latitude,
Longitude: location.Longitude,
}
}
func convertVisibilityFromStore(visibility store.Visibility) v1pb.Visibility {
switch visibility {
case store.Private:
return v1pb.Visibility_PRIVATE
case store.Protected:
return v1pb.Visibility_PROTECTED
case store.Public:
return v1pb.Visibility_PUBLIC
default:
return v1pb.Visibility_VISIBILITY_UNSPECIFIED
}
}
func convertVisibilityToStore(visibility v1pb.Visibility) store.Visibility {
switch visibility {
case v1pb.Visibility_PROTECTED:
return store.Protected
case v1pb.Visibility_PUBLIC:
return store.Public
default:
return store.Private
}
}

View File

@@ -0,0 +1 @@
package v1

View File

@@ -0,0 +1,153 @@
package v1
import (
"context"
"fmt"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.ListMemoReactionsRequest) (*v1pb.ListMemoReactionsResponse, error) {
// Extract memo UID and check visibility.
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility.
if memo.Visibility != store.Public {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
response := &v1pb.ListMemoReactionsResponse{
Reactions: []*v1pb.Reaction{},
}
for _, reaction := range reactions {
reactionMessage := convertReactionFromStore(reaction)
response.Reactions = append(response.Reactions, reactionMessage)
}
return response, nil
}
func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.UpsertMemoReactionRequest) (*v1pb.Reaction, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Extract memo UID and check visibility before allowing reaction.
memoUID, err := ExtractMemoUIDFromName(request.Reaction.ContentId)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility.
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: request.Reaction.ContentId,
ReactionType: request.Reaction.ReactionType,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert reaction")
}
reactionMessage := convertReactionFromStore(reaction)
return reactionMessage, nil
}
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
_, reactionID, err := ExtractMemoReactionIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
}
// Get reaction and check ownership.
reaction, err := s.Store.GetReaction(ctx, &store.FindReaction{
ID: &reactionID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get reaction")
}
if reaction == nil {
// Return permission denied to avoid revealing if reaction exists.
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if reaction.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
ID: reactionID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
}
return &emptypb.Empty{}, nil
}
func convertReactionFromStore(reaction *store.Reaction) *v1pb.Reaction {
reactionUID := fmt.Sprintf("%d", reaction.ID)
// Generate nested resource name: memos/{memo}/reactions/{reaction}
// reaction.ContentID already contains "memos/{memo}"
return &v1pb.Reaction{
Name: fmt.Sprintf("%s/%s%s", reaction.ContentID, ReactionNamePrefix, reactionUID),
Creator: fmt.Sprintf("%s%d", UserNamePrefix, reaction.CreatorID),
ContentId: reaction.ContentID,
ReactionType: reaction.ReactionType,
CreateTime: timestamppb.New(time.Unix(reaction.CreatedTs, 0)),
}
}

View File

@@ -0,0 +1,158 @@
package v1
import (
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
)
const (
InstanceSettingNamePrefix = "instance/settings/"
UserNamePrefix = "users/"
MemoNamePrefix = "memos/"
AttachmentNamePrefix = "attachments/"
ReactionNamePrefix = "reactions/"
InboxNamePrefix = "inboxes/"
IdentityProviderNamePrefix = "identity-providers/"
ActivityNamePrefix = "activities/"
WebhookNamePrefix = "webhooks/"
)
// GetNameParentTokens returns the tokens from a resource name.
func GetNameParentTokens(name string, tokenPrefixes ...string) ([]string, error) {
parts := strings.Split(name, "/")
if len(parts) != 2*len(tokenPrefixes) {
return nil, errors.Errorf("invalid request %q", name)
}
var tokens []string
for i, tokenPrefix := range tokenPrefixes {
if fmt.Sprintf("%s/", parts[2*i]) != tokenPrefix {
return nil, errors.Errorf("invalid prefix %q in request %q", tokenPrefix, name)
}
if parts[2*i+1] == "" {
return nil, errors.Errorf("invalid request %q with empty prefix %q", name, tokenPrefix)
}
tokens = append(tokens, parts[2*i+1])
}
return tokens, nil
}
func ExtractInstanceSettingKeyFromName(name string) (string, error) {
const prefix = "instance/settings/"
if !strings.HasPrefix(name, prefix) {
return "", errors.Errorf("invalid instance setting name: expected prefix %q, got %q", prefix, name)
}
settingKey := strings.TrimPrefix(name, prefix)
if settingKey == "" {
return "", errors.Errorf("invalid instance setting name: empty setting key in %q", name)
}
// Ensure there are no additional path segments
if strings.Contains(settingKey, "/") {
return "", errors.Errorf("invalid instance setting name: setting key cannot contain '/' in %q", name)
}
return settingKey, nil
}
// ExtractUserIDFromName returns the uid from a resource name.
func ExtractUserIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, UserNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid user ID %q", tokens[0])
}
return id, nil
}
// extractUserIdentifierFromName extracts the identifier (ID or username) from a user resource name.
// Supports: "users/101" or "users/steven"
// Returns the identifier string (e.g., "101" or "steven").
func extractUserIdentifierFromName(name string) string {
tokens, err := GetNameParentTokens(name, UserNamePrefix)
if err != nil || len(tokens) == 0 {
return ""
}
return tokens[0]
}
// ExtractMemoUIDFromName returns the memo UID from a resource name.
// e.g., "memos/uuid" -> "uuid".
func ExtractMemoUIDFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, MemoNamePrefix)
if err != nil {
return "", err
}
id := tokens[0]
return id, nil
}
// ExtractAttachmentUIDFromName returns the attachment UID from a resource name.
func ExtractAttachmentUIDFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, AttachmentNamePrefix)
if err != nil {
return "", err
}
id := tokens[0]
return id, nil
}
// ExtractMemoReactionIDFromName returns the memo UID and reaction ID from a resource name.
// e.g., "memos/abc/reactions/123" -> ("abc", 123).
func ExtractMemoReactionIDFromName(name string) (string, int32, error) {
tokens, err := GetNameParentTokens(name, MemoNamePrefix, ReactionNamePrefix)
if err != nil {
return "", 0, err
}
memoUID := tokens[0]
reactionID, err := util.ConvertStringToInt32(tokens[1])
if err != nil {
return "", 0, errors.Errorf("invalid reaction ID %q", tokens[1])
}
return memoUID, reactionID, nil
}
// ExtractInboxIDFromName returns the inbox ID from a resource name.
func ExtractInboxIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, InboxNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid inbox ID %q", tokens[0])
}
return id, nil
}
func ExtractIdentityProviderIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
}
return id, nil
}
func ExtractActivityIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, ActivityNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid activity ID %q", tokens[0])
}
return id, nil
}

View File

@@ -0,0 +1,346 @@
package v1
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/filter"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// Helper function to extract user ID and shortcut ID from shortcut resource name.
// Format: users/{user}/shortcuts/{shortcut}.
func extractUserAndShortcutIDFromName(name string) (int32, string, error) {
parts := strings.Split(name, "/")
if len(parts) != 4 || parts[0] != "users" || parts[2] != "shortcuts" {
return 0, "", errors.Errorf("invalid shortcut name format: %s", name)
}
userID, err := util.ConvertStringToInt32(parts[1])
if err != nil {
return 0, "", errors.Errorf("invalid user ID %q", parts[1])
}
shortcutID := parts[3]
if shortcutID == "" {
return 0, "", errors.Errorf("empty shortcut ID in name: %s", name)
}
return userID, shortcutID, nil
}
// Helper function to construct shortcut resource name.
func constructShortcutName(userID int32, shortcutID string) string {
return fmt.Sprintf("users/%d/shortcuts/%s", userID, shortcutID)
}
func (s *APIV1Service) ListShortcuts(ctx context.Context, request *v1pb.ListShortcutsRequest) (*v1pb.ListShortcutsResponse, error) {
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user setting: %v", err)
}
if userSetting == nil {
return &v1pb.ListShortcutsResponse{
Shortcuts: []*v1pb.Shortcut{},
}, nil
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := []*v1pb.Shortcut{}
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
shortcuts = append(shortcuts, &v1pb.Shortcut{
Name: constructShortcutName(userID, shortcut.GetId()),
Title: shortcut.GetTitle(),
Filter: shortcut.GetFilter(),
})
}
return &v1pb.ListShortcutsResponse{
Shortcuts: shortcuts,
}, nil
}
func (s *APIV1Service) GetShortcut(ctx context.Context, request *v1pb.GetShortcutRequest) (*v1pb.Shortcut, error) {
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting := userSetting.GetShortcuts()
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
if shortcut.GetId() == shortcutID {
return &v1pb.Shortcut{
Name: constructShortcutName(userID, shortcut.GetId()),
Title: shortcut.GetTitle(),
Filter: shortcut.GetFilter(),
}, nil
}
}
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateShortcutRequest) (*v1pb.Shortcut, error) {
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
newShortcut := &storepb.ShortcutsUserSetting_Shortcut{
Id: util.GenUUID(),
Title: request.Shortcut.GetTitle(),
Filter: request.Shortcut.GetFilter(),
}
if newShortcut.Title == "" {
return nil, status.Errorf(codes.InvalidArgument, "title is required")
}
if err := s.validateFilter(ctx, newShortcut.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
if request.ValidateOnly {
return &v1pb.Shortcut{
Name: constructShortcutName(userID, newShortcut.GetId()),
Title: newShortcut.GetTitle(),
Filter: newShortcut.GetFilter(),
}, nil
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
userSetting = &storepb.UserSetting{
UserId: userID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{
Shortcuts: &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{},
},
},
}
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := shortcutsUserSetting.GetShortcuts()
shortcuts = append(shortcuts, newShortcut)
shortcutsUserSetting.Shortcuts = shortcuts
userSetting.Value = &storepb.UserSetting_Shortcuts{
Shortcuts: shortcutsUserSetting,
}
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
return &v1pb.Shortcut{
Name: constructShortcutName(userID, newShortcut.GetId()),
Title: newShortcut.GetTitle(),
Filter: newShortcut.GetFilter(),
}, nil
}
func (s *APIV1Service) UpdateShortcut(ctx context.Context, request *v1pb.UpdateShortcutRequest) (*v1pb.Shortcut, error) {
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Shortcut.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := shortcutsUserSetting.GetShortcuts()
var foundShortcut *storepb.ShortcutsUserSetting_Shortcut
newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts))
for _, shortcut := range shortcuts {
if shortcut.GetId() == shortcutID {
foundShortcut = shortcut
for _, field := range request.UpdateMask.Paths {
if field == "title" {
if request.Shortcut.GetTitle() == "" {
return nil, status.Errorf(codes.InvalidArgument, "title is required")
}
shortcut.Title = request.Shortcut.GetTitle()
} else if field == "filter" {
if err := s.validateFilter(ctx, request.Shortcut.GetFilter()); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
shortcut.Filter = request.Shortcut.GetFilter()
}
}
}
newShortcuts = append(newShortcuts, shortcut)
}
if foundShortcut == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting.Shortcuts = newShortcuts
userSetting.Value = &storepb.UserSetting_Shortcuts{
Shortcuts: shortcutsUserSetting,
}
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
if err != nil {
return nil, err
}
return &v1pb.Shortcut{
Name: constructShortcutName(userID, foundShortcut.GetId()),
Title: foundShortcut.GetTitle(),
Filter: foundShortcut.GetFilter(),
}, nil
}
func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteShortcutRequest) (*emptypb.Empty, error) {
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := shortcutsUserSetting.GetShortcuts()
newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts))
found := false
for _, shortcut := range shortcuts {
if shortcut.GetId() != shortcutID {
newShortcuts = append(newShortcuts, shortcut)
} else {
found = true
}
}
if !found {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting.Shortcuts = newShortcuts
userSetting.Value = &storepb.UserSetting_Shortcuts{
Shortcuts: shortcutsUserSetting,
}
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) validateFilter(ctx context.Context, filterStr string) error {
if filterStr == "" {
return errors.New("filter cannot be empty")
}
engine, err := filter.DefaultEngine()
if err != nil {
return err
}
var dialect filter.DialectName
switch s.Profile.Driver {
case "mysql":
dialect = filter.DialectMySQL
case "postgres":
dialect = filter.DialectPostgres
default:
dialect = filter.DialectSQLite
}
if _, err := engine.CompileToStatement(ctx, filterStr, filter.RenderOptions{Dialect: dialect}); err != nil {
return errors.Wrap(err, "failed to compile filter")
}
return nil
}

View File

@@ -0,0 +1,263 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
v1 "github.com/usememos/memos/server/router/api/v1" //nolint:revive
"github.com/usememos/memos/store"
)
// TestListActivitiesWithDeletedMemos verifies that ListActivities gracefully handles
// activities that reference deleted memos instead of crashing the entire request.
func TestListActivitiesWithDeletedMemos(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users - one to create memo, one to comment
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
require.NoError(t, err)
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
require.NoError(t, err)
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
// Create a memo by userOne
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Original memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo1)
// Create a comment on the memo by userTwo (this will create an activity for userOne)
comment, err := ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
Name: memo1.Name,
Comment: &apiv1.Memo{
Content: "This is a comment",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, comment)
// Verify activity was created for the comment (check from userOne's perspective - they receive the notification)
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
initialActivityCount := len(activities.Activities)
require.Greater(t, initialActivityCount, 0, "Should have at least one activity")
// Delete the original memo (this deletes the comment too)
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
Name: memo1.Name,
})
require.NoError(t, err)
// List activities again - should succeed even though the memo is deleted
activities, err = ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
// Activities list should be empty or not contain the deleted memo activity
for _, activity := range activities.Activities {
if activity.Payload != nil && activity.Payload.GetMemoComment() != nil {
require.NotEqual(t, memo1.Name, activity.Payload.GetMemoComment().Memo,
"Activity should not reference deleted memo")
}
}
// After deletion, there should be fewer activities
require.LessOrEqual(t, len(activities.Activities), initialActivityCount-1,
"Should have filtered out the activity for the deleted memo")
}
// TestGetActivityWithDeletedMemo verifies that GetActivity returns a proper error
// when trying to fetch an activity that references a deleted memo.
func TestGetActivityWithDeletedMemo(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
require.NoError(t, err)
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
require.NoError(t, err)
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
// Create a memo by userOne
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Original memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo1)
// Create a comment to trigger activity creation by userTwo
comment, err := ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
Name: memo1.Name,
Comment: &apiv1.Memo{
Content: "Comment",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, comment)
// Get the activity ID by listing activities from userOne's perspective
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
require.Greater(t, len(activities.Activities), 0)
activityName := activities.Activities[0].Name
// Delete the memo
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
Name: memo1.Name,
})
require.NoError(t, err)
// Try to get the specific activity - should return NotFound error
_, err = ts.Service.GetActivity(userOneCtx, &apiv1.GetActivityRequest{
Name: activityName,
})
require.Error(t, err)
require.Contains(t, err.Error(), "activity references deleted content")
}
// TestActivitiesWithPartiallyDeletedMemos verifies that when some memos are deleted,
// other valid activities are still returned.
func TestActivitiesWithPartiallyDeletedMemos(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
require.NoError(t, err)
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
require.NoError(t, err)
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
// Create two memos by userOne
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "First memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
memo2, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Second memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
// Create comments on both by userTwo (creates activities for userOne)
_, err = ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
Name: memo1.Name,
Comment: &apiv1.Memo{
Content: "Comment on first",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
_, err = ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
Name: memo2.Name,
Comment: &apiv1.Memo{
Content: "Comment on second",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
// Should have 2 activities from userOne's perspective
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
require.Equal(t, 2, len(activities.Activities))
// Delete first memo
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
Name: memo1.Name,
})
require.NoError(t, err)
// List activities - should still work and return only the second memo's activity
activities, err = ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
require.Equal(t, 1, len(activities.Activities), "Should have 1 activity remaining")
// Verify the remaining activity relates to a valid memo
require.NotNil(t, activities.Activities[0].Payload.GetMemoComment())
require.Contains(t, activities.Activities[0].Payload.GetMemoComment().RelatedMemo, "memos/")
}
// TestActivityStoreDirectDeletion tests the scenario where a memo is deleted directly
// from the store (simulating database-level deletion or migration).
func TestActivityStoreDirectDeletion(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "test-user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a memo
memo1, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
// Create a comment
comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{
Name: memo1.Name,
Comment: &apiv1.Memo{
Content: "Test comment",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
// Extract memo UID from the comment name
commentMemoUID, err := v1.ExtractMemoUIDFromName(comment.Name)
require.NoError(t, err)
commentMemo, err := ts.Store.GetMemo(ctx, &store.FindMemo{
UID: &commentMemoUID,
})
require.NoError(t, err)
require.NotNil(t, commentMemo)
// Delete the comment memo directly from store (simulating orphaned activity)
err = ts.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: commentMemo.ID})
require.NoError(t, err)
// List activities should still succeed even with orphaned activity
activities, err := ts.Service.ListActivities(userCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
// Activities should be empty or not include the orphaned one
for _, activity := range activities.Activities {
if activity.Payload != nil && activity.Payload.GetMemoComment() != nil {
require.NotEqual(t, comment.Name, activity.Payload.GetMemoComment().Memo,
"Should not return activity with deleted memo")
}
}
}

View File

@@ -0,0 +1 @@
fake png content

View File

@@ -0,0 +1 @@
fake png content

View File

@@ -0,0 +1,2 @@
‰PNG


After

Width:  |  Height:  |  Size: 8 B

Binary file not shown.

View File

@@ -0,0 +1,59 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestCreateAttachment(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
ctx := context.Background()
user, err := ts.CreateRegularUser(ctx, "test_user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Test case 1: Create attachment with empty type but known extension
t.Run("EmptyType_KnownExtension", func(t *testing.T) {
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "test.png",
Content: []byte("fake png content"),
},
})
require.NoError(t, err)
require.Equal(t, "image/png", attachment.Type)
})
// Test case 2: Create attachment with empty type and unknown extension, but detectable content
t.Run("EmptyType_UnknownExtension_ContentSniffing", func(t *testing.T) {
// PNG magic header: 89 50 4E 47 0D 0A 1A 0A
pngContent := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "test.unknown",
Content: pngContent,
},
})
require.NoError(t, err)
require.Equal(t, "image/png", attachment.Type)
})
// Test case 3: Empty type, unknown extension, random content -> fallback to application/octet-stream
t.Run("EmptyType_Fallback", func(t *testing.T) {
randomContent := []byte{0x00, 0x01, 0x02, 0x03}
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "test.data",
Content: randomContent,
},
})
require.NoError(t, err)
require.Equal(t, "application/octet-stream", attachment.Type)
})
}

View File

@@ -0,0 +1,655 @@
package test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
func TestAuthenticatorAccessTokenV2(t *testing.T) {
ctx := context.Background()
t.Run("authenticates valid access token v2", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate access token v2
token, _, err := auth.GenerateAccessTokenV2(
user.ID,
user.Username,
string(user.Role),
string(user.RowStatus),
[]byte(ts.Secret),
)
require.NoError(t, err)
// Authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
claims, err := authenticator.AuthenticateByAccessTokenV2(token)
require.NoError(t, err)
assert.NotNil(t, claims)
assert.Equal(t, user.ID, claims.UserID)
assert.Equal(t, user.Username, claims.Username)
assert.Equal(t, string(user.Role), claims.Role)
assert.Equal(t, string(user.RowStatus), claims.Status)
})
t.Run("fails with invalid token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, err := authenticator.AuthenticateByAccessTokenV2("invalid-token")
assert.Error(t, err)
})
t.Run("fails with wrong secret", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate token with one secret
token, _, err := auth.GenerateAccessTokenV2(
user.ID,
user.Username,
string(user.Role),
string(user.RowStatus),
[]byte("secret-1"),
)
require.NoError(t, err)
// Try to authenticate with different secret
authenticator := auth.NewAuthenticator(ts.Store, "secret-2")
_, err = authenticator.AuthenticateByAccessTokenV2(token)
assert.Error(t, err)
})
}
func TestAuthenticatorRefreshToken(t *testing.T) {
ctx := context.Background()
t.Run("authenticates valid refresh token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create refresh token record in store
tokenID := util.GenUUID()
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(auth.RefreshTokenDuration)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord)
require.NoError(t, err)
// Generate refresh token JWT
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
require.NoError(t, err)
// Authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
authenticatedUser, returnedTokenID, err := authenticator.AuthenticateByRefreshToken(ctx, token)
require.NoError(t, err)
assert.NotNil(t, authenticatedUser)
assert.Equal(t, user.ID, authenticatedUser.ID)
assert.Equal(t, tokenID, returnedTokenID)
})
t.Run("fails with revoked token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
tokenID := util.GenUUID()
// Generate refresh token JWT but don't store it in database (simulates revocation)
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
require.NoError(t, err)
// Try to authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "revoked")
})
t.Run("fails with expired token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create expired refresh token record in store
tokenID := util.GenUUID()
expiredToken := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), // Expired
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, expiredToken)
require.NoError(t, err)
// Generate refresh token JWT (JWT itself isn't expired yet)
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
require.NoError(t, err)
// Try to authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
})
t.Run("fails with archived user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create valid refresh token
tokenID := util.GenUUID()
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(auth.RefreshTokenDuration)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord)
require.NoError(t, err)
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
require.NoError(t, err)
// Archive the user
archivedStatus := store.Archived
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
ID: user.ID,
RowStatus: &archivedStatus,
})
require.NoError(t, err)
// Try to authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "archived")
})
}
func TestAuthenticatorPAT(t *testing.T) {
ctx := context.Background()
t.Run("authenticates valid PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate PAT
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
// Store PAT in database
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
require.NoError(t, err)
// Authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
authenticatedUser, pat, err := authenticator.AuthenticateByPAT(ctx, token)
require.NoError(t, err)
assert.NotNil(t, authenticatedUser)
assert.NotNil(t, pat)
assert.Equal(t, user.ID, authenticatedUser.ID)
assert.Equal(t, tokenID, pat.TokenId)
})
t.Run("fails with invalid PAT format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err := authenticator.AuthenticateByPAT(ctx, "invalid-token-without-prefix")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid PAT format")
})
t.Run("fails with non-existent PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Generate a PAT but don't store it
token := auth.GeneratePersonalAccessToken()
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err := authenticator.AuthenticateByPAT(ctx, token)
assert.Error(t, err)
})
t.Run("fails with expired PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate and store expired PAT
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
expiredPAT := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Expired PAT",
ExpiresAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), // Expired
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, expiredPAT)
require.NoError(t, err)
// Try to authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err = authenticator.AuthenticateByPAT(ctx, token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
})
t.Run("succeeds with non-expiring PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate and store PAT without expiration
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Never-expiring PAT",
ExpiresAt: nil, // No expiration
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
require.NoError(t, err)
// Authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
authenticatedUser, pat, err := authenticator.AuthenticateByPAT(ctx, token)
require.NoError(t, err)
assert.NotNil(t, authenticatedUser)
assert.NotNil(t, pat)
assert.Nil(t, pat.ExpiresAt)
})
t.Run("fails with archived user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate and store PAT
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
require.NoError(t, err)
// Archive the user
archivedStatus := store.Archived
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
ID: user.ID,
RowStatus: &archivedStatus,
})
require.NoError(t, err)
// Try to authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err = authenticator.AuthenticateByPAT(ctx, token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "archived")
})
}
func TestStoreRefreshTokenMethods(t *testing.T) {
ctx := context.Background()
t.Run("adds and retrieves refresh token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
tokenID := util.GenUUID()
token := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
require.NoError(t, err)
// Retrieve tokens
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, tokens, 1)
assert.Equal(t, tokenID, tokens[0].TokenId)
})
t.Run("retrieves specific refresh token by ID", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
tokenID := util.GenUUID()
token := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
require.NoError(t, err)
// Retrieve specific token
retrievedToken, err := ts.Store.GetUserRefreshTokenByID(ctx, user.ID, tokenID)
require.NoError(t, err)
assert.NotNil(t, retrievedToken)
assert.Equal(t, tokenID, retrievedToken.TokenId)
})
t.Run("removes refresh token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
tokenID := util.GenUUID()
token := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
require.NoError(t, err)
// Remove token
err = ts.Store.RemoveUserRefreshToken(ctx, user.ID, tokenID)
require.NoError(t, err)
// Verify removal
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, tokens, 0)
})
t.Run("handles multiple refresh tokens", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Add multiple tokens
tokenID1 := util.GenUUID()
tokenID2 := util.GenUUID()
token1 := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID1,
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
CreatedAt: timestamppb.Now(),
}
token2 := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID2,
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token1)
require.NoError(t, err)
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token2)
require.NoError(t, err)
// Retrieve all tokens
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, tokens, 2)
// Remove one token
err = ts.Store.RemoveUserRefreshToken(ctx, user.ID, tokenID1)
require.NoError(t, err)
// Verify only one token remains
tokens, err = ts.Store.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, tokens, 1)
assert.Equal(t, tokenID2, tokens[0].TokenId)
})
}
func TestStorePersonalAccessTokenMethods(t *testing.T) {
ctx := context.Background()
t.Run("adds and retrieves PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Retrieve PATs
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, pats, 1)
assert.Equal(t, tokenID, pats[0].TokenId)
assert.Equal(t, tokenHash, pats[0].TokenHash)
})
t.Run("removes PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Remove PAT
err = ts.Store.RemoveUserPersonalAccessToken(ctx, user.ID, tokenID)
require.NoError(t, err)
// Verify removal
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, pats, 0)
})
t.Run("updates PAT last used time", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Update last used time
lastUsed := timestamppb.Now()
err = ts.Store.UpdatePATLastUsed(ctx, user.ID, tokenID, lastUsed)
require.NoError(t, err)
// Verify update
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, pats, 1)
assert.NotNil(t, pats[0].LastUsedAt)
})
t.Run("handles multiple PATs", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Add multiple PATs
token1 := auth.GeneratePersonalAccessToken()
tokenHash1 := auth.HashPersonalAccessToken(token1)
tokenID1 := util.GenUUID()
token2 := auth.GeneratePersonalAccessToken()
tokenHash2 := auth.HashPersonalAccessToken(token2)
tokenID2 := util.GenUUID()
pat1 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID1,
TokenHash: tokenHash1,
Description: "PAT 1",
CreatedAt: timestamppb.Now(),
}
pat2 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID2,
TokenHash: tokenHash2,
Description: "PAT 2",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat1)
require.NoError(t, err)
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat2)
require.NoError(t, err)
// Retrieve all PATs
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, pats, 2)
// Remove one PAT
err = ts.Store.RemoveUserPersonalAccessToken(ctx, user.ID, tokenID1)
require.NoError(t, err)
// Verify only one PAT remains
pats, err = ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, pats, 1)
assert.Equal(t, tokenID2, pats[0].TokenId)
})
t.Run("finds user by PAT hash", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Find user by PAT hash
result, err := ts.Store.GetUserByPATHash(ctx, tokenHash)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, user.ID, result.UserID)
assert.NotNil(t, result.User)
assert.Equal(t, user.Username, result.User.Username)
assert.NotNil(t, result.PAT)
assert.Equal(t, tokenID, result.PAT.TokenId)
})
}

View File

@@ -0,0 +1,552 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestCreateIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("CreateIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
ctx := ts.CreateUserContext(ctx, hostUser.ID)
// Create OAuth2 identity provider
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test OAuth2 Provider",
IdentifierFilter: "",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "test-client-id",
ClientSecret: "test-client-secret",
AuthUrl: "https://example.com/oauth/authorize",
TokenUrl: "https://example.com/oauth/token",
UserInfoUrl: "https://example.com/oauth/userinfo",
Scopes: []string{"openid", "profile", "email"},
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
DisplayName: "name",
Email: "email",
AvatarUrl: "avatar_url",
},
},
},
},
},
}
resp, err := ts.Service.CreateIdentityProvider(ctx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "Test OAuth2 Provider", resp.Title)
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
require.Contains(t, resp.Name, "identity-providers/")
require.NotNil(t, resp.Config.GetOauth2Config())
require.Equal(t, "test-client-id", resp.Config.GetOauth2Config().ClientId)
})
t.Run("CreateIdentityProvider permission denied for non-host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
// Set user context
ctx := ts.CreateUserContext(ctx, regularUser.ID)
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err = ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("CreateIdentityProvider unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err := ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "user not authenticated")
})
}
func TestListIdentityProviders(t *testing.T) {
ctx := context.Background()
t.Run("ListIdentityProviders empty", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.ListIdentityProvidersRequest{}
resp, err := ts.Service.ListIdentityProviders(ctx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Empty(t, resp.IdentityProviders)
})
t.Run("ListIdentityProviders with providers", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a couple of identity providers
createReq1 := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Provider 1",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "client1",
AuthUrl: "https://example1.com/auth",
TokenUrl: "https://example1.com/token",
UserInfoUrl: "https://example1.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
createReq2 := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Provider 2",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "client2",
AuthUrl: "https://example2.com/auth",
TokenUrl: "https://example2.com/token",
UserInfoUrl: "https://example2.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq1)
require.NoError(t, err)
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq2)
require.NoError(t, err)
// List providers
listReq := &v1pb.ListIdentityProvidersRequest{}
resp, err := ts.Service.ListIdentityProviders(ctx, listReq)
require.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.IdentityProviders, 2)
// Verify response contains expected providers
titles := []string{resp.IdentityProviders[0].Title, resp.IdentityProviders[1].Title}
require.Contains(t, titles, "Provider 1")
require.Contains(t, titles, "Provider 2")
})
}
func TestGetIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("GetIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "test-client",
ClientSecret: "test-secret",
AuthUrl: "https://example.com/auth",
TokenUrl: "https://example.com/token",
UserInfoUrl: "https://example.com/user",
Scopes: []string{"openid", "profile"},
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
DisplayName: "name",
Email: "email",
},
},
},
},
},
}
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
require.NoError(t, err)
// Get identity provider
getReq := &v1pb.GetIdentityProviderRequest{
Name: created.Name,
}
// Test unauthenticated, should not contain client secret
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, created.Name, resp.Name)
require.Equal(t, "Test Provider", resp.Title)
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
require.NotNil(t, resp.Config.GetOauth2Config())
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
require.Equal(t, "", resp.Config.GetOauth2Config().ClientSecret)
// Test as host user, should contain client secret
respHostUser, err := ts.Service.GetIdentityProvider(userCtx, getReq)
require.NoError(t, err)
require.NotNil(t, respHostUser)
require.Equal(t, created.Name, respHostUser.Name)
require.Equal(t, "Test Provider", respHostUser.Title)
require.Equal(t, v1pb.IdentityProvider_OAUTH2, respHostUser.Type)
require.NotNil(t, respHostUser.Config.GetOauth2Config())
require.Equal(t, "test-client", respHostUser.Config.GetOauth2Config().ClientId)
require.Equal(t, "test-secret", respHostUser.Config.GetOauth2Config().ClientSecret)
})
t.Run("GetIdentityProvider not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.GetIdentityProviderRequest{
Name: "identity-providers/999",
}
_, err := ts.Service.GetIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
t.Run("GetIdentityProvider invalid name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.GetIdentityProviderRequest{
Name: "invalid-name",
}
_, err := ts.Service.GetIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
}
func TestUpdateIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("UpdateIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Original Provider",
IdentifierFilter: "",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "original-client",
AuthUrl: "https://original.com/auth",
TokenUrl: "https://original.com/token",
UserInfoUrl: "https://original.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
require.NoError(t, err)
// Update identity provider
updateReq := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: created.Name,
Title: "Updated Provider",
IdentifierFilter: "test@example.com",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "updated-client",
ClientSecret: "updated-secret",
AuthUrl: "https://updated.com/auth",
TokenUrl: "https://updated.com/token",
UserInfoUrl: "https://updated.com/user",
Scopes: []string{"openid", "profile", "email"},
FieldMapping: &v1pb.FieldMapping{
Identifier: "sub",
DisplayName: "given_name",
Email: "email",
AvatarUrl: "picture",
},
},
},
},
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "identifier_filter", "config"},
},
}
updated, err := ts.Service.UpdateIdentityProvider(userCtx, updateReq)
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "Updated Provider", updated.Title)
require.Equal(t, "test@example.com", updated.IdentifierFilter)
require.Equal(t, "updated-client", updated.Config.GetOauth2Config().ClientId)
})
t.Run("UpdateIdentityProvider missing update mask", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: "identity-providers/1",
Title: "Updated Provider",
},
}
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "update_mask is required")
})
t.Run("UpdateIdentityProvider invalid name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: "invalid-name",
Title: "Updated Provider",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title"},
},
}
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
}
func TestDeleteIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("DeleteIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Provider to Delete",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "client-to-delete",
AuthUrl: "https://example.com/auth",
TokenUrl: "https://example.com/token",
UserInfoUrl: "https://example.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
require.NoError(t, err)
// Delete identity provider
deleteReq := &v1pb.DeleteIdentityProviderRequest{
Name: created.Name,
}
_, err = ts.Service.DeleteIdentityProvider(userCtx, deleteReq)
require.NoError(t, err)
// Verify deletion
getReq := &v1pb.GetIdentityProviderRequest{
Name: created.Name,
}
_, err = ts.Service.GetIdentityProvider(ctx, getReq)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
t.Run("DeleteIdentityProvider invalid name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.DeleteIdentityProviderRequest{
Name: "invalid-name",
}
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
t.Run("DeleteIdentityProvider not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.DeleteIdentityProviderRequest{
Name: "identity-providers/999",
}
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
require.Error(t, err)
// Note: Delete might succeed even if item doesn't exist, depending on store implementation
})
}
func TestIdentityProviderPermissions(t *testing.T) {
ctx := context.Background()
t.Run("Only host users can create identity providers", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err = ts.Service.CreateIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("Authentication required", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err := ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "user not authenticated")
})
}

View File

@@ -0,0 +1,54 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestInstanceAdminRetrieval(t *testing.T) {
ctx := context.Background()
t.Run("Instance becomes initialized after first admin user is created", func(t *testing.T) {
// Create test service
ts := NewTestService(t)
defer ts.Cleanup()
// Verify instance is not initialized initially
profile1, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
require.NoError(t, err)
require.Nil(t, profile1.Admin, "Instance should not be initialized before first admin user")
// Create the first admin user
user, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
require.NotNil(t, user)
// Verify instance is now initialized
profile2, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
require.NoError(t, err)
require.NotNil(t, profile2.Admin, "Instance should be initialized after first admin user is created")
require.Equal(t, user.Username, profile2.Admin.Username)
})
t.Run("Admin retrieval is cached by Store layer", func(t *testing.T) {
// Create test service
ts := NewTestService(t)
defer ts.Cleanup()
// Create admin user
user, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Multiple calls should return consistent admin user (from cache)
for i := 0; i < 5; i++ {
profile, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
require.NoError(t, err)
require.NotNil(t, profile.Admin)
require.Equal(t, user.Username, profile.Admin.Username)
}
})
}

View File

@@ -0,0 +1,204 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestGetInstanceProfile(t *testing.T) {
ctx := context.Background()
t.Run("GetInstanceProfile returns instance profile", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetInstanceProfile directly
req := &v1pb.GetInstanceProfileRequest{}
resp, err := ts.Service.GetInstanceProfile(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
// Verify the response contains expected data
require.Equal(t, "test-1.0.0", resp.Version)
require.True(t, resp.Demo)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
// Instance should not be initialized since no admin users are created
require.Nil(t, resp.Admin)
})
t.Run("GetInstanceProfile with initialized instance", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user in the store
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
require.NotNil(t, hostUser)
// Call GetInstanceProfile directly
req := &v1pb.GetInstanceProfileRequest{}
resp, err := ts.Service.GetInstanceProfile(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
// Verify the response contains expected data with initialized flag
require.Equal(t, "test-1.0.0", resp.Version)
require.True(t, resp.Demo)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
// Instance should be initialized since an admin user exists
require.NotNil(t, resp.Admin)
require.Equal(t, hostUser.Username, resp.Admin.Username)
})
}
func TestGetInstanceProfile_Concurrency(t *testing.T) {
ctx := context.Background()
t.Run("Concurrent access to service", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user
_, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Make concurrent requests
numGoroutines := 10
results := make(chan *v1pb.InstanceProfile, numGoroutines)
errors := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
req := &v1pb.GetInstanceProfileRequest{}
resp, err := ts.Service.GetInstanceProfile(ctx, req)
if err != nil {
errors <- err
return
}
results <- resp
}()
}
// Collect all results
for i := 0; i < numGoroutines; i++ {
select {
case err := <-errors:
t.Fatalf("Goroutine returned error: %v", err)
case resp := <-results:
require.NotNil(t, resp)
require.Equal(t, "test-1.0.0", resp.Version)
require.True(t, resp.Demo)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
require.NotNil(t, resp.Admin)
}
}
})
}
func TestGetInstanceSetting(t *testing.T) {
ctx := context.Background()
t.Run("GetInstanceSetting - general setting", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetInstanceSetting for general setting
req := &v1pb.GetInstanceSettingRequest{
Name: "instance/settings/GENERAL",
}
resp, err := ts.Service.GetInstanceSetting(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "instance/settings/GENERAL", resp.Name)
// The general setting should have a general_setting field
generalSetting := resp.GetGeneralSetting()
require.NotNil(t, generalSetting)
// General setting should have default values
require.False(t, generalSetting.DisallowUserRegistration)
require.False(t, generalSetting.DisallowPasswordAuth)
require.Empty(t, generalSetting.AdditionalScript)
})
t.Run("GetInstanceSetting - storage setting", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user for storage setting access
hostUser, err := ts.CreateHostUser(ctx, "testhost")
require.NoError(t, err)
// Add user to context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Call GetInstanceSetting for storage setting
req := &v1pb.GetInstanceSettingRequest{
Name: "instance/settings/STORAGE",
}
resp, err := ts.Service.GetInstanceSetting(userCtx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "instance/settings/STORAGE", resp.Name)
// The storage setting should have a storage_setting field
storageSetting := resp.GetStorageSetting()
require.NotNil(t, storageSetting)
})
t.Run("GetInstanceSetting - memo related setting", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetInstanceSetting for memo related setting
req := &v1pb.GetInstanceSettingRequest{
Name: "instance/settings/MEMO_RELATED",
}
resp, err := ts.Service.GetInstanceSetting(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "instance/settings/MEMO_RELATED", resp.Name)
// The memo related setting should have a memo_related_setting field
memoRelatedSetting := resp.GetMemoRelatedSetting()
require.NotNil(t, memoRelatedSetting)
})
t.Run("GetInstanceSetting - invalid setting name", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetInstanceSetting with invalid name
req := &v1pb.GetInstanceSettingRequest{
Name: "invalid/setting/name",
}
_, err := ts.Service.GetInstanceSetting(ctx, req)
// Should return an error
require.Error(t, err)
require.Contains(t, err.Error(), "invalid instance setting name")
})
}

View File

@@ -0,0 +1,166 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestSetMemoAttachments(t *testing.T) {
ctx := context.Background()
t.Run("SetMemoAttachments success by memo owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create attachment
attachment, err := ts.Service.CreateAttachment(userCtx, &apiv1.CreateAttachmentRequest{
Attachment: &apiv1.Attachment{
Filename: "test.txt",
Size: 5,
Type: "text/plain",
Content: []byte("hello"),
},
})
require.NoError(t, err)
require.NotNil(t, attachment)
// Set memo attachments - should succeed
_, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{
{Name: attachment.Name},
},
})
require.NoError(t, err)
})
t.Run("SetMemoAttachments success by host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create memo by regular user
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Host user can modify attachments - should succeed
_, err = ts.Service.SetMemoAttachments(hostCtx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{},
})
require.NoError(t, err)
})
t.Run("SetMemoAttachments permission denied for non-owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user1
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
// Create user2
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
// Create memo by user1
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// User2 tries to modify attachments - should fail
_, err = ts.Service.SetMemoAttachments(user2Ctx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("SetMemoAttachments unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Unauthenticated user tries to modify attachments - should fail
_, err = ts.Service.SetMemoAttachments(ctx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
})
t.Run("SetMemoAttachments memo not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Try to set attachments on non-existent memo - should fail
_, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{
Name: "memos/nonexistent-uid-12345",
Attachments: []*apiv1.Attachment{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}

View File

@@ -0,0 +1,169 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestSetMemoRelations(t *testing.T) {
ctx := context.Background()
t.Run("SetMemoRelations success by memo owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo1
memo1, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo 1",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo1)
// Create memo2
memo2, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo 2",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo2)
// Set memo relations - should succeed
_, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{
Name: memo1.Name,
Relations: []*apiv1.MemoRelation{
{
RelatedMemo: &apiv1.MemoRelation_Memo{
Name: memo2.Name,
},
Type: apiv1.MemoRelation_REFERENCE,
},
},
})
require.NoError(t, err)
})
t.Run("SetMemoRelations success by host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create memo by regular user
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Host user can modify relations - should succeed
_, err = ts.Service.SetMemoRelations(hostCtx, &apiv1.SetMemoRelationsRequest{
Name: memo.Name,
Relations: []*apiv1.MemoRelation{},
})
require.NoError(t, err)
})
t.Run("SetMemoRelations permission denied for non-owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user1
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
// Create user2
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
// Create memo by user1
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// User2 tries to modify relations - should fail
_, err = ts.Service.SetMemoRelations(user2Ctx, &apiv1.SetMemoRelationsRequest{
Name: memo.Name,
Relations: []*apiv1.MemoRelation{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("SetMemoRelations unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Unauthenticated user tries to modify relations - should fail
_, err = ts.Service.SetMemoRelations(ctx, &apiv1.SetMemoRelationsRequest{
Name: memo.Name,
Relations: []*apiv1.MemoRelation{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
})
t.Run("SetMemoRelations memo not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Try to set relations on non-existent memo - should fail
_, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{
Name: "memos/nonexistent-uid-12345",
Relations: []*apiv1.MemoRelation{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}

View File

@@ -0,0 +1,369 @@
package test
import (
"context"
"fmt"
"slices"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestListMemos(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create userOne
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
require.NoError(t, err)
require.NotNil(t, userOne)
// Create userOne context
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
// Create userTwo
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
require.NoError(t, err)
require.NotNil(t, userTwo)
// Create userTwo context
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
// Create attachmentOne by userOne
attachmentOne, err := ts.Service.CreateAttachment(userOneCtx, &apiv1.CreateAttachmentRequest{
Attachment: &apiv1.Attachment{
Name: "",
Filename: "hello.txt",
Size: 5,
Type: "text/plain",
Content: []byte{
104, 101, 108, 108, 111,
},
},
})
require.NoError(t, err)
require.NotNil(t, attachmentOne)
// Create attachmentTwo by userOne
attachmentTwo, err := ts.Service.CreateAttachment(userOneCtx, &apiv1.CreateAttachmentRequest{
Attachment: &apiv1.Attachment{
Name: "",
Filename: "world.txt",
Size: 5,
Type: "text/plain",
Content: []byte{
119, 111, 114, 108, 100,
},
},
})
require.NoError(t, err)
require.NotNil(t, attachmentTwo)
// Create memoOne with two attachments by userOne
memoOne, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Hellooo, any words after this sentence won't be in the snippet. This is the next sentence. And I also have two attachments.",
Visibility: apiv1.Visibility_PROTECTED,
Attachments: []*apiv1.Attachment{
&apiv1.Attachment{
Name: attachmentOne.Name,
},
&apiv1.Attachment{
Name: attachmentTwo.Name,
},
},
},
})
require.NoError(t, err)
require.NotNil(t, memoOne)
// Create memoTwo by userTwo referencing memoOne
memoTwo, err := ts.Service.CreateMemo(userTwoCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This is a memo reminding you to check the attachment attached to memoOne. I have referenced the memo below.⬇️",
Visibility: apiv1.Visibility_PROTECTED,
Relations: []*apiv1.MemoRelation{
&apiv1.MemoRelation{
RelatedMemo: &apiv1.MemoRelation_Memo{
Name: memoOne.Name,
},
},
},
},
})
require.NoError(t, err)
require.NotNil(t, memoTwo)
// Create memoThree by userOne
memoThree, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This is a very popular memo. I have 2 reactions!",
Visibility: apiv1.Visibility_PROTECTED,
},
})
require.NoError(t, err)
require.NotNil(t, memoThree)
// Create reaction from userOne on memoThree
reactionOne, err := ts.Service.UpsertMemoReaction(userOneCtx, &apiv1.UpsertMemoReactionRequest{
Name: memoThree.Name,
Reaction: &apiv1.Reaction{
ContentId: memoThree.Name,
ReactionType: "❤️",
},
})
require.NoError(t, err)
require.NotNil(t, reactionOne)
// Create reaction from userTwo on memoThree
reactionTwo, err := ts.Service.UpsertMemoReaction(userTwoCtx, &apiv1.UpsertMemoReactionRequest{
Name: memoThree.Name,
Reaction: &apiv1.Reaction{
ContentId: memoThree.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reactionTwo)
memos, err := ts.Service.ListMemos(userOneCtx, &apiv1.ListMemosRequest{PageSize: 10})
require.NoError(t, err)
require.NotNil(t, memos)
require.Equal(t, 3, len(memos.Memos))
// ///////////////
// VERIFY MEMO ONE
// ///////////////
memoOneResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoOne.GetName() })
require.NotEqual(t, memoOneResIdx, -1)
memoOneRes := memos.Memos[memoOneResIdx]
require.NotNil(t, memoOneRes)
require.Equal(t, fmt.Sprintf("users/%d", userOne.ID), memoOneRes.GetCreator())
require.Equal(t, apiv1.Visibility_PROTECTED, memoOneRes.GetVisibility())
require.Equal(t, memoOne.Content, memoOneRes.GetContent())
require.Equal(t, memoOne.Content[:64]+"...", memoOneRes.GetSnippet(), "memoOne's content is snipped past the 64 char limit")
require.Len(t, memoOneRes.Attachments, 2)
require.Len(t, memoOneRes.Relations, 1)
require.Empty(t, memoOneRes.Reactions)
// verify memoOne's attachments
// attachment one
attachmentOneResIdx := slices.IndexFunc(memoOneRes.Attachments, func(a *apiv1.Attachment) bool { return a.GetName() == attachmentOne.GetName() })
require.NotEqual(t, attachmentOneResIdx, -1)
attachmentOneRes := memoOneRes.Attachments[attachmentOneResIdx]
require.NotNil(t, attachmentOneRes)
require.Equal(t, attachmentOne.GetName(), attachmentOneRes.GetName())
require.Equal(t, attachmentOne.GetContent(), attachmentOneRes.GetContent())
// attachment two
attachmentTwoResIdx := slices.IndexFunc(memoOneRes.Attachments, func(a *apiv1.Attachment) bool { return a.GetName() == attachmentTwo.GetName() })
require.NotEqual(t, attachmentTwoResIdx, -1)
attachmentTwoRes := memoOneRes.Attachments[attachmentTwoResIdx]
require.NotNil(t, attachmentTwoRes)
require.Equal(t, attachmentTwo.GetName(), attachmentTwoRes.GetName())
require.Equal(t, attachmentTwo.GetName(), attachmentTwoRes.GetName())
require.Equal(t, attachmentTwo.GetContent(), attachmentTwoRes.GetContent())
// verify memoOne's relations
require.Len(t, memoOneRes.Relations, 1)
memoOneExpectedRelation := &apiv1.MemoRelation{
Memo: &apiv1.MemoRelation_Memo{Name: memoTwo.GetName()},
RelatedMemo: &apiv1.MemoRelation_Memo{Name: memoOne.GetName()},
}
require.Equal(t, memoOneExpectedRelation.Memo.GetName(), memoOneRes.Relations[0].Memo.GetName())
require.Equal(t, memoOneExpectedRelation.RelatedMemo.GetName(), memoOneRes.Relations[0].RelatedMemo.GetName())
// ///////////////
// VERIFY MEMO TWO
// ///////////////
memoTwoResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoTwo.GetName() })
require.NotEqual(t, memoTwoResIdx, -1)
memoTwoRes := memos.Memos[memoTwoResIdx]
require.NotNil(t, memoTwoRes)
require.Equal(t, fmt.Sprintf("users/%d", userTwo.ID), memoTwoRes.GetCreator())
require.Equal(t, apiv1.Visibility_PROTECTED, memoTwoRes.GetVisibility())
require.Equal(t, memoTwo.Content, memoTwoRes.GetContent())
require.Empty(t, memoTwoRes.Attachments)
require.Len(t, memoTwoRes.Relations, 1)
require.Empty(t, memoTwoRes.Reactions)
// verify memoTwo's relations
require.Len(t, memoTwoRes.Relations, 1)
memoTwoExpectedRelation := &apiv1.MemoRelation{
Memo: &apiv1.MemoRelation_Memo{Name: memoTwo.GetName()},
RelatedMemo: &apiv1.MemoRelation_Memo{Name: memoOne.GetName()},
}
require.Equal(t, memoTwoExpectedRelation.Memo.GetName(), memoTwoRes.Relations[0].Memo.GetName())
require.Equal(t, memoTwoExpectedRelation.RelatedMemo.GetName(), memoTwoRes.Relations[0].RelatedMemo.GetName())
// ///////////////
// VERIFY MEMO THREE
// ///////////////
memoThreeResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoThree.GetName() })
require.NotEqual(t, memoThreeResIdx, -1)
memoThreeRes := memos.Memos[memoThreeResIdx]
require.NotNil(t, memoThreeRes)
require.Equal(t, fmt.Sprintf("users/%d", userOne.ID), memoThreeRes.GetCreator())
require.Equal(t, apiv1.Visibility_PROTECTED, memoThreeRes.GetVisibility())
require.Equal(t, memoThree.Content, memoThreeRes.GetContent())
require.Empty(t, memoThreeRes.Attachments)
require.Empty(t, memoThreeRes.Relations)
require.Len(t, memoThreeRes.Reactions, 2)
// verify memoThree's reactions
require.Len(t, memoThreeRes.Reactions, 2)
// userOne's reaction
userOneReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%d", userOne.ID) })
require.NotEqual(t, userOneReactionIdx, -1)
userOneReaction := memoThreeRes.Reactions[userOneReactionIdx]
require.NotNil(t, userOneReaction)
require.Equal(t, "❤️", userOneReaction.ReactionType)
// userTwo's reaction
userTwoReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%d", userTwo.ID) })
require.NotEqual(t, userTwoReactionIdx, -1)
userTwoReaction := memoThreeRes.Reactions[userTwoReactionIdx]
require.NotNil(t, userTwoReaction)
require.Equal(t, "👍", userTwoReaction.ReactionType)
}
// TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments.
// This addresses issue #5483: https://github.com/usememos/memos/issues/5483
func TestCreateMemoWithCustomTimestamps(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test user
user, err := ts.CreateRegularUser(ctx, "test-user-timestamps")
require.NoError(t, err)
require.NotNil(t, user)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Define custom timestamps (January 1, 2020)
customCreateTime := time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC)
customUpdateTime := time.Date(2020, 1, 2, 12, 0, 0, 0, time.UTC)
customDisplayTime := time.Date(2020, 1, 3, 12, 0, 0, 0, time.UTC)
// Test 1: Create a memo with custom create_time
memoWithCreateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has a custom creation time",
Visibility: apiv1.Visibility_PRIVATE,
CreateTime: timestamppb.New(customCreateTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithCreateTime)
require.Equal(t, customCreateTime.Unix(), memoWithCreateTime.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
// Test 2: Create a memo with custom update_time
memoWithUpdateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has a custom update time",
Visibility: apiv1.Visibility_PRIVATE,
UpdateTime: timestamppb.New(customUpdateTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithUpdateTime)
require.Equal(t, customUpdateTime.Unix(), memoWithUpdateTime.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
// Test 3: Create a memo with custom display_time
// Note: display_time is computed from either created_ts or updated_ts based on instance setting
// Since DisplayWithUpdateTime defaults to false, display_time maps to created_ts
memoWithDisplayTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has a custom display time",
Visibility: apiv1.Visibility_PRIVATE,
DisplayTime: timestamppb.New(customDisplayTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithDisplayTime)
// Since DisplayWithUpdateTime is false by default, display_time sets created_ts
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.DisplayTime.AsTime().Unix(), "display_time should match the custom timestamp")
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.CreateTime.AsTime().Unix(), "create_time should also match since display_time maps to created_ts")
// Test 4: Create a memo with all custom timestamps
// When both display_time and create_time are provided, create_time takes precedence
memoWithAllTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has all custom timestamps",
Visibility: apiv1.Visibility_PRIVATE,
CreateTime: timestamppb.New(customCreateTime),
UpdateTime: timestamppb.New(customUpdateTime),
DisplayTime: timestamppb.New(customDisplayTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithAllTimestamps)
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
require.Equal(t, customUpdateTime.Unix(), memoWithAllTimestamps.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
// display_time is computed from created_ts when DisplayWithUpdateTime is false
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.DisplayTime.AsTime().Unix(), "display_time should be derived from create_time")
// Test 5: Create a comment (memo relation) with custom timestamps
parentMemo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This is the parent memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, parentMemo)
customCommentCreateTime := time.Date(2021, 6, 15, 10, 30, 0, 0, time.UTC)
comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{
Name: parentMemo.Name,
Comment: &apiv1.Memo{
Content: "This is a comment with custom create time",
Visibility: apiv1.Visibility_PRIVATE,
CreateTime: timestamppb.New(customCommentCreateTime),
},
})
require.NoError(t, err)
require.NotNil(t, comment)
require.Equal(t, customCommentCreateTime.Unix(), comment.CreateTime.AsTime().Unix(), "comment create_time should match the custom timestamp")
// Test 6: Verify that memos without custom timestamps still get auto-generated ones
memoWithoutTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has auto-generated timestamps",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memoWithoutTimestamps)
require.NotNil(t, memoWithoutTimestamps.CreateTime, "create_time should be auto-generated")
require.NotNil(t, memoWithoutTimestamps.UpdateTime, "update_time should be auto-generated")
require.True(t, time.Now().Unix()-memoWithoutTimestamps.CreateTime.AsTime().Unix() < 5, "create_time should be recent (within 5 seconds)")
}

View File

@@ -0,0 +1,194 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestDeleteMemoReaction(t *testing.T) {
ctx := context.Background()
t.Run("DeleteMemoReaction success by reaction owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction
reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// Delete reaction - should succeed
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.NoError(t, err)
})
t.Run("DeleteMemoReaction success by host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create memo by regular user
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction by regular user
reaction, err := ts.Service.UpsertMemoReaction(regularUserCtx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// Host user can delete reaction - should succeed
_, err = ts.Service.DeleteMemoReaction(hostCtx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.NoError(t, err)
})
t.Run("DeleteMemoReaction permission denied for non-owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user1
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
// Create user2
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
// Create memo by user1
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction by user1
reaction, err := ts.Service.UpsertMemoReaction(user1Ctx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// User2 tries to delete reaction - should fail with permission denied
_, err = ts.Service.DeleteMemoReaction(user2Ctx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("DeleteMemoReaction unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction
reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// Unauthenticated user tries to delete reaction - should fail
_, err = ts.Service.DeleteMemoReaction(ctx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
})
t.Run("DeleteMemoReaction not found returns permission denied", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Try to delete non-existent reaction - should fail with permission denied
// (not "not found" to avoid information disclosure)
// Use new nested resource format: memos/{memo}/reactions/{reaction}
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
Name: "memos/nonexistent/reactions/99999",
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
require.NotContains(t, err.Error(), "not found")
})
}

View File

@@ -0,0 +1,819 @@
package test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestListShortcuts(t *testing.T) {
ctx := context.Background()
t.Run("ListShortcuts success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// List shortcuts (should be empty initially)
req := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
resp, err := ts.Service.ListShortcuts(userCtx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Empty(t, resp.Shortcuts)
})
t.Run("ListShortcuts permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Set user1 context but try to list user2's shortcuts
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user2.ID),
}
_, err = ts.Service.ListShortcuts(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("ListShortcuts invalid parent format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.ListShortcutsRequest{
Parent: "invalid-parent-format",
}
_, err = ts.Service.ListShortcuts(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid user name")
})
t.Run("ListShortcuts unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.ListShortcutsRequest{
Parent: "users/1",
}
_, err := ts.Service.ListShortcuts(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
}
func TestGetShortcut(t *testing.T) {
ctx := context.Background()
t.Run("GetShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// First create a shortcut
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Test Shortcut",
Filter: "tag in [\"test\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Now get the shortcut
getReq := &v1pb.GetShortcutRequest{
Name: created.Name,
}
resp, err := ts.Service.GetShortcut(userCtx, getReq)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, created.Name, resp.Name)
require.Equal(t, "Test Shortcut", resp.Title)
require.Equal(t, "tag in [\"test\"]", resp.Filter)
})
t.Run("GetShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
Title: "User1 Shortcut",
Filter: "tag in [\"user1\"]",
},
}
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
require.NoError(t, err)
// Try to get shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
getReq := &v1pb.GetShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.GetShortcut(user2Ctx, getReq)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("GetShortcut invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.GetShortcutRequest{
Name: "invalid-shortcut-name",
}
_, err = ts.Service.GetShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid shortcut name")
})
t.Run("GetShortcut not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.GetShortcutRequest{
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
}
_, err = ts.Service.GetShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}
func TestCreateShortcut(t *testing.T) {
ctx := context.Background()
t.Run("CreateShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "My Shortcut",
Filter: "tag in [\"important\"]",
},
}
resp, err := ts.Service.CreateShortcut(userCtx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "My Shortcut", resp.Title)
require.Equal(t, "tag in [\"important\"]", resp.Filter)
require.Contains(t, resp.Name, fmt.Sprintf("users/%d/shortcuts/", user.ID))
// Verify the shortcut was created by listing
listReq := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Len(t, listResp.Shortcuts, 1)
require.Equal(t, "My Shortcut", listResp.Shortcuts[0].Title)
})
t.Run("CreateShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Set user1 context but try to create shortcut for user2
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user2.ID),
Shortcut: &v1pb.Shortcut{
Title: "Forbidden Shortcut",
Filter: "tag in [\"forbidden\"]",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("CreateShortcut invalid parent format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: "invalid-parent",
Shortcut: &v1pb.Shortcut{
Title: "Test Shortcut",
Filter: "tag in [\"test\"]",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid user name")
})
t.Run("CreateShortcut invalid filter", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Invalid Filter Shortcut",
Filter: "invalid||filter))syntax",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid filter")
})
t.Run("CreateShortcut missing title", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Filter: "tag in [\"test\"]",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "title is required")
})
}
func TestUpdateShortcut(t *testing.T) {
ctx := context.Background()
t.Run("UpdateShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Original Title",
Filter: "tag in [\"original\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Update the shortcut
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created.Name,
Title: "Updated Title",
Filter: "tag in [\"updated\"]",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "filter"},
},
}
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "Updated Title", updated.Title)
require.Equal(t, "tag in [\"updated\"]", updated.Filter)
require.Equal(t, created.Name, updated.Name)
})
t.Run("UpdateShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
Title: "User1 Shortcut",
Filter: "tag in [\"user1\"]",
},
}
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
require.NoError(t, err)
// Try to update shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created.Name,
Title: "Hacked Title",
Filter: "tag in [\"hacked\"]",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "filter"},
},
}
_, err = ts.Service.UpdateShortcut(user2Ctx, updateReq)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("UpdateShortcut missing update mask", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user and context for authentication
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: fmt.Sprintf("users/%d/shortcuts/test", user.ID),
Title: "Updated Title",
},
}
_, err = ts.Service.UpdateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "update mask is required")
})
t.Run("UpdateShortcut invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: "invalid-shortcut-name",
Title: "Updated Title",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title"},
},
}
_, err := ts.Service.UpdateShortcut(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid shortcut name")
})
t.Run("UpdateShortcut invalid filter", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Test Shortcut",
Filter: "tag in [\"test\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Try to update with invalid filter
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created.Name,
Filter: "invalid||filter))syntax",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"filter"},
},
}
_, err = ts.Service.UpdateShortcut(userCtx, updateReq)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid filter")
})
}
func TestDeleteShortcut(t *testing.T) {
ctx := context.Background()
t.Run("DeleteShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Shortcut to Delete",
Filter: "tag in [\"delete\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Delete the shortcut
deleteReq := &v1pb.DeleteShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
require.NoError(t, err)
// Verify deletion by listing shortcuts
listReq := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Empty(t, listResp.Shortcuts)
// Also verify by trying to get the deleted shortcut
getReq := &v1pb.GetShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.GetShortcut(userCtx, getReq)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
t.Run("DeleteShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
Title: "User1 Shortcut",
Filter: "tag in [\"user1\"]",
},
}
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
require.NoError(t, err)
// Try to delete shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
deleteReq := &v1pb.DeleteShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.DeleteShortcut(user2Ctx, deleteReq)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("DeleteShortcut invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.DeleteShortcutRequest{
Name: "invalid-shortcut-name",
}
_, err := ts.Service.DeleteShortcut(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid shortcut name")
})
t.Run("DeleteShortcut not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.DeleteShortcutRequest{
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
}
_, err = ts.Service.DeleteShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}
func TestShortcutFiltering(t *testing.T) {
ctx := context.Background()
t.Run("CreateShortcut with valid filters", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Test various valid filter formats
validFilters := []string{
"tag in [\"work\"]",
"content.contains(\"meeting\")",
"tag in [\"work\"] && content.contains(\"meeting\")",
"tag in [\"work\"] || tag in [\"personal\"]",
"creator_id == 1",
"visibility == \"PUBLIC\"",
"has_task_list == true",
"has_task_list == false",
}
for i, filter := range validFilters {
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Valid Filter " + string(rune(i)),
Filter: filter,
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.NoError(t, err, "Filter should be valid: %s", filter)
}
})
t.Run("CreateShortcut with invalid filters", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Test various invalid filter formats
invalidFilters := []string{
"tag in ", // incomplete expression
"invalid_field @in [\"value\"]", // unknown field
"tag in [\"work\"] &&", // incomplete expression
"tag in [\"work\"] || || tag in [\"test\"]", // double operator
"((tag in [\"work\"]", // unmatched parentheses
"tag in [\"work\"] && )", // mismatched parentheses
"tag == \"work\"", // wrong operator (== not supported for tags)
"tag in work", // missing brackets
}
for _, filter := range invalidFilters {
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Invalid Filter Test",
Filter: filter,
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err, "Filter should be invalid: %s", filter)
require.Contains(t, err.Error(), "invalid filter", "Error should mention invalid filter for: %s", filter)
}
})
}
func TestShortcutCRUDComplete(t *testing.T) {
ctx := context.Background()
t.Run("Complete CRUD lifecycle", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// 1. Create multiple shortcuts
shortcut1Req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Work Notes",
Filter: "tag in [\"work\"]",
},
}
shortcut2Req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Personal Notes",
Filter: "tag in [\"personal\"]",
},
}
created1, err := ts.Service.CreateShortcut(userCtx, shortcut1Req)
require.NoError(t, err)
require.Equal(t, "Work Notes", created1.Title)
created2, err := ts.Service.CreateShortcut(userCtx, shortcut2Req)
require.NoError(t, err)
require.Equal(t, "Personal Notes", created2.Title)
// 2. List shortcuts and verify both exist
listReq := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Len(t, listResp.Shortcuts, 2)
// 3. Get individual shortcuts
getReq1 := &v1pb.GetShortcutRequest{Name: created1.Name}
getResp1, err := ts.Service.GetShortcut(userCtx, getReq1)
require.NoError(t, err)
require.Equal(t, created1.Name, getResp1.Name)
require.Equal(t, "Work Notes", getResp1.Title)
getReq2 := &v1pb.GetShortcutRequest{Name: created2.Name}
getResp2, err := ts.Service.GetShortcut(userCtx, getReq2)
require.NoError(t, err)
require.Equal(t, created2.Name, getResp2.Name)
require.Equal(t, "Personal Notes", getResp2.Title)
// 4. Update one shortcut
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created1.Name,
Title: "Work & Meeting Notes",
Filter: "tag in [\"work\"] || tag in [\"meeting\"]",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "filter"},
},
}
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
require.NoError(t, err)
require.Equal(t, "Work & Meeting Notes", updated.Title)
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", updated.Filter)
// 5. Verify update by getting it again
getUpdatedReq := &v1pb.GetShortcutRequest{Name: created1.Name}
getUpdatedResp, err := ts.Service.GetShortcut(userCtx, getUpdatedReq)
require.NoError(t, err)
require.Equal(t, "Work & Meeting Notes", getUpdatedResp.Title)
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", getUpdatedResp.Filter)
// 6. Delete one shortcut
deleteReq := &v1pb.DeleteShortcutRequest{
Name: created2.Name,
}
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
require.NoError(t, err)
// 7. Verify deletion by listing (should only have 1 left)
finalListResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Len(t, finalListResp.Shortcuts, 1)
require.Equal(t, "Work & Meeting Notes", finalListResp.Shortcuts[0].Title)
// 8. Verify deleted shortcut can't be accessed
getDeletedReq := &v1pb.GetShortcutRequest{Name: created2.Name}
_, err = ts.Service.GetShortcut(userCtx, getDeletedReq)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}

View File

@@ -0,0 +1,86 @@
package test
import (
"context"
"testing"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/plugin/markdown"
"github.com/usememos/memos/server/auth"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
teststore "github.com/usememos/memos/store/test"
)
// TestService holds the test service setup for API v1 services.
type TestService struct {
Service *apiv1.APIV1Service
Store *store.Store
Profile *profile.Profile
Secret string
}
// NewTestService creates a new test service with SQLite database.
func NewTestService(t *testing.T) *TestService {
ctx := context.Background()
// Create a test store with SQLite
testStore := teststore.NewTestingStore(ctx, t)
// Create a test profile
testProfile := &profile.Profile{
Demo: true,
Version: "test-1.0.0",
InstanceURL: "http://localhost:8080",
Driver: "sqlite",
DSN: ":memory:",
}
// Create APIV1Service with nil grpcServer since we're testing direct calls
secret := "test-secret"
markdownService := markdown.NewService(
markdown.WithTagExtension(),
)
service := &apiv1.APIV1Service{
Secret: secret,
Profile: testProfile,
Store: testStore,
MarkdownService: markdownService,
}
return &TestService{
Service: service,
Store: testStore,
Profile: testProfile,
Secret: secret,
}
}
// Cleanup closes resources after test.
func (ts *TestService) Cleanup() {
ts.Store.Close()
}
// CreateHostUser creates an admin user for testing.
func (ts *TestService) CreateHostUser(ctx context.Context, username string) (*store.User, error) {
return ts.Store.CreateUser(ctx, &store.User{
Username: username,
Role: store.RoleAdmin,
Email: username + "@example.com",
})
}
// CreateRegularUser creates a regular user for testing.
func (ts *TestService) CreateRegularUser(ctx context.Context, username string) (*store.User, error) {
return ts.Store.CreateUser(ctx, &store.User{
Username: username,
Role: store.RoleUser,
Email: username + "@example.com",
})
}
// CreateUserContext creates a context with the given user's ID for authentication.
func (*TestService) CreateUserContext(ctx context.Context, userID int32) context.Context {
// Use the context key from the auth package
return context.WithValue(ctx, auth.UserIDContextKey, userID)
}

View File

@@ -0,0 +1,173 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestCreateUserRegistration(t *testing.T) {
ctx := context.Background()
t.Run("CreateUser success when registration enabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// User registration is enabled by default, no need to set it explicitly
// Create user without authentication - should succeed
_, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.NoError(t, err)
})
t.Run("CreateUser blocked when registration disabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user first so we're not in first-user setup mode
_, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Disable user registration
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
DisallowUserRegistration: true,
},
},
})
require.NoError(t, err)
// Try to create user without authentication - should fail
_, err = ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not allowed")
})
t.Run("CreateUser succeeds for superuser even when registration disabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Disable user registration
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
DisallowUserRegistration: true,
},
},
})
require.NoError(t, err)
// Host user can create users even when registration is disabled - should succeed
_, err = ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.NoError(t, err)
})
t.Run("CreateUser regular user cannot create users when registration disabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Disable user registration
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
DisallowUserRegistration: true,
},
},
})
require.NoError(t, err)
// Regular user tries to create user when registration is disabled - should fail
_, err = ts.Service.CreateUser(regularUserCtx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not allowed")
})
t.Run("CreateUser host can assign roles", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Host user can create user with specific role - should succeed
createdUser, err := ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newadmin",
Email: "newadmin@example.com",
Password: "password123",
Role: apiv1.User_ADMIN,
},
})
require.NoError(t, err)
require.NotNil(t, createdUser)
require.Equal(t, apiv1.User_ADMIN, createdUser.Role)
})
t.Run("CreateUser unauthenticated user can only create regular user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user first so we're not in first-user setup mode
_, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// User registration is enabled by default
// Unauthenticated user tries to create admin user - role should be ignored
createdUser, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "wannabeadmin",
Email: "wannabeadmin@example.com",
Password: "password123",
Role: apiv1.User_ADMIN, // This should be ignored
},
})
require.NoError(t, err)
require.NotNil(t, createdUser)
require.Equal(t, apiv1.User_USER, createdUser.Role, "Unauthenticated users can only create USER role")
})
}

View File

@@ -0,0 +1,105 @@
package test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func TestGetUserStats_TagCount(t *testing.T) {
ctx := context.Background()
// Create test service
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test host user
user, err := ts.CreateHostUser(ctx, "test_user")
require.NoError(t, err)
// Create user context for authentication
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a memo with a single tag
memo, err := ts.Store.CreateMemo(ctx, &store.Memo{
UID: "test-memo-1",
CreatorID: user.ID,
Content: "This is a test memo with #test tag",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: []string{"test"},
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Test GetUserStats
userName := fmt.Sprintf("users/%d", user.ID)
response, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
Name: userName,
})
require.NoError(t, err)
require.NotNil(t, response)
// Check that the tag count is exactly 1, not 2
require.Contains(t, response.TagCount, "test")
require.Equal(t, int32(1), response.TagCount["test"], "Tag count should be 1 for a single occurrence")
// Create another memo with the same tag
memo2, err := ts.Store.CreateMemo(ctx, &store.Memo{
UID: "test-memo-2",
CreatorID: user.ID,
Content: "Another memo with #test tag",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: []string{"test"},
},
})
require.NoError(t, err)
require.NotNil(t, memo2)
// Test GetUserStats again
response2, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
Name: userName,
})
require.NoError(t, err)
require.NotNil(t, response2)
// Check that the tag count is exactly 2, not 3
require.Contains(t, response2.TagCount, "test")
require.Equal(t, int32(2), response2.TagCount["test"], "Tag count should be 2 for two occurrences")
// Test with a new unique tag
memo3, err := ts.Store.CreateMemo(ctx, &store.Memo{
UID: "test-memo-3",
CreatorID: user.ID,
Content: "Memo with #unique tag",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: []string{"unique"},
},
})
require.NoError(t, err)
require.NotNil(t, memo3)
// Test GetUserStats for the new tag
response3, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
Name: userName,
})
require.NoError(t, err)
require.NotNil(t, response3)
// Check that the unique tag count is exactly 1
require.Contains(t, response3.TagCount, "unique")
require.Equal(t, int32(1), response3.TagCount["unique"], "New tag count should be 1 for first occurrence")
// The original test tag should still be 2
require.Contains(t, response3.TagCount, "test")
require.Equal(t, int32(2), response3.TagCount["test"], "Original tag count should remain 2")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,236 @@
package v1
import (
"context"
"fmt"
"time"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUserStatsRequest) (*v1pb.ListAllUserStatsResponse, error) {
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get instance memo related setting")
}
normalStatus := store.Normal
memoFind := &store.FindMemo{
// Exclude comments by default.
ExcludeComments: true,
ExcludeContent: true,
RowStatus: &normalStatus,
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
memoFind.VisibilityList = []store.Visibility{store.Public}
} else {
if memoFind.CreatorID == nil {
filter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
memoFind.Filters = append(memoFind.Filters, filter)
} else if *memoFind.CreatorID != currentUser.ID {
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
}
userMemoStatMap := make(map[int32]*v1pb.UserStats)
limit := 1000
offset := 0
memoFind.Limit = &limit
memoFind.Offset = &offset
for {
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
if len(memos) == 0 {
break
}
for _, memo := range memos {
// Initialize user stats if not exists
if _, exists := userMemoStatMap[memo.CreatorID]; !exists {
userMemoStatMap[memo.CreatorID] = &v1pb.UserStats{
Name: fmt.Sprintf("users/%d/stats", memo.CreatorID),
TagCount: make(map[string]int32),
MemoDisplayTimestamps: []*timestamppb.Timestamp{},
PinnedMemos: []string{},
MemoTypeStats: &v1pb.UserStats_MemoTypeStats{
LinkCount: 0,
CodeCount: 0,
TodoCount: 0,
UndoCount: 0,
},
}
}
stats := userMemoStatMap[memo.CreatorID]
// Add display timestamp
displayTs := memo.CreatedTs
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
stats.MemoDisplayTimestamps = append(stats.MemoDisplayTimestamps, timestamppb.New(time.Unix(displayTs, 0)))
// Count memo stats
stats.TotalMemoCount++
// Count tags and other properties
if memo.Payload != nil {
for _, tag := range memo.Payload.Tags {
stats.TagCount[tag]++
}
if memo.Payload.Property != nil {
if memo.Payload.Property.HasLink {
stats.MemoTypeStats.LinkCount++
}
if memo.Payload.Property.HasCode {
stats.MemoTypeStats.CodeCount++
}
if memo.Payload.Property.HasTaskList {
stats.MemoTypeStats.TodoCount++
}
if memo.Payload.Property.HasIncompleteTasks {
stats.MemoTypeStats.UndoCount++
}
}
}
// Track pinned memos
if memo.Pinned {
stats.PinnedMemos = append(stats.PinnedMemos, fmt.Sprintf("users/%d/memos/%d", memo.CreatorID, memo.ID))
}
}
offset += limit
}
userMemoStats := []*v1pb.UserStats{}
for _, userMemoStat := range userMemoStatMap {
userMemoStats = append(userMemoStats, userMemoStat)
}
response := &v1pb.ListAllUserStatsResponse{
Stats: userMemoStats,
}
return response, nil
}
func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserStatsRequest) (*v1pb.UserStats, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
normalStatus := store.Normal
memoFind := &store.FindMemo{
CreatorID: &userID,
// Exclude comments by default.
ExcludeComments: true,
ExcludeContent: true,
RowStatus: &normalStatus,
}
if currentUser == nil {
memoFind.VisibilityList = []store.Visibility{store.Public}
} else if currentUser.ID != userID {
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get instance memo related setting")
}
displayTimestamps := []*timestamppb.Timestamp{}
tagCount := make(map[string]int32)
linkCount := int32(0)
codeCount := int32(0)
todoCount := int32(0)
undoCount := int32(0)
pinnedMemos := []string{}
totalMemoCount := int32(0)
limit := 1000
offset := 0
memoFind.Limit = &limit
memoFind.Offset = &offset
for {
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
if len(memos) == 0 {
break
}
totalMemoCount += int32(len(memos))
for _, memo := range memos {
displayTs := memo.CreatedTs
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
displayTimestamps = append(displayTimestamps, timestamppb.New(time.Unix(displayTs, 0)))
// Count different memo types based on content.
if memo.Payload != nil {
for _, tag := range memo.Payload.Tags {
tagCount[tag]++
}
if memo.Payload.Property != nil {
if memo.Payload.Property.HasLink {
linkCount++
}
if memo.Payload.Property.HasCode {
codeCount++
}
if memo.Payload.Property.HasTaskList {
todoCount++
}
if memo.Payload.Property.HasIncompleteTasks {
undoCount++
}
}
}
if memo.Pinned {
pinnedMemos = append(pinnedMemos, fmt.Sprintf("users/%d/memos/%d", userID, memo.ID))
}
}
offset += limit
}
userStats := &v1pb.UserStats{
Name: fmt.Sprintf("users/%d/stats", userID),
MemoDisplayTimestamps: displayTimestamps,
TagCount: tagCount,
PinnedMemos: pinnedMemos,
TotalMemoCount: totalMemoCount,
MemoTypeStats: &v1pb.UserStats_MemoTypeStats{
LinkCount: linkCount,
CodeCount: codeCount,
TodoCount: todoCount,
UndoCount: undoCount,
},
}
return userStats, nil
}

175
server/router/api/v1/v1.go Normal file
View File

@@ -0,0 +1,175 @@
package v1
import (
"context"
"net/http"
"connectrpc.com/connect"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"golang.org/x/sync/semaphore"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/plugin/markdown"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
type APIV1Service struct {
v1pb.UnimplementedInstanceServiceServer
v1pb.UnimplementedAuthServiceServer
v1pb.UnimplementedUserServiceServer
v1pb.UnimplementedMemoServiceServer
v1pb.UnimplementedAttachmentServiceServer
v1pb.UnimplementedShortcutServiceServer
v1pb.UnimplementedActivityServiceServer
v1pb.UnimplementedIdentityProviderServiceServer
Secret string
Profile *profile.Profile
Store *store.Store
MarkdownService markdown.Service
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
thumbnailSemaphore *semaphore.Weighted
}
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store) *APIV1Service {
markdownService := markdown.NewService(
markdown.WithTagExtension(),
)
return &APIV1Service{
Secret: secret,
Profile: profile,
Store: store,
MarkdownService: markdownService,
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
}
}
// RegisterGateway registers the gRPC-Gateway and Connect handlers with the given Echo instance.
func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error {
// Auth middleware for gRPC-Gateway - runs after routing, has access to method name.
// Uses the same PublicMethods config as the Connect AuthInterceptor.
authenticator := auth.NewAuthenticator(s.Store, s.Secret)
gatewayAuthMiddleware := func(next runtime.HandlerFunc) runtime.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
ctx := r.Context()
// Get the RPC method name from context (set by grpc-gateway after routing)
rpcMethod, ok := runtime.RPCMethod(ctx)
// Extract credentials from HTTP headers
authHeader := r.Header.Get("Authorization")
result := authenticator.Authenticate(ctx, authHeader)
// Enforce authentication for non-public methods
// If rpcMethod cannot be determined, allow through, service layer will handle visibility checks
if result == nil && ok && !IsPublicMethod(rpcMethod) {
http.Error(w, `{"code": 16, "message": "authentication required"}`, http.StatusUnauthorized)
return
}
// Apply auth result to context (no-op when result is nil for public endpoints)
if result != nil {
ctx = auth.ApplyToContext(ctx, result)
r = r.WithContext(ctx)
}
next(w, r, pathParams)
}
}
// Create gRPC-Gateway mux with auth middleware.
gwMux := runtime.NewServeMux(
runtime.WithMiddlewares(gatewayAuthMiddleware),
)
if err := v1pb.RegisterInstanceServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterAuthServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterUserServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterMemoServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterAttachmentServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterShortcutServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterActivityServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterIdentityProviderServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
gwGroup := echoServer.Group("")
gwGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"},
}))
handler := echo.WrapHandler(gwMux)
gwGroup.Any("/api/v1/*", handler)
gwGroup.Any("/file/*", handler)
// Connect handlers for browser clients (replaces grpc-web).
logStacktraces := s.Profile.Demo
connectInterceptors := connect.WithInterceptors(
NewMetadataInterceptor(), // Convert HTTP headers to gRPC metadata first
NewLoggingInterceptor(logStacktraces),
NewRecoveryInterceptor(logStacktraces),
NewAuthInterceptor(s.Store, s.Secret),
)
connectMux := http.NewServeMux()
connectHandler := NewConnectServiceHandler(s)
connectHandler.RegisterConnectHandlers(connectMux, connectInterceptors)
// Wrap with CORS for browser access
corsHandler := middleware.CORSWithConfig(middleware.CORSConfig{
UnsafeAllowOriginFunc: func(_ *echo.Context, origin string) (string, bool, error) {
return origin, true, nil
},
AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions},
AllowHeaders: []string{"*"},
AllowCredentials: true,
})
connectGroup := echoServer.Group("", corsHandler)
connectGroup.Any("/memos.api.v1.*", echo.WrapHandler(connectMux))
// Register AI REST endpoints (direct HTTP, no Connect/gRPC required)
// Apply auth middleware so user context is populated (tries Bearer token then cookie)
aiAuthenticator := auth.NewAuthenticator(s.Store, s.Secret)
aiGroup := echoServer.Group("", func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
ctx := c.Request().Context()
authHeader := c.Request().Header.Get("Authorization")
cookieHeader := c.Request().Header.Get("Cookie")
result := aiAuthenticator.Authenticate(ctx, authHeader)
if result == nil && cookieHeader != "" {
// Try cookie-based auth (refresh token)
user, err := aiAuthenticator.AuthenticateToUser(ctx, authHeader, cookieHeader)
if err == nil && user != nil {
ctx = auth.SetUserInContext(ctx, user, "")
c.SetRequest(c.Request().WithContext(ctx))
return next(c)
}
}
if result != nil {
ctx = auth.ApplyToContext(ctx, result)
c.SetRequest(c.Request().WithContext(ctx))
}
return next(c)
}
})
s.RegisterAIHTTPHandlers(aiGroup)
return nil
}

View File

@@ -0,0 +1,307 @@
# Fileserver Package
## Overview
The `fileserver` package handles all binary file serving for Memos using native HTTP handlers. It was created to replace gRPC-based binary serving, which had limitations with HTTP range requests (required for Safari video/audio playback).
## Responsibilities
- Serve attachment binary files (images, videos, audio, documents)
- Serve user avatar images
- Handle HTTP range requests for video/audio streaming
- Authenticate requests using JWT tokens or Personal Access Tokens
- Check permissions for private content
- Generate and serve image thumbnails
- Prevent XSS attacks on uploaded content
- Support S3 external storage
## Architecture
### Design Principles
1. **Separation of Concerns**: Binary files via HTTP, metadata via gRPC
2. **DRY**: Imports auth constants from `api/v1` package (single source of truth)
3. **Security First**: Authentication, authorization, and XSS prevention
4. **Performance**: Native HTTP streaming with proper caching headers
### Package Structure
```
fileserver/
├── fileserver.go # Main service and HTTP handlers
├── README.md # This file
└── fileserver_test.go # Tests (to be added)
```
## API Endpoints
### 1. Attachment Binary
```
GET /file/attachments/:uid/:filename[?thumbnail=true]
```
**Parameters:**
- `uid` - Attachment unique identifier
- `filename` - Original filename
- `thumbnail` (optional) - Return thumbnail for images
**Authentication:** Required for non-public memos
**Response:**
- `200 OK` - File content with proper Content-Type
- `206 Partial Content` - For range requests (video/audio)
- `401 Unauthorized` - Authentication required
- `403 Forbidden` - User not authorized
- `404 Not Found` - Attachment not found
**Headers:**
- `Content-Type` - MIME type of the file
- `Cache-Control: public, max-age=3600`
- `Accept-Ranges: bytes` - For video/audio
- `Content-Range` - For partial responses (206)
### 2. User Avatar
```
GET /file/users/:identifier/avatar
```
**Parameters:**
- `identifier` - User ID (e.g., `1`) or username (e.g., `steven`)
**Authentication:** Not required (avatars are public)
**Response:**
- `200 OK` - Avatar image (PNG/JPEG)
- `404 Not Found` - User not found or no avatar set
**Headers:**
- `Content-Type` - image/png or image/jpeg
- `Cache-Control: public, max-age=3600`
## Authentication
### Supported Methods
The fileserver supports the following authentication methods:
1. **JWT Access Token** (`Authorization: Bearer {token}`)
- Short-lived tokens (15 minutes) for API access
- Stateless validation using JWT signature
- Extracts user ID from token claims
2. **Personal Access Token (PAT)** (`Authorization: Bearer {pat}`)
- Long-lived tokens for programmatic access
- Validates against database for revocation
- Prefixed with specific identifier
### Authentication Flow
```
Request → getCurrentUser()
├─→ Try Session Cookie
│ ├─→ Parse cookie value
│ ├─→ Get user from DB
│ ├─→ Validate session
│ └─→ Return user (if valid)
└─→ Try JWT Token
├─→ Parse Authorization header
├─→ Verify JWT signature
├─→ Get user from DB
├─→ Validate token in access tokens list
└─→ Return user (if valid)
```
### Permission Model
**Attachments:**
- Unlinked: Public (no auth required)
- Public memo: Public (no auth required)
- Protected memo: Requires authentication
- Private memo: Creator only
**Avatars:**
- Always public (no auth required)
## Key Functions
### HTTP Handlers
#### `serveAttachmentFile(c echo.Context) error`
Main handler for attachment binary serving.
**Flow:**
1. Extract UID from URL parameter
2. Fetch attachment from database
3. Check permissions (memo visibility)
4. Get binary blob (local file, S3, or database)
5. Handle thumbnail request (if applicable)
6. Set security headers (XSS prevention)
7. Serve with range request support (video/audio)
#### `serveUserAvatar(c echo.Context) error`
Main handler for user avatar serving.
**Flow:**
1. Extract identifier (ID or username) from URL
2. Lookup user in database
3. Check if avatar exists
4. Decode base64 data URI
5. Serve with proper content type and caching
### Authentication
#### `getCurrentUser(ctx, c) (*store.User, error)`
Authenticates request using session cookie or JWT token.
#### `authenticateBySession(ctx, cookie) (*store.User, error)`
Validates session cookie and returns authenticated user.
#### `authenticateByJWT(ctx, token) (*store.User, error)`
Validates JWT access token and returns authenticated user.
### Permission Checks
#### `checkAttachmentPermission(ctx, c, attachment) error`
Validates user has permission to access attachment based on memo visibility.
### File Operations
#### `getAttachmentBlob(attachment) ([]byte, error)`
Retrieves binary content from local storage, S3, or database.
#### `getOrGenerateThumbnail(ctx, attachment) ([]byte, error)`
Returns cached thumbnail or generates new one (with semaphore limiting).
### Utilities
#### `getUserByIdentifier(ctx, identifier) (*store.User, error)`
Finds user by ID (int) or username (string).
#### `extractImageInfo(dataURI) (type, base64, error)`
Parses data URI to extract MIME type and base64 data.
## Dependencies
### External Packages
- `github.com/labstack/echo/v4` - HTTP router and middleware
- `github.com/golang-jwt/jwt/v5` - JWT parsing and validation
- `github.com/disintegration/imaging` - Image thumbnail generation
- `golang.org/x/sync/semaphore` - Concurrency control for thumbnails
### Internal Packages
- `server/auth` - Authentication utilities
- `store` - Database operations
- `internal/profile` - Server configuration
- `plugin/storage/s3` - S3 storage client
## Configuration
### Constants
Auth-related constants are imported from `server/auth`:
- `auth.RefreshTokenCookieName` - "memos_refresh"
- `auth.PersonalAccessTokenPrefix` - PAT identifier prefix
Package-specific constants:
- `ThumbnailCacheFolder` - ".thumbnail_cache"
- `thumbnailMaxSize` - 600px
- `SupportedThumbnailMimeTypes` - ["image/png", "image/jpeg"]
## Error Handling
All handlers return Echo HTTP errors with appropriate status codes:
```go
// Bad request
echo.NewHTTPError(http.StatusBadRequest, "message")
// Unauthorized (no auth)
echo.NewHTTPError(http.StatusUnauthorized, "message")
// Forbidden (auth but no permission)
echo.NewHTTPError(http.StatusForbidden, "message")
// Not found
echo.NewHTTPError(http.StatusNotFound, "message")
// Internal error
echo.NewHTTPError(http.StatusInternalServerError, "message").SetInternal(err)
```
## Security Considerations
### 1. XSS Prevention
SVG and HTML files are served as `application/octet-stream` to prevent script execution:
```go
if contentType == "image/svg+xml" ||
contentType == "text/html" ||
contentType == "application/xhtml+xml" {
contentType = "application/octet-stream"
}
```
### 2. Authentication
Private content requires valid JWT access token or Personal Access Token.
### 3. Authorization
Memo visibility rules enforced before serving attachments.
### 4. Input Validation
- Attachment UID validated from database
- User identifier validated (ID or username)
- Range requests validated before processing
## Performance Optimizations
### 1. Thumbnail Caching
Thumbnails cached on disk to avoid regeneration:
- Cache location: `{data_dir}/.thumbnail_cache/`
- Filename: `{attachment_id}{extension}`
- Semaphore limits concurrent generation (max 3)
### 2. HTTP Range Requests
Video/audio files use `http.ServeContent()` for efficient streaming:
- Automatic range parsing
- Efficient memory usage (streaming, not loading full file)
- Safari-compatible partial content responses
### 3. Caching Headers
All responses include cache headers:
```
Cache-Control: public, max-age=3600
```
### 4. S3 External Links
S3 files served via presigned URLs (no server download).
## Testing
### Unit Tests (To Add)
See SAFARI_FIX.md for recommended test coverage.
### Manual Testing
```bash
# Test attachment
curl "http://localhost:8081/file/attachments/{uid}/file.jpg"
# Test avatar by ID
curl "http://localhost:8081/file/users/1/avatar"
# Test avatar by username
curl "http://localhost:8081/file/users/steven/avatar"
# Test range request
curl -H "Range: bytes=0-999" "http://localhost:8081/file/attachments/{uid}/video.mp4"
```
## Future Improvements
See SAFARI_FIX.md section "Future Improvements" for planned enhancements.
## Related Documentation
- [SAFARI_FIX.md](../../../SAFARI_FIX.md) - Full migration guide
- [server/router/api/v1/auth.go](../api/v1/auth.go) - Auth constants source of truth
- [RFC 7233](https://tools.ietf.org/html/rfc7233) - HTTP Range Requests spec

View File

@@ -0,0 +1,587 @@
package fileserver
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/disintegration/imaging"
"github.com/labstack/echo/v5"
"github.com/pkg/errors"
"golang.org/x/sync/semaphore"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/storage/s3"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// Constants for file serving configuration.
const (
// ThumbnailCacheFolder is the folder name where thumbnail images are stored.
ThumbnailCacheFolder = ".thumbnail_cache"
// thumbnailMaxSize is the maximum dimension (width or height) for thumbnails.
thumbnailMaxSize = 600
// maxConcurrentThumbnails limits concurrent thumbnail generation to prevent memory exhaustion.
maxConcurrentThumbnails = 3
// cacheMaxAge is the max-age value for Cache-Control headers (1 hour).
cacheMaxAge = "public, max-age=3600"
)
// xssUnsafeTypes contains MIME types that could execute scripts if served directly.
// These are served as application/octet-stream to prevent XSS attacks.
var xssUnsafeTypes = map[string]bool{
"text/html": true,
"text/javascript": true,
"application/javascript": true,
"application/x-javascript": true,
"text/xml": true,
"application/xml": true,
"application/xhtml+xml": true,
"image/svg+xml": true,
}
// thumbnailSupportedTypes contains image MIME types that support thumbnail generation.
var thumbnailSupportedTypes = map[string]bool{
"image/png": true,
"image/jpeg": true,
"image/heic": true,
"image/heif": true,
"image/webp": true,
}
// avatarAllowedTypes contains MIME types allowed for user avatars.
var avatarAllowedTypes = map[string]bool{
"image/png": true,
"image/jpeg": true,
"image/jpg": true,
"image/gif": true,
"image/webp": true,
"image/heic": true,
"image/heif": true,
}
// SupportedThumbnailMimeTypes is the exported list of thumbnail-supported MIME types.
var SupportedThumbnailMimeTypes = []string{
"image/png",
"image/jpeg",
"image/heic",
"image/heif",
"image/webp",
}
// dataURIRegex parses data URI format: data:image/png;base64,iVBORw0KGgo...
var dataURIRegex = regexp.MustCompile(`^data:(?P<type>[^;]+);base64,(?P<base64>.+)`)
// FileServerService handles HTTP file serving with proper range request support.
type FileServerService struct {
Profile *profile.Profile
Store *store.Store
authenticator *auth.Authenticator
// thumbnailSemaphore limits concurrent thumbnail generation.
thumbnailSemaphore *semaphore.Weighted
}
// NewFileServerService creates a new file server service.
func NewFileServerService(profile *profile.Profile, store *store.Store, secret string) *FileServerService {
return &FileServerService{
Profile: profile,
Store: store,
authenticator: auth.NewAuthenticator(store, secret),
thumbnailSemaphore: semaphore.NewWeighted(maxConcurrentThumbnails),
}
}
// RegisterRoutes registers HTTP file serving routes.
func (s *FileServerService) RegisterRoutes(echoServer *echo.Echo) {
fileGroup := echoServer.Group("/file")
fileGroup.GET("/attachments/:uid/:filename", s.serveAttachmentFile)
fileGroup.GET("/users/:identifier/avatar", s.serveUserAvatar)
}
// =============================================================================
// HTTP Handlers
// =============================================================================
// serveAttachmentFile serves attachment binary content using native HTTP.
func (s *FileServerService) serveAttachmentFile(c *echo.Context) error {
ctx := c.Request().Context()
uid := c.Param("uid")
wantThumbnail := c.QueryParam("thumbnail") == "true"
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
UID: &uid,
GetBlob: true,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment").Wrap(err)
}
if attachment == nil {
return echo.NewHTTPError(http.StatusNotFound, "attachment not found")
}
if err := s.checkAttachmentPermission(ctx, c, attachment); err != nil {
return err
}
contentType := s.sanitizeContentType(attachment.Type)
// Stream video/audio to avoid loading entire file into memory.
if isMediaType(attachment.Type) {
return s.serveMediaStream(c, attachment, contentType)
}
return s.serveStaticFile(c, attachment, contentType, wantThumbnail)
}
// serveUserAvatar serves user avatar images.
func (s *FileServerService) serveUserAvatar(c *echo.Context) error {
ctx := c.Request().Context()
identifier := c.Param("identifier")
user, err := s.getUserByIdentifier(ctx, identifier)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get user").Wrap(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusNotFound, "user not found")
}
if user.AvatarURL == "" {
return echo.NewHTTPError(http.StatusNotFound, "avatar not found")
}
imageType, imageData, err := s.parseDataURI(user.AvatarURL)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to parse avatar data").Wrap(err)
}
if !avatarAllowedTypes[imageType] {
return echo.NewHTTPError(http.StatusBadRequest, "invalid avatar image type")
}
setSecurityHeaders(c)
c.Response().Header().Set(echo.HeaderContentType, imageType)
c.Response().Header().Set(echo.HeaderCacheControl, cacheMaxAge)
return c.Blob(http.StatusOK, imageType, imageData)
}
// =============================================================================
// File Serving Methods
// =============================================================================
// serveMediaStream serves video/audio files using streaming to avoid memory exhaustion.
func (s *FileServerService) serveMediaStream(c *echo.Context, attachment *store.Attachment, contentType string) error {
setSecurityHeaders(c)
setMediaHeaders(c, contentType, attachment.Type)
switch attachment.StorageType {
case storepb.AttachmentStorageType_LOCAL:
filePath, err := s.resolveLocalPath(attachment.Reference)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to resolve file path").Wrap(err)
}
http.ServeFile(c.Response(), c.Request(), filePath)
return nil
case storepb.AttachmentStorageType_S3:
presignURL, err := s.getS3PresignedURL(c.Request().Context(), attachment)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to generate presigned URL").Wrap(err)
}
return c.Redirect(http.StatusTemporaryRedirect, presignURL)
default:
// Database storage fallback.
modTime := time.Unix(attachment.UpdatedTs, 0)
http.ServeContent(c.Response(), c.Request(), attachment.Filename, modTime, bytes.NewReader(attachment.Blob))
return nil
}
}
// serveStaticFile serves non-streaming files (images, documents, etc.).
func (s *FileServerService) serveStaticFile(c *echo.Context, attachment *store.Attachment, contentType string, wantThumbnail bool) error {
blob, err := s.getAttachmentBlob(attachment)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment blob").Wrap(err)
}
// Generate thumbnail for supported image types.
if wantThumbnail && thumbnailSupportedTypes[attachment.Type] {
if thumbnailBlob, err := s.getOrGenerateThumbnail(c.Request().Context(), attachment); err != nil {
slog.Warn("failed to get thumbnail", "error", err)
} else {
blob = thumbnailBlob
}
}
setSecurityHeaders(c)
setMediaHeaders(c, contentType, attachment.Type)
// Force download for non-media files to prevent XSS execution.
if !strings.HasPrefix(contentType, "image/") && contentType != "application/pdf" {
c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("attachment; filename=%q", attachment.Filename))
}
return c.Blob(http.StatusOK, contentType, blob)
}
// =============================================================================
// Storage Operations
// =============================================================================
// getAttachmentBlob retrieves the binary content of an attachment from storage.
func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
switch attachment.StorageType {
case storepb.AttachmentStorageType_LOCAL:
return s.readLocalFile(attachment.Reference)
case storepb.AttachmentStorageType_S3:
return s.downloadFromS3(attachment)
default:
return attachment.Blob, nil
}
}
// getAttachmentReader returns a reader for streaming attachment content.
func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) {
switch attachment.StorageType {
case storepb.AttachmentStorageType_LOCAL:
filePath, err := s.resolveLocalPath(attachment.Reference)
if err != nil {
return nil, err
}
file, err := os.Open(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.Wrap(err, "file not found")
}
return nil, errors.Wrap(err, "failed to open file")
}
return file, nil
case storepb.AttachmentStorageType_S3:
s3Client, s3Object, err := s.createS3Client(attachment)
if err != nil {
return nil, err
}
reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key)
if err != nil {
return nil, errors.Wrap(err, "failed to stream from S3")
}
return reader, nil
default:
return io.NopCloser(bytes.NewReader(attachment.Blob)), nil
}
}
// resolveLocalPath converts a storage reference to an absolute file path.
func (s *FileServerService) resolveLocalPath(reference string) (string, error) {
filePath := filepath.FromSlash(reference)
if !filepath.IsAbs(filePath) {
filePath = filepath.Join(s.Profile.Data, filePath)
}
return filePath, nil
}
// readLocalFile reads the entire contents of a local file.
func (s *FileServerService) readLocalFile(reference string) ([]byte, error) {
filePath, err := s.resolveLocalPath(reference)
if err != nil {
return nil, err
}
file, err := os.Open(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.Wrap(err, "file not found")
}
return nil, errors.Wrap(err, "failed to open file")
}
defer file.Close()
blob, err := io.ReadAll(file)
if err != nil {
return nil, errors.Wrap(err, "failed to read file")
}
return blob, nil
}
// createS3Client creates an S3 client from attachment payload.
func (*FileServerService) createS3Client(attachment *store.Attachment) (*s3.Client, *storepb.AttachmentPayload_S3Object, error) {
if attachment.Payload == nil {
return nil, nil, errors.New("attachment payload is missing")
}
s3Object := attachment.Payload.GetS3Object()
if s3Object == nil {
return nil, nil, errors.New("S3 object payload is missing")
}
if s3Object.S3Config == nil {
return nil, nil, errors.New("S3 config is missing")
}
if s3Object.Key == "" {
return nil, nil, errors.New("S3 object key is missing")
}
client, err := s3.NewClient(context.Background(), s3Object.S3Config)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create S3 client")
}
return client, s3Object, nil
}
// downloadFromS3 downloads the entire object from S3.
func (s *FileServerService) downloadFromS3(attachment *store.Attachment) ([]byte, error) {
client, s3Object, err := s.createS3Client(attachment)
if err != nil {
return nil, err
}
blob, err := client.GetObject(context.Background(), s3Object.Key)
if err != nil {
return nil, errors.Wrap(err, "failed to download from S3")
}
return blob, nil
}
// getS3PresignedURL generates a presigned URL for direct S3 access.
func (s *FileServerService) getS3PresignedURL(ctx context.Context, attachment *store.Attachment) (string, error) {
client, s3Object, err := s.createS3Client(attachment)
if err != nil {
return "", err
}
url, err := client.PresignGetObject(ctx, s3Object.Key)
if err != nil {
return "", errors.Wrap(err, "failed to presign URL")
}
return url, nil
}
// =============================================================================
// Thumbnail Generation
// =============================================================================
// getOrGenerateThumbnail returns the thumbnail image of the attachment.
// Uses semaphore to limit concurrent thumbnail generation and prevent memory exhaustion.
func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachment *store.Attachment) ([]byte, error) {
thumbnailPath, err := s.getThumbnailPath(attachment)
if err != nil {
return nil, err
}
// Fast path: return cached thumbnail if exists.
if blob, err := s.readCachedThumbnail(thumbnailPath); err == nil {
return blob, nil
}
// Acquire semaphore to limit concurrent generation.
if err := s.thumbnailSemaphore.Acquire(ctx, 1); err != nil {
return nil, errors.Wrap(err, "failed to acquire semaphore")
}
defer s.thumbnailSemaphore.Release(1)
// Double-check after acquiring semaphore (another goroutine may have generated it).
if blob, err := s.readCachedThumbnail(thumbnailPath); err == nil {
return blob, nil
}
return s.generateThumbnail(attachment, thumbnailPath)
}
// getThumbnailPath returns the file path for a cached thumbnail.
func (s *FileServerService) getThumbnailPath(attachment *store.Attachment) (string, error) {
cacheFolder := filepath.Join(s.Profile.Data, ThumbnailCacheFolder)
if err := os.MkdirAll(cacheFolder, os.ModePerm); err != nil {
return "", errors.Wrap(err, "failed to create thumbnail cache folder")
}
filename := fmt.Sprintf("%d%s", attachment.ID, filepath.Ext(attachment.Filename))
return filepath.Join(cacheFolder, filename), nil
}
// readCachedThumbnail reads a thumbnail from the cache directory.
func (*FileServerService) readCachedThumbnail(path string) ([]byte, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
return io.ReadAll(file)
}
// generateThumbnail creates a new thumbnail and saves it to disk.
func (s *FileServerService) generateThumbnail(attachment *store.Attachment, thumbnailPath string) ([]byte, error) {
reader, err := s.getAttachmentReader(attachment)
if err != nil {
return nil, errors.Wrap(err, "failed to get attachment reader")
}
defer reader.Close()
img, err := imaging.Decode(reader, imaging.AutoOrientation(true))
if err != nil {
return nil, errors.Wrap(err, "failed to decode image")
}
width, height := img.Bounds().Dx(), img.Bounds().Dy()
thumbnailWidth, thumbnailHeight := calculateThumbnailDimensions(width, height)
thumbnailImage := imaging.Resize(img, thumbnailWidth, thumbnailHeight, imaging.Lanczos)
if err := imaging.Save(thumbnailImage, thumbnailPath); err != nil {
return nil, errors.Wrap(err, "failed to save thumbnail")
}
return s.readCachedThumbnail(thumbnailPath)
}
// calculateThumbnailDimensions calculates the target dimensions for a thumbnail.
// The largest dimension is constrained to thumbnailMaxSize while maintaining aspect ratio.
// Small images are not enlarged.
func calculateThumbnailDimensions(width, height int) (int, int) {
if max(width, height) <= thumbnailMaxSize {
return width, height
}
if width >= height {
return thumbnailMaxSize, 0 // Landscape: constrain width.
}
return 0, thumbnailMaxSize // Portrait: constrain height.
}
// =============================================================================
// Authentication & Authorization
// =============================================================================
// checkAttachmentPermission verifies the user has permission to access the attachment.
func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c *echo.Context, attachment *store.Attachment) error {
// For unlinked attachments, only the creator can access.
if attachment.MemoID == nil {
user, err := s.getCurrentUser(ctx, c)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get current user").Wrap(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "unauthorized access")
}
if user.ID != attachment.CreatorID && user.Role != store.RoleAdmin {
return echo.NewHTTPError(http.StatusForbidden, "forbidden access")
}
return nil
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to find memo").Wrap(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, "memo not found")
}
if memo.Visibility == store.Public {
return nil
}
user, err := s.getCurrentUser(ctx, c)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get current user").Wrap(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "unauthorized access")
}
if memo.Visibility == store.Private && user.ID != memo.CreatorID && user.Role != store.RoleAdmin {
return echo.NewHTTPError(http.StatusForbidden, "forbidden access")
}
return nil
}
// getCurrentUser retrieves the current authenticated user from the request.
// Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie.
func (s *FileServerService) getCurrentUser(ctx context.Context, c *echo.Context) (*store.User, error) {
authHeader := c.Request().Header.Get(echo.HeaderAuthorization)
cookieHeader := c.Request().Header.Get("Cookie")
return s.authenticator.AuthenticateToUser(ctx, authHeader, cookieHeader)
}
// getUserByIdentifier finds a user by either ID or username.
func (s *FileServerService) getUserByIdentifier(ctx context.Context, identifier string) (*store.User, error) {
if userID, err := util.ConvertStringToInt32(identifier); err == nil {
return s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
}
return s.Store.GetUser(ctx, &store.FindUser{Username: &identifier})
}
// =============================================================================
// Helper Functions
// =============================================================================
// sanitizeContentType converts potentially dangerous MIME types to safe alternatives.
func (*FileServerService) sanitizeContentType(mimeType string) string {
contentType := mimeType
if strings.HasPrefix(contentType, "text/") {
contentType += "; charset=utf-8"
}
// Normalize for case-insensitive lookup.
if xssUnsafeTypes[strings.ToLower(mimeType)] {
return "application/octet-stream"
}
return contentType
}
// parseDataURI extracts MIME type and decoded data from a data URI.
func (*FileServerService) parseDataURI(dataURI string) (string, []byte, error) {
matches := dataURIRegex.FindStringSubmatch(dataURI)
if len(matches) != 3 {
return "", nil, errors.New("invalid data URI format")
}
imageType := matches[1]
imageData, err := base64.StdEncoding.DecodeString(matches[2])
if err != nil {
return "", nil, errors.Wrap(err, "failed to decode base64 data")
}
return imageType, imageData, nil
}
// isMediaType checks if the MIME type is video or audio.
func isMediaType(mimeType string) bool {
return strings.HasPrefix(mimeType, "video/") || strings.HasPrefix(mimeType, "audio/")
}
// setSecurityHeaders sets common security headers for all responses.
func setSecurityHeaders(c *echo.Context) {
h := c.Response().Header()
h.Set("X-Content-Type-Options", "nosniff")
h.Set("X-Frame-Options", "DENY")
h.Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';")
}
// setMediaHeaders sets headers for media file responses.
func setMediaHeaders(c *echo.Context, contentType, originalType string) {
h := c.Response().Header()
h.Set(echo.HeaderContentType, contentType)
h.Set(echo.HeaderCacheControl, cacheMaxAge)
// Support HDR/wide color gamut for images and videos.
if strings.HasPrefix(originalType, "image/") || strings.HasPrefix(originalType, "video/") {
h.Set("Color-Gamut", "srgb, p3, rec2020")
}
}

View File

@@ -0,0 +1,68 @@
package frontend
import (
"context"
"embed"
"io/fs"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/store"
)
//go:embed dist/*
var embeddedFiles embed.FS
type FrontendService struct {
Profile *profile.Profile
Store *store.Store
}
func NewFrontendService(profile *profile.Profile, store *store.Store) *FrontendService {
return &FrontendService{
Profile: profile,
Store: store,
}
}
func (*FrontendService) Serve(_ context.Context, e *echo.Echo) {
skipper := func(c *echo.Context) bool {
// Skip API routes.
if util.HasPrefixes(c.Path(), "/api", "/memos.api.v1") {
return true
}
// For index.html and root path, set no-cache headers to prevent browser caching
// This prevents sensitive data from being accessible via browser back button after logout
if c.Path() == "/" || c.Path() == "/index.html" {
c.Response().Header().Set(echo.HeaderCacheControl, "no-cache, no-store, must-revalidate")
c.Response().Header().Set("Pragma", "no-cache")
c.Response().Header().Set("Expires", "0")
return false
}
// Set Cache-Control header for static assets.
// Since Vite generates content-hashed filenames (e.g., index-BtVjejZf.js),
// we can cache aggressively but use immutable to prevent revalidation checks.
// For frequently redeployed instances, use shorter max-age (1 hour) to avoid
// serving stale assets after redeployment.
c.Response().Header().Set(echo.HeaderCacheControl, "public, max-age=3600, immutable") // 1 hour
return false
}
// Route to serve the main app with HTML5 fallback for SPA behavior.
e.Use(middleware.StaticWithConfig(middleware.StaticConfig{
Filesystem: getFileSystem("dist"),
HTML5: true, // Enable fallback to index.html
Skipper: skipper,
}))
}
func getFileSystem(path string) fs.FS {
sub, err := fs.Sub(embeddedFiles, path)
if err != nil {
panic(err)
}
return sub
}

View File

@@ -0,0 +1,66 @@
# MCP Server
This package implements a [Model Context Protocol (MCP)](https://modelcontextprotocol.io) server embedded in the Memos HTTP process. It exposes memo operations as MCP tools, making Memos accessible to any MCP-compatible AI client (Claude Desktop, Cursor, Zed, etc.).
## Endpoint
```
POST /mcp (tool calls, initialize)
GET /mcp (optional SSE stream for server-to-client messages)
```
Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26).
## Authentication
Every request must include a Personal Access Token (PAT):
```
Authorization: Bearer <your-PAT>
```
PATs are long-lived tokens created in Settings → My Account → Access Tokens. Short-lived JWT session tokens are not accepted. Requests without a valid PAT receive `HTTP 401`.
## Tools
All tools are scoped to the authenticated user's memos.
| Tool | Description | Required params | Optional params |
|---|---|---|---|
| `list_memos` | List memos | — | `page_size` (int, max 100), `filter` (CEL expression) |
| `get_memo` | Get a single memo | `name` | — |
| `search_memos` | Full-text search | `query` | — |
| `create_memo` | Create a memo | `content` | `visibility` |
| `update_memo` | Update content or visibility | `name` | `content`, `visibility` |
| `delete_memo` | Delete a memo | `name` | — |
**`name`** is the memo resource name, e.g. `memos/abc123`.
**`visibility`** accepts `PRIVATE` (default), `PROTECTED`, or `PUBLIC`.
**`filter`** accepts CEL expressions supported by the memo filter engine, e.g.:
- `content.contains("keyword")`
- `visibility == "PUBLIC"`
- `has_task_list`
## Connecting Claude Code
```bash
claude mcp add --transport http memos http://localhost:5230/mcp \
--header "Authorization: Bearer <your-PAT>"
```
Use `--scope user` to make it available across all projects:
```bash
claude mcp add --scope user --transport http memos http://localhost:5230/mcp \
--header "Authorization: Bearer <your-PAT>"
```
## Package Structure
| File | Responsibility |
|---|---|
| `mcp.go` | `MCPService` struct, constructor, route registration |
| `auth_middleware.go` | Echo middleware — validates Bearer token, sets user ID in context |
| `tools_memo.go` | Tool registration and six memo tool handlers |

56
server/router/mcp/mcp.go Normal file
View File

@@ -0,0 +1,56 @@
package mcp
import (
"net/http"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
type MCPService struct {
store *store.Store
authenticator *auth.Authenticator
}
func NewMCPService(store *store.Store, secret string) *MCPService {
return &MCPService{
store: store,
authenticator: auth.NewAuthenticator(store, secret),
}
}
func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpSrv := mcpserver.NewMCPServer("Memos", "1.0.0",
mcpserver.WithToolCapabilities(false),
)
s.registerMemoTools(mcpSrv)
s.registerTagTools(mcpSrv)
s.registerMemoResources(mcpSrv)
s.registerPrompts(mcpSrv)
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv)
mcpGroup := echoServer.Group("")
mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"},
}))
mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
authHeader := c.Request().Header.Get("Authorization")
if authHeader != "" {
result := s.authenticator.Authenticate(c.Request().Context(), authHeader)
if result == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired token"})
}
ctx := auth.ApplyToContext(c.Request().Context(), result)
c.SetRequest(c.Request().WithContext(ctx))
}
return next(c)
}
})
mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler))
}

View File

@@ -0,0 +1,84 @@
package mcp
import (
"context"
"errors"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
)
func (s *MCPService) registerPrompts(mcpSrv *mcpserver.MCPServer) {
// capture — turns free-form user input into a structured create_memo call.
mcpSrv.AddPrompt(
mcp.NewPrompt("capture",
mcp.WithPromptDescription("Capture a thought, idea, or note as a new memo. "+
"Use this prompt when the user wants to quickly save something. "+
"The assistant will call create_memo with the provided content."),
mcp.WithArgument("content",
mcp.ArgumentDescription("The text to save as a memo"),
mcp.RequiredArgument(),
),
mcp.WithArgument("tags",
mcp.ArgumentDescription("Comma-separated tags to apply, e.g. \"work,project\""),
),
),
s.handleCapturePrompt,
)
// review — surfaces existing memos on a topic for summarisation.
mcpSrv.AddPrompt(
mcp.NewPrompt("review",
mcp.WithPromptDescription("Search and review memos on a given topic. "+
"The assistant will call search_memos and summarise the results."),
mcp.WithArgument("topic",
mcp.ArgumentDescription("Topic or keyword to search for"),
mcp.RequiredArgument(),
),
),
s.handleReviewPrompt,
)
}
func (*MCPService) handleCapturePrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
content := req.Params.Arguments["content"]
if content == "" {
return nil, errors.New("content argument is required")
}
tags := req.Params.Arguments["tags"]
instruction := fmt.Sprintf(
"Please save the following as a new private memo using the create_memo tool.\n\nContent:\n%s",
content,
)
if tags != "" {
instruction += fmt.Sprintf("\n\nAppend these tags inline using #tag syntax: %s", tags)
}
return &mcp.GetPromptResult{
Description: "Capture a memo",
Messages: []mcp.PromptMessage{
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
},
}, nil
}
func (*MCPService) handleReviewPrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
topic := req.Params.Arguments["topic"]
if topic == "" {
return nil, errors.New("topic argument is required")
}
instruction := fmt.Sprintf(
"Please use the search_memos tool to find memos about %q, then provide a concise summary of what has been written on this topic, grouped by theme. Include the memo names so the user can reference them.",
topic,
)
return &mcp.GetPromptResult{
Description: fmt.Sprintf("Review memos about %q", topic),
Messages: []mcp.PromptMessage{
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
},
}, nil
}

View File

@@ -0,0 +1,85 @@
package mcp
import (
"context"
"fmt"
"strings"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// Memo resource URI scheme: memo://memos/{uid}
// Clients can read any memo they have access to by URI without calling a tool.
func (s *MCPService) registerMemoResources(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddResourceTemplate(
mcp.NewResourceTemplate(
"memo://memos/{uid}",
"Memo",
mcp.WithTemplateDescription("A single Memos note identified by its UID. Returns the memo content as Markdown with a YAML frontmatter header containing metadata."),
mcp.WithTemplateMIMEType("text/markdown"),
),
s.handleReadMemoResource,
)
}
func (s *MCPService) handleReadMemoResource(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
userID := auth.GetUserID(ctx)
// URI format: memo://memos/{uid}
uid := strings.TrimPrefix(req.Params.URI, "memo://memos/")
if uid == req.Params.URI || uid == "" {
return nil, errors.Errorf("invalid memo URI %q: expected memo://memos/<uid>", req.Params.URI)
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo")
}
if memo == nil {
return nil, errors.Errorf("memo not found: %s", uid)
}
if err := checkMemoAccess(memo, userID); err != nil {
return nil, err
}
j := storeMemoToJSON(memo)
text := formatMemoMarkdown(j)
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: req.Params.URI,
MIMEType: "text/markdown",
Text: text,
},
}, nil
}
// formatMemoMarkdown renders a memo as Markdown with a YAML frontmatter header.
func formatMemoMarkdown(j memoJSON) string {
var sb strings.Builder
sb.WriteString("---\n")
fmt.Fprintf(&sb, "name: %s\n", j.Name)
fmt.Fprintf(&sb, "creator: %s\n", j.Creator)
fmt.Fprintf(&sb, "visibility: %s\n", j.Visibility)
fmt.Fprintf(&sb, "state: %s\n", j.State)
fmt.Fprintf(&sb, "pinned: %v\n", j.Pinned)
if len(j.Tags) > 0 {
fmt.Fprintf(&sb, "tags: [%s]\n", strings.Join(j.Tags, ", "))
}
fmt.Fprintf(&sb, "create_time: %d\n", j.CreateTime)
fmt.Fprintf(&sb, "update_time: %d\n", j.UpdateTime)
if j.Parent != "" {
fmt.Fprintf(&sb, "parent: %s\n", j.Parent)
}
sb.WriteString("---\n\n")
sb.WriteString(j.Content)
return sb.String()
}

View File

@@ -0,0 +1,599 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
"github.com/lithammer/shortuuid/v4"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// tagRegexp matches #tag patterns in memo content.
// A tag must start with a letter and contain no whitespace or # characters.
var tagRegexp = regexp.MustCompile(`(?:^|\s)#([A-Za-z][^\s#]*)`)
// extractTags does a best-effort extraction of #tags from raw markdown content.
// It is used when creating or updating memos via MCP to pre-populate Payload.Tags.
// The full markdown service may later rebuild a more accurate payload.
func extractTags(content string) []string {
matches := tagRegexp.FindAllStringSubmatch(content, -1)
seen := make(map[string]struct{}, len(matches))
tags := make([]string, 0, len(matches))
for _, m := range matches {
tag := m[1]
if _, ok := seen[tag]; !ok {
seen[tag] = struct{}{}
tags = append(tags, tag)
}
}
return tags
}
// buildPayload constructs a MemoPayload with tags extracted from content.
// Returns nil when no tags are found so the store omits the payload entirely.
func buildPayload(content string) *storepb.MemoPayload {
tags := extractTags(content)
if len(tags) == 0 {
return nil
}
return &storepb.MemoPayload{Tags: tags}
}
// propertyJSON is the serialisable form of MemoPayload.Property.
type propertyJSON struct {
HasLink bool `json:"has_link"`
HasTaskList bool `json:"has_task_list"`
HasCode bool `json:"has_code"`
HasIncompleteTasks bool `json:"has_incomplete_tasks"`
}
// memoJSON is the canonical response shape for all MCP memo results.
// It serialises correctly with standard encoding/json (no proto marshalling needed).
type memoJSON struct {
Name string `json:"name"`
Creator string `json:"creator"`
CreateTime int64 `json:"create_time"`
UpdateTime int64 `json:"update_time"`
Content string `json:"content,omitempty"`
Visibility string `json:"visibility"`
Tags []string `json:"tags"`
Pinned bool `json:"pinned"`
State string `json:"state"`
Property *propertyJSON `json:"property,omitempty"`
Parent string `json:"parent,omitempty"`
}
func storeMemoToJSON(m *store.Memo) memoJSON {
j := memoJSON{
Name: "memos/" + m.UID,
Creator: fmt.Sprintf("users/%d", m.CreatorID),
CreateTime: m.CreatedTs,
UpdateTime: m.UpdatedTs,
Content: m.Content,
Visibility: string(m.Visibility),
Pinned: m.Pinned,
State: string(m.RowStatus),
Tags: []string{},
}
if m.Payload != nil {
if len(m.Payload.Tags) > 0 {
j.Tags = m.Payload.Tags
}
if p := m.Payload.Property; p != nil && (p.HasLink || p.HasTaskList || p.HasCode || p.HasIncompleteTasks) {
j.Property = &propertyJSON{
HasLink: p.HasLink,
HasTaskList: p.HasTaskList,
HasCode: p.HasCode,
HasIncompleteTasks: p.HasIncompleteTasks,
}
}
}
if m.ParentUID != nil {
j.Parent = "memos/" + *m.ParentUID
}
return j
}
// checkMemoAccess returns an error if the caller cannot read memo.
// userID == 0 means anonymous.
func checkMemoAccess(memo *store.Memo, userID int32) error {
switch memo.Visibility {
case store.Protected:
if userID == 0 {
return errors.New("permission denied")
}
case store.Private:
if memo.CreatorID != userID {
return errors.New("permission denied")
}
default:
// store.Public and any unknown visibility: allow
}
return nil
}
// applyVisibilityFilter restricts find to memos the caller may see.
func applyVisibilityFilter(find *store.FindMemo, userID int32) {
if userID == 0 {
find.VisibilityList = []store.Visibility{store.Public}
} else {
find.Filters = append(find.Filters, fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, userID))
}
}
// parseMemoUID extracts the UID from a "memos/<uid>" resource name.
func parseMemoUID(name string) (string, error) {
uid, ok := strings.CutPrefix(name, "memos/")
if !ok || uid == "" {
return "", errors.Errorf(`memo name must be in the format "memos/<uid>", got %q`, name)
}
return uid, nil
}
// parseVisibility validates a visibility string and returns the store constant.
func parseVisibility(s string) (store.Visibility, error) {
switch v := store.Visibility(s); v {
case store.Public, store.Protected, store.Private:
return v, nil
default:
return "", errors.Errorf("visibility must be PRIVATE, PROTECTED, or PUBLIC; got %q", s)
}
}
// parseRowStatus validates a state string and returns the store constant.
func parseRowStatus(s string) (store.RowStatus, error) {
switch rs := store.RowStatus(s); rs {
case store.Normal, store.Archived:
return rs, nil
default:
return "", errors.Errorf("state must be NORMAL or ARCHIVED; got %q", s)
}
}
func extractUserID(ctx context.Context) (int32, error) {
id := auth.GetUserID(ctx)
if id == 0 {
return 0, errors.New("unauthenticated: a personal access token is required")
}
return id, nil
}
func marshalJSON(v any) (string, error) {
b, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(b), nil
}
func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_memos",
mcp.WithDescription("List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos."),
mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1100, default 20)")),
mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")),
mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
),
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
), s.handleListMemos)
mcpSrv.AddTool(mcp.NewTool("get_memo",
mcp.WithDescription("Get a single memo by resource name. Public memos are accessible without authentication."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
), s.handleGetMemo)
mcpSrv.AddTool(mcp.NewTool("create_memo",
mcp.WithDescription("Create a new memo. Requires authentication."),
mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")),
mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("Visibility (default: PRIVATE)"),
),
), s.handleCreateMemo)
mcpSrv.AddTool(mcp.NewTool("update_memo",
mcp.WithDescription("Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Description("New Markdown content")),
mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("New visibility"),
),
mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")),
mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"),
),
), s.handleUpdateMemo)
mcpSrv.AddTool(mcp.NewTool("delete_memo",
mcp.WithDescription("Permanently delete a memo. Requires authentication and ownership."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
), s.handleDeleteMemo)
mcpSrv.AddTool(mcp.NewTool("search_memos",
mcp.WithDescription("Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only."),
mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")),
), s.handleSearchMemos)
mcpSrv.AddTool(mcp.NewTool("list_memo_comments",
mcp.WithDescription("List comments on a memo. Visibility rules for comments match those of the parent memo."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
), s.handleListMemoComments)
mcpSrv.AddTool(mcp.NewTool("create_memo_comment",
mcp.WithDescription("Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")),
), s.handleCreateMemoComment)
}
func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
pageSize := req.GetInt("page_size", 20)
if pageSize <= 0 {
pageSize = 20
}
if pageSize > 100 {
pageSize = 100
}
page := req.GetInt("page", 0)
if page < 0 {
page = 0
}
var rowStatus *store.RowStatus
if state := req.GetString("state", "NORMAL"); state != "" {
rs, err := parseRowStatus(state)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
rowStatus = &rs
}
limit := pageSize + 1
offset := page * pageSize
find := &store.FindMemo{
ExcludeComments: true,
RowStatus: rowStatus,
Limit: &limit,
Offset: &offset,
OrderByPinned: req.GetBool("order_by_pinned", false),
}
applyVisibilityFilter(find, userID)
if filter := req.GetString("filter", ""); filter != "" {
find.Filters = append(find.Filters, filter)
}
memos, err := s.store.ListMemos(ctx, find)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
}
hasMore := len(memos) > pageSize
if hasMore {
memos = memos[:pageSize]
}
results := make([]memoJSON, len(memos))
for i, m := range memos {
results[i] = storeMemoToJSON(m)
}
type listResponse struct {
Memos []memoJSON `json:"memos"`
HasMore bool `json:"has_more"`
}
out, err := marshalJSON(listResponse{Memos: results, HasMore: hasMore})
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if memo == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoAccess(memo, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
out, err := marshalJSON(storeMemoToJSON(memo))
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
content := req.GetString("content", "")
if content == "" {
return mcp.NewToolResultError("content is required"), nil
}
visibility, err := parseVisibility(req.GetString("visibility", "PRIVATE"))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
memo, err := s.store.CreateMemo(ctx, &store.Memo{
UID: shortuuid.New(),
CreatorID: userID,
Content: content,
Visibility: visibility,
Payload: buildPayload(content),
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil
}
out, err := marshalJSON(storeMemoToJSON(memo))
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if memo == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if memo.CreatorID != userID {
return mcp.NewToolResultError("permission denied"), nil
}
update := &store.UpdateMemo{ID: memo.ID}
args := req.GetArguments()
if v := req.GetString("content", ""); v != "" {
update.Content = &v
update.Payload = buildPayload(v)
}
if v := req.GetString("visibility", ""); v != "" {
vis, err := parseVisibility(v)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
update.Visibility = &vis
}
if v := req.GetString("state", ""); v != "" {
rs, err := parseRowStatus(v)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
update.RowStatus = &rs
}
if _, ok := args["pinned"]; ok {
pinned := req.GetBool("pinned", false)
update.Pinned = &pinned
}
if err := s.store.UpdateMemo(ctx, update); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to update memo: %v", err)), nil
}
updated, err := s.store.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil
}
out, err := marshalJSON(storeMemoToJSON(updated))
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if memo == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if memo.CreatorID != userID {
return mcp.NewToolResultError("permission denied"), nil
}
if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil
}
return mcp.NewToolResultText(`{"deleted":true}`), nil
}
func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
query := req.GetString("query", "")
if query == "" {
return mcp.NewToolResultError("query is required"), nil
}
limit := 50
zero := 0
rowStatus := store.Normal
find := &store.FindMemo{
ExcludeComments: true,
RowStatus: &rowStatus,
Limit: &limit,
Offset: &zero,
Filters: []string{fmt.Sprintf(`content.contains(%q)`, query)},
}
applyVisibilityFilter(find, userID)
memos, err := s.store.ListMemos(ctx, find)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to search memos: %v", err)), nil
}
results := make([]memoJSON, len(memos))
for i, m := range memos {
results[i] = storeMemoToJSON(m)
}
out, err := marshalJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
parent, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if parent == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoAccess(parent, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
relationType := store.MemoRelationComment
relations, err := s.store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &parent.ID,
Type: &relationType,
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list relations: %v", err)), nil
}
if len(relations) == 0 {
out, _ := marshalJSON([]memoJSON{})
return mcp.NewToolResultText(out), nil
}
commentIDs := make([]int32, len(relations))
for i, r := range relations {
commentIDs[i] = r.MemoID
}
memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: commentIDs})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list comments: %v", err)), nil
}
results := make([]memoJSON, 0, len(memos))
for _, m := range memos {
if checkMemoAccess(m, userID) == nil {
results = append(results, storeMemoToJSON(m))
}
}
out, err := marshalJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
content := req.GetString("content", "")
if content == "" {
return mcp.NewToolResultError("content is required"), nil
}
parent, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if parent == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoAccess(parent, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
comment, err := s.store.CreateMemo(ctx, &store.Memo{
UID: shortuuid.New(),
CreatorID: userID,
Content: content,
Visibility: parent.Visibility,
Payload: buildPayload(content),
ParentUID: &parent.UID,
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %v", err)), nil
}
if _, err = s.store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: comment.ID,
RelatedMemoID: parent.ID,
Type: store.MemoRelationComment,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to link comment: %v", err)), nil
}
out, err := marshalJSON(storeMemoToJSON(comment))
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}

View File

@@ -0,0 +1,68 @@
package mcp
import (
"context"
"fmt"
"sort"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_tags",
mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."),
), s.handleListTags)
}
type tagEntry struct {
Tag string `json:"tag"`
Count int `json:"count"`
}
func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
rowStatus := store.Normal
find := &store.FindMemo{
ExcludeComments: true,
ExcludeContent: true,
RowStatus: &rowStatus,
}
applyVisibilityFilter(find, userID)
memos, err := s.store.ListMemos(ctx, find)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
}
counts := make(map[string]int)
for _, m := range memos {
if m.Payload == nil {
continue
}
for _, tag := range m.Payload.Tags {
counts[tag]++
}
}
entries := make([]tagEntry, 0, len(counts))
for tag, count := range counts {
entries = append(entries, tagEntry{Tag: tag, Count: count})
}
sort.Slice(entries, func(i, j int) bool {
if entries[i].Count != entries[j].Count {
return entries[i].Count > entries[j].Count
}
return entries[i].Tag < entries[j].Tag
})
out, err := marshalJSON(entries)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}

423
server/router/rss/rss.go Normal file
View File

@@ -0,0 +1,423 @@
package rss
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/gorilla/feeds"
"github.com/labstack/echo/v5"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/plugin/markdown"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
const (
maxRSSItemCount = 100
defaultCacheDuration = 1 * time.Hour
maxCacheSize = 50 // Maximum number of cached feeds
)
var (
// Regex to match markdown headings at the start of a line.
markdownHeadingRegex = regexp.MustCompile(`^#{1,6}\s*`)
)
// cacheEntry represents a cached RSS feed with expiration.
type cacheEntry struct {
content string
etag string
lastModified time.Time
createdAt time.Time
}
type RSSService struct {
Profile *profile.Profile
Store *store.Store
MarkdownService markdown.Service
// Cache for RSS feeds
cache map[string]*cacheEntry
cacheMutex sync.RWMutex
}
type RSSHeading struct {
Title string
Description string
Language string
}
func NewRSSService(profile *profile.Profile, store *store.Store, markdownService markdown.Service) *RSSService {
return &RSSService{
Profile: profile,
Store: store,
MarkdownService: markdownService,
cache: make(map[string]*cacheEntry),
}
}
func (s *RSSService) RegisterRoutes(g *echo.Group) {
g.GET("/explore/rss.xml", s.GetExploreRSS)
g.GET("/u/:username/rss.xml", s.GetUserRSS)
}
func (s *RSSService) GetExploreRSS(c *echo.Context) error {
ctx := c.Request().Context()
cacheKey := "explore"
// Check cache first
if cached := s.getFromCache(cacheKey); cached != nil {
// Check ETag for conditional request
if c.Request().Header.Get("If-None-Match") == cached.etag {
return c.NoContent(http.StatusNotModified)
}
s.setRSSHeaders(c, cached.etag, cached.lastModified)
return c.String(http.StatusOK, cached.content)
}
normalStatus := store.Normal
limit := maxRSSItemCount
memoFind := store.FindMemo{
RowStatus: &normalStatus,
VisibilityList: []store.Visibility{store.Public},
Limit: &limit,
}
memoList, err := s.Store.ListMemos(ctx, &memoFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").Wrap(err)
}
baseURL := c.Scheme() + "://" + c.Request().Host
rss, lastModified, err := s.generateRSSFromMemoList(ctx, memoList, baseURL, nil)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate rss").Wrap(err)
}
// Cache the result
etag := s.putInCache(cacheKey, rss, lastModified)
s.setRSSHeaders(c, etag, lastModified)
return c.String(http.StatusOK, rss)
}
func (s *RSSService) GetUserRSS(c *echo.Context) error {
ctx := c.Request().Context()
username := c.Param("username")
cacheKey := "user:" + username
// Check cache first
if cached := s.getFromCache(cacheKey); cached != nil {
// Check ETag for conditional request
if c.Request().Header.Get("If-None-Match") == cached.etag {
return c.NoContent(http.StatusNotModified)
}
s.setRSSHeaders(c, cached.etag, cached.lastModified)
return c.String(http.StatusOK, cached.content)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").Wrap(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusNotFound, "User not found")
}
normalStatus := store.Normal
limit := maxRSSItemCount
memoFind := store.FindMemo{
CreatorID: &user.ID,
RowStatus: &normalStatus,
VisibilityList: []store.Visibility{store.Public},
Limit: &limit,
}
memoList, err := s.Store.ListMemos(ctx, &memoFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").Wrap(err)
}
baseURL := c.Scheme() + "://" + c.Request().Host
rss, lastModified, err := s.generateRSSFromMemoList(ctx, memoList, baseURL, user)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate rss").Wrap(err)
}
// Cache the result
etag := s.putInCache(cacheKey, rss, lastModified)
s.setRSSHeaders(c, etag, lastModified)
return c.String(http.StatusOK, rss)
}
func (s *RSSService) generateRSSFromMemoList(ctx context.Context, memoList []*store.Memo, baseURL string, user *store.User) (string, time.Time, error) {
rssHeading, err := getRSSHeading(ctx, s.Store)
if err != nil {
return "", time.Time{}, err
}
feed := &feeds.Feed{
Title: rssHeading.Title,
Link: &feeds.Link{Href: baseURL},
Description: rssHeading.Description,
Created: time.Now(),
}
var itemCountLimit = min(len(memoList), maxRSSItemCount)
if itemCountLimit == 0 {
// Return empty feed if no memos
rss, err := feed.ToRss()
return rss, time.Time{}, err
}
// Track the most recent update time for Last-Modified header
var lastModified time.Time
if len(memoList) > 0 {
lastModified = time.Unix(memoList[0].UpdatedTs, 0)
}
// Batch load all attachments for all memos to avoid N+1 query problem
memoIDs := make([]int32, itemCountLimit)
for i := 0; i < itemCountLimit; i++ {
memoIDs[i] = memoList[i].ID
}
allAttachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoIDList: memoIDs,
})
if err != nil {
return "", lastModified, err
}
// Group attachments by memo ID for quick lookup
attachmentsByMemoID := make(map[int32][]*store.Attachment)
for _, attachment := range allAttachments {
if attachment.MemoID != nil {
attachmentsByMemoID[*attachment.MemoID] = append(attachmentsByMemoID[*attachment.MemoID], attachment)
}
}
// Batch load all memo creators
creatorMap := make(map[int32]*store.User)
if user != nil {
// Single user feed - reuse the user object
creatorMap[user.ID] = user
} else {
// Multi-user feed - batch load all unique creators
creatorIDs := make(map[int32]bool)
for _, memo := range memoList[:itemCountLimit] {
creatorIDs[memo.CreatorID] = true
}
// Batch load all users with a single query by getting all users and filtering
// Note: This is more efficient than N separate queries
for creatorID := range creatorIDs {
creator, err := s.Store.GetUser(ctx, &store.FindUser{ID: &creatorID})
if err == nil && creator != nil {
creatorMap[creatorID] = creator
}
}
}
// Generate feed items
feed.Items = make([]*feeds.Item, itemCountLimit)
for i := 0; i < itemCountLimit; i++ {
memo := memoList[i]
// Generate item title from memo content
title := s.generateItemTitle(memo.Content)
// Render content as HTML
htmlContent, err := s.getRSSItemDescription(memo.Content)
if err != nil {
return "", lastModified, err
}
link := &feeds.Link{Href: baseURL + "/memos/" + memo.UID}
item := &feeds.Item{
Title: title,
Link: link,
Description: htmlContent, // Summary/excerpt
Content: htmlContent, // Full content in content:encoded
Created: time.Unix(memo.CreatedTs, 0),
Updated: time.Unix(memo.UpdatedTs, 0),
Id: link.Href,
}
// Add author information
if creator, ok := creatorMap[memo.CreatorID]; ok {
authorName := creator.Nickname
if authorName == "" {
authorName = creator.Username
}
item.Author = &feeds.Author{
Name: authorName,
Email: creator.Email,
}
}
// Note: gorilla/feeds doesn't support categories in RSS items
// Tags could be added to the description or content if needed
// Add first attachment as enclosure
if attachments, ok := attachmentsByMemoID[memo.ID]; ok && len(attachments) > 0 {
attachment := attachments[0]
enclosure := feeds.Enclosure{}
if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 {
enclosure.Url = attachment.Reference
} else {
enclosure.Url = fmt.Sprintf("%s/file/attachments/%s/%s", baseURL, attachment.UID, attachment.Filename)
}
enclosure.Length = strconv.Itoa(int(attachment.Size))
enclosure.Type = attachment.Type
item.Enclosure = &enclosure
}
feed.Items[i] = item
}
rss, err := feed.ToRss()
if err != nil {
return "", lastModified, err
}
return rss, lastModified, nil
}
func (*RSSService) generateItemTitle(content string) string {
// Extract first line as title
lines := strings.Split(content, "\n")
title := strings.TrimSpace(lines[0])
// Remove markdown heading syntax using regex (handles # to ###### with optional spaces)
title = markdownHeadingRegex.ReplaceAllString(title, "")
title = strings.TrimSpace(title)
// Limit title length
const maxTitleLength = 100
if len(title) > maxTitleLength {
// Find last space before limit to avoid cutting words
cutoff := maxTitleLength
for i := min(maxTitleLength-1, len(title)-1); i > 0; i-- {
if title[i] == ' ' {
cutoff = i
break
}
}
if cutoff < maxTitleLength {
title = title[:cutoff] + "..."
} else {
// No space found, just truncate
title = title[:maxTitleLength] + "..."
}
}
// If title is empty, use a default
if title == "" {
title = "Memo"
}
return title
}
func (s *RSSService) getRSSItemDescription(content string) (string, error) {
html, err := s.MarkdownService.RenderHTML([]byte(content))
if err != nil {
return "", err
}
return html, nil
}
// getFromCache retrieves a cached feed entry if it exists and is not expired.
func (s *RSSService) getFromCache(key string) *cacheEntry {
s.cacheMutex.RLock()
entry, exists := s.cache[key]
s.cacheMutex.RUnlock()
if !exists {
return nil
}
// Check if cache entry is still valid
if time.Since(entry.createdAt) > defaultCacheDuration {
// Entry is expired, remove it
s.cacheMutex.Lock()
delete(s.cache, key)
s.cacheMutex.Unlock()
return nil
}
return entry
}
// putInCache stores a feed in the cache and returns its ETag.
func (s *RSSService) putInCache(key, content string, lastModified time.Time) string {
s.cacheMutex.Lock()
defer s.cacheMutex.Unlock()
// Generate ETag from content hash
hash := sha256.Sum256([]byte(content))
etag := fmt.Sprintf(`"%x"`, hash[:8])
// Implement simple LRU: if cache is too large, remove oldest entries
if len(s.cache) >= maxCacheSize {
var oldestKey string
var oldestTime time.Time
for k, v := range s.cache {
if oldestKey == "" || v.createdAt.Before(oldestTime) {
oldestKey = k
oldestTime = v.createdAt
}
}
if oldestKey != "" {
delete(s.cache, oldestKey)
}
}
s.cache[key] = &cacheEntry{
content: content,
etag: etag,
lastModified: lastModified,
createdAt: time.Now(),
}
return etag
}
// setRSSHeaders sets appropriate HTTP headers for RSS responses.
func (*RSSService) setRSSHeaders(c *echo.Context, etag string, lastModified time.Time) {
c.Response().Header().Set(echo.HeaderContentType, "application/rss+xml; charset=utf-8")
c.Response().Header().Set(echo.HeaderCacheControl, fmt.Sprintf("public, max-age=%d", int(defaultCacheDuration.Seconds())))
c.Response().Header().Set("ETag", etag)
if !lastModified.IsZero() {
c.Response().Header().Set("Last-Modified", lastModified.UTC().Format(http.TimeFormat))
}
}
func getRSSHeading(ctx context.Context, stores *store.Store) (RSSHeading, error) {
settings, err := stores.GetInstanceGeneralSetting(ctx)
if err != nil {
return RSSHeading{}, err
}
if settings == nil || settings.CustomProfile == nil {
return RSSHeading{
Title: "Memos",
Description: "An open source, lightweight note-taking service. Easily capture and share your great thoughts.",
Language: "en-us",
}, nil
}
customProfile := settings.CustomProfile
return RSSHeading{
Title: customProfile.Title,
Description: customProfile.Description,
Language: "en-us",
}, nil
}