first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled

This commit is contained in:
2026-03-04 06:30:47 +00:00
commit bb402d4ccc
777 changed files with 135661 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
package v1
// PublicMethods defines API endpoints that don't require authentication.
// All other endpoints require a valid session or access token.
//
// This is the SINGLE SOURCE OF TRUTH for public endpoints.
// Both Connect interceptor and gRPC-Gateway interceptor use this map.
//
// Format: Full gRPC procedure path as returned by req.Spec().Procedure (Connect)
// or info.FullMethod (gRPC interceptor).
var PublicMethods = map[string]struct{}{
// Auth Service - login/token endpoints must be accessible without auth
"/memos.api.v1.AuthService/SignIn": {},
"/memos.api.v1.AuthService/RefreshToken": {}, // Token refresh uses cookie, must be accessible when access token expired
// Instance Service - needed before login to show instance info
"/memos.api.v1.InstanceService/GetInstanceProfile": {},
"/memos.api.v1.InstanceService/GetInstanceSetting": {},
// User Service - public user profiles and stats
"/memos.api.v1.UserService/CreateUser": {}, // Allow first user registration
"/memos.api.v1.UserService/GetUser": {},
"/memos.api.v1.UserService/GetUserAvatar": {},
"/memos.api.v1.UserService/GetUserStats": {},
"/memos.api.v1.UserService/ListAllUserStats": {},
"/memos.api.v1.UserService/SearchUsers": {},
// Identity Provider Service - SSO buttons on login page
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": {},
// Memo Service - public memos (visibility filtering done in service layer)
"/memos.api.v1.MemoService/GetMemo": {},
"/memos.api.v1.MemoService/ListMemos": {},
"/memos.api.v1.MemoService/ListMemoComments": {},
}
// IsPublicMethod checks if a procedure path is public (no authentication required).
// Returns true for public methods, false for protected methods.
func IsPublicMethod(procedure string) bool {
_, ok := PublicMethods[procedure]
return ok
}

View File

@@ -0,0 +1,88 @@
package v1
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestPublicMethodsArePublic verifies that methods in PublicMethods are recognized as public.
func TestPublicMethodsArePublic(t *testing.T) {
publicMethods := []string{
// Auth Service
"/memos.api.v1.AuthService/SignIn",
"/memos.api.v1.AuthService/RefreshToken",
// Instance Service
"/memos.api.v1.InstanceService/GetInstanceProfile",
"/memos.api.v1.InstanceService/GetInstanceSetting",
// User Service
"/memos.api.v1.UserService/CreateUser",
"/memos.api.v1.UserService/GetUser",
"/memos.api.v1.UserService/GetUserAvatar",
"/memos.api.v1.UserService/GetUserStats",
"/memos.api.v1.UserService/ListAllUserStats",
"/memos.api.v1.UserService/SearchUsers",
// Identity Provider Service
"/memos.api.v1.IdentityProviderService/ListIdentityProviders",
// Memo Service
"/memos.api.v1.MemoService/GetMemo",
"/memos.api.v1.MemoService/ListMemos",
}
for _, method := range publicMethods {
t.Run(method, func(t *testing.T) {
assert.True(t, IsPublicMethod(method), "Expected %s to be public", method)
})
}
}
// TestProtectedMethodsRequireAuth verifies that non-public methods are recognized as protected.
func TestProtectedMethodsRequireAuth(t *testing.T) {
protectedMethods := []string{
// Auth Service - logout and get current user require auth
"/memos.api.v1.AuthService/SignOut",
"/memos.api.v1.AuthService/GetCurrentUser",
// Instance Service - admin operations
"/memos.api.v1.InstanceService/UpdateInstanceSetting",
// User Service - modification operations
"/memos.api.v1.UserService/ListUsers",
"/memos.api.v1.UserService/UpdateUser",
"/memos.api.v1.UserService/DeleteUser",
// Memo Service - write operations
"/memos.api.v1.MemoService/CreateMemo",
"/memos.api.v1.MemoService/UpdateMemo",
"/memos.api.v1.MemoService/DeleteMemo",
// Attachment Service - write operations
"/memos.api.v1.AttachmentService/CreateAttachment",
"/memos.api.v1.AttachmentService/DeleteAttachment",
// Shortcut Service
"/memos.api.v1.ShortcutService/CreateShortcut",
"/memos.api.v1.ShortcutService/ListShortcuts",
"/memos.api.v1.ShortcutService/UpdateShortcut",
"/memos.api.v1.ShortcutService/DeleteShortcut",
// Activity Service
"/memos.api.v1.ActivityService/GetActivity",
}
for _, method := range protectedMethods {
t.Run(method, func(t *testing.T) {
assert.False(t, IsPublicMethod(method), "Expected %s to require auth", method)
})
}
}
// TestUnknownMethodsRequireAuth verifies that unknown methods default to requiring auth.
func TestUnknownMethodsRequireAuth(t *testing.T) {
unknownMethods := []string{
"/unknown.Service/Method",
"/memos.api.v1.UnknownService/Method",
"",
"invalid",
}
for _, method := range unknownMethods {
t.Run(method, func(t *testing.T) {
assert.False(t, IsPublicMethod(method), "Unknown method %q should require auth", method)
})
}
}

View File

@@ -0,0 +1,155 @@
package v1
import (
"context"
"fmt"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListActivities(ctx context.Context, request *v1pb.ListActivitiesRequest) (*v1pb.ListActivitiesResponse, error) {
// Set default page size if not specified
pageSize := request.PageSize
if pageSize <= 0 || pageSize > 1000 {
pageSize = 100
}
// TODO: Implement pagination with page_token and use pageSize for limiting
// For now, we'll fetch all activities and the pageSize will be used in future pagination implementation
_ = pageSize // Acknowledge pageSize variable to avoid linter warning
activities, err := s.Store.ListActivities(ctx, &store.FindActivity{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list activities: %v", err)
}
var activityMessages []*v1pb.Activity
for _, activity := range activities {
activityMessage, err := s.convertActivityFromStore(ctx, activity)
if err != nil {
// Skip activities that reference deleted memos instead of failing the entire list
continue
}
if activityMessage != nil {
activityMessages = append(activityMessages, activityMessage)
}
}
return &v1pb.ListActivitiesResponse{
Activities: activityMessages,
// TODO: Implement next_page_token for pagination
}, nil
}
func (s *APIV1Service) GetActivity(ctx context.Context, request *v1pb.GetActivityRequest) (*v1pb.Activity, error) {
activityID, err := ExtractActivityIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid activity name: %v", err)
}
activity, err := s.Store.GetActivity(ctx, &store.FindActivity{
ID: &activityID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get activity: %v", err)
}
activityMessage, err := s.convertActivityFromStore(ctx, activity)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert activity from store: %v", err)
}
if activityMessage == nil {
return nil, status.Errorf(codes.NotFound, "activity references deleted content")
}
return activityMessage, nil
}
// convertActivityFromStore converts a storage-layer activity to an API activity.
// This handles the mapping between internal activity representation and the public API,
// including proper type and level conversions.
// Returns nil if the activity references deleted content (to allow graceful skipping).
func (s *APIV1Service) convertActivityFromStore(ctx context.Context, activity *store.Activity) (*v1pb.Activity, error) {
payload, err := s.convertActivityPayloadFromStore(ctx, activity.Payload)
if err != nil {
return nil, err
}
// Skip activities that reference deleted memos
if payload == nil {
return nil, nil
}
// Convert store activity type to proto enum
var activityType v1pb.Activity_Type
switch activity.Type {
case store.ActivityTypeMemoComment:
activityType = v1pb.Activity_MEMO_COMMENT
default:
activityType = v1pb.Activity_TYPE_UNSPECIFIED
}
// Convert store activity level to proto enum
var activityLevel v1pb.Activity_Level
switch activity.Level {
case store.ActivityLevelInfo:
activityLevel = v1pb.Activity_INFO
default:
activityLevel = v1pb.Activity_LEVEL_UNSPECIFIED
}
return &v1pb.Activity{
Name: fmt.Sprintf("%s%d", ActivityNamePrefix, activity.ID),
Creator: fmt.Sprintf("%s%d", UserNamePrefix, activity.CreatorID),
Type: activityType,
Level: activityLevel,
CreateTime: timestamppb.New(time.Unix(activity.CreatedTs, 0)),
Payload: payload,
}, nil
}
// convertActivityPayloadFromStore converts a storage-layer activity payload to an API payload.
// This resolves references (e.g., memo IDs) to resource names for the API.
// Returns nil if the activity references deleted content (to allow graceful skipping).
func (s *APIV1Service) convertActivityPayloadFromStore(ctx context.Context, payload *storepb.ActivityPayload) (*v1pb.ActivityPayload, error) {
v2Payload := &v1pb.ActivityPayload{}
if payload.MemoComment != nil {
// Fetch the comment memo
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &payload.MemoComment.MemoId,
ExcludeContent: true,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
// If the comment memo was deleted, skip this activity gracefully
if memo == nil {
return nil, nil
}
// Fetch the related memo (the one being commented on)
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &payload.MemoComment.RelatedMemoId,
ExcludeContent: true,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get related memo: %v", err)
}
// If the related memo was deleted, skip this activity gracefully
if relatedMemo == nil {
return nil, nil
}
v2Payload.Payload = &v1pb.ActivityPayload_MemoComment{
MemoComment: &v1pb.ActivityMemoCommentPayload{
Memo: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
RelatedMemo: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
},
}
}
return v2Payload, nil
}

View File

@@ -0,0 +1,397 @@
package v1
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/labstack/echo/v5"
groqai "github.com/usememos/memos/plugin/ai/groq"
ollamaai "github.com/usememos/memos/plugin/ai/ollama"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
const aiSettingsStoreKey = "ai_settings_v1"
// RegisterAIHTTPHandlers registers plain HTTP JSON handlers for the AI service.
// Uses direct raw DB storage to avoid proto generation dependency.
func (s *APIV1Service) RegisterAIHTTPHandlers(g *echo.Group) {
g.POST("/api/ai/settings", s.handleGetAISettings)
g.POST("/api/ai/settings/update", s.handleUpdateAISettings)
g.POST("/api/ai/complete", s.handleGenerateCompletion)
g.POST("/api/ai/test", s.handleTestAIProvider)
g.POST("/api/ai/models/ollama", s.handleListOllamaModels)
}
type aiSettingsJSON struct {
Groq *groqSettingsJSON `json:"groq,omitempty"`
Ollama *ollamaSettingsJSON `json:"ollama,omitempty"`
}
type groqSettingsJSON struct {
Enabled bool `json:"enabled"`
APIKey string `json:"apiKey"`
DefaultModel string `json:"defaultModel"`
}
type ollamaSettingsJSON struct {
Enabled bool `json:"enabled"`
Host string `json:"host"`
DefaultModel string `json:"defaultModel"`
TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
}
func defaultAISettings() *aiSettingsJSON {
return &aiSettingsJSON{
Groq: &groqSettingsJSON{DefaultModel: "llama-3.1-8b-instant"},
Ollama: &ollamaSettingsJSON{Host: "http://localhost:11434", DefaultModel: "llama3", TimeoutSeconds: 120},
}
}
func (s *APIV1Service) loadAISettings(ctx context.Context) (*aiSettingsJSON, error) {
settings := defaultAISettings()
list, err := s.Store.GetDriver().ListInstanceSettings(ctx, &store.FindInstanceSetting{Name: aiSettingsStoreKey})
if err != nil || len(list) == 0 {
return settings, nil
}
_ = json.Unmarshal([]byte(list[0].Value), settings)
return settings, nil
}
func (s *APIV1Service) saveAISettings(ctx context.Context, settings *aiSettingsJSON) error {
b, err := json.Marshal(settings)
if err != nil {
return err
}
_, err = s.Store.GetDriver().UpsertInstanceSetting(ctx, &store.InstanceSetting{
Name: aiSettingsStoreKey,
Value: string(b),
})
return err
}
func (s *APIV1Service) handleGetAISettings(c *echo.Context) error {
ctx := c.Request().Context()
if auth.GetUserID(ctx) == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
settings, err := s.loadAISettings(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, settings)
}
func (s *APIV1Service) handleUpdateAISettings(c *echo.Context) error {
ctx := c.Request().Context()
if auth.GetUserID(ctx) == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
var req aiSettingsJSON
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
if err := s.saveAISettings(ctx, &req); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, req)
}
type generateCompletionRequest struct {
Provider int `json:"provider"`
Prompt string `json:"prompt"`
Model string `json:"model"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"maxTokens"`
AutoTag bool `json:"autoTag,omitempty"`
SpellCheck bool `json:"spellCheck,omitempty"`
}
type generateCompletionResponse struct {
Text string `json:"text"`
ModelUsed string `json:"modelUsed"`
PromptTokens int `json:"promptTokens"`
CompletionTokens int `json:"completionTokens"`
TotalTokens int `json:"totalTokens"`
}
func (s *APIV1Service) handleGenerateCompletion(c *echo.Context) error {
ctx := c.Request().Context()
if auth.GetUserID(ctx) == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
var req generateCompletionRequest
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
if req.Prompt == "" {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "prompt is required"})
}
settings, err := s.loadAISettings(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
temp := float32(0.7)
if req.Temperature > 0 {
temp = float32(req.Temperature)
}
maxTokens := 1000
if req.MaxTokens > 0 {
maxTokens = req.MaxTokens
}
switch req.Provider {
case 1: // Groq
if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Groq is not enabled or API key is missing. Configure it in Settings > AI."})
}
model := req.Model
if model == "" {
model = settings.Groq.DefaultModel
}
if model == "" {
model = "llama-3.1-8b-instant"
}
client := groqai.NewGroqClient(groqai.GroqConfig{
APIKey: settings.Groq.APIKey,
DefaultModel: model,
})
resp, err := client.GenerateCompletion(ctx, groqai.CompletionRequest{
Model: model,
Prompt: req.Prompt,
Temperature: temp,
MaxTokens: maxTokens,
})
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
// Apply post-processing if requested
text := resp.Text
if req.SpellCheck {
text = s.correctSpelling(ctx, text)
}
if req.AutoTag {
text = s.addAutoTags(ctx, text)
}
return c.JSON(http.StatusOK, generateCompletionResponse{
Text: text, ModelUsed: resp.Model,
PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens,
})
case 2: // Ollama
if settings.Ollama == nil || !settings.Ollama.Enabled {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."})
}
model := req.Model
if model == "" {
model = settings.Ollama.DefaultModel
}
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
Host: settings.Ollama.Host,
DefaultModel: model,
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
})
resp, err := client.GenerateCompletion(ctx, ollamaai.CompletionRequest{
Model: model,
Prompt: req.Prompt,
Temperature: temp,
MaxTokens: maxTokens,
})
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
// Apply post-processing if requested
text := resp.Text
if req.SpellCheck {
text = s.correctSpelling(ctx, text)
}
if req.AutoTag {
text = s.addAutoTags(ctx, text)
}
return c.JSON(http.StatusOK, generateCompletionResponse{
Text: text, ModelUsed: resp.Model,
PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens,
})
default:
return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider. Use 1 for Groq, 2 for Ollama"})
}
}
func (s *APIV1Service) handleTestAIProvider(c *echo.Context) error {
ctx := c.Request().Context()
if auth.GetUserID(ctx) == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
var req struct {
Provider int `json:"provider"`
}
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
settings, err := s.loadAISettings(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
switch req.Provider {
case 1:
if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" {
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Groq not enabled or API key missing"})
}
client := groqai.NewGroqClient(groqai.GroqConfig{APIKey: settings.Groq.APIKey, DefaultModel: settings.Groq.DefaultModel})
if err := client.TestConnection(ctx); err != nil {
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()})
}
return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Groq connected successfully"})
case 2:
if settings.Ollama == nil || !settings.Ollama.Enabled {
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Ollama not enabled"})
}
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
Host: settings.Ollama.Host,
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
})
if err := client.TestConnection(ctx); err != nil {
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()})
}
return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Ollama connected successfully"})
default:
return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider"})
}
}
// correctSpelling applies basic spelling corrections to the text
func (s *APIV1Service) correctSpelling(ctx context.Context, text string) string {
// TODO: Implement actual spell checking using a spell checker library
// For now, return the text as-is
return text
}
// addAutoTags analyzes the text and adds relevant #tags
func (s *APIV1Service) addAutoTags(ctx context.Context, text string) string {
// Common programming languages and technologies
tagMap := map[string][]string{
"javascript": {"javascript", "js"},
"python": {"python", "py"},
"java": {"java"},
"go": {"go", "golang"},
"rust": {"rust"},
"dart": {"dart"},
"flutter": {"flutter"},
"react": {"react", "reactjs"},
"vue": {"vue", "vuejs"},
"angular": {"angular"},
"node": {"nodejs", "node"},
"express": {"express"},
"mongodb": {"mongodb"},
"postgresql": {"postgresql", "postgres"},
"mysql": {"mysql"},
"redis": {"redis"},
"docker": {"docker"},
"kubernetes": {"kubernetes", "k8s"},
"aws": {"aws", "amazon"},
"azure": {"azure"},
"gcp": {"gcp", "google-cloud"},
"git": {"git"},
"github": {"github"},
"gitlab": {"gitlab"},
"ci/cd": {"ci-cd"},
"testing": {"testing"},
"api": {"api"},
"rest": {"rest", "api"},
"graphql": {"graphql"},
"database": {"database", "db"},
"frontend": {"frontend"},
"backend": {"backend"},
"fullstack": {"fullstack"},
"mobile": {"mobile"},
"web": {"web"},
"devops": {"devops"},
"security": {"security"},
"machine learning": {"ml", "machine-learning"},
"ai": {"ai", "artificial-intelligence"},
"blockchain": {"blockchain"},
"crypto": {"crypto", "cryptocurrency"},
}
// Convert text to lowercase for matching
lowerText := strings.ToLower(text)
var tags []string
tagSet := make(map[string]bool)
// Check for programming languages and technologies
for keyword, tagList := range tagMap {
if strings.Contains(lowerText, keyword) {
for _, tag := range tagList {
if !tagSet[tag] {
tags = append(tags, tag)
tagSet[tag] = true
}
}
}
}
// Add tags to the end of the text if any were found
if len(tags) > 0 {
tagString := "\n\n"
for i, tag := range tags {
if i > 0 {
tagString += " "
}
tagString += "#" + tag
}
text += tagString
}
return text
}
func (s *APIV1Service) handleListOllamaModels(c *echo.Context) error {
ctx := c.Request().Context()
if auth.GetUserID(ctx) == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
settings, err := s.loadAISettings(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if settings.Ollama == nil || !settings.Ollama.Enabled {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."})
}
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
Host: settings.Ollama.Host,
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
})
models, err := client.ListModels(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": fmt.Sprintf("failed to list Ollama models: %v", err)})
}
// Convert to the same format as frontend expects
var modelList []map[string]string
for _, model := range models {
modelList = append(modelList, map[string]string{
"id": model,
"name": model,
})
}
return c.JSON(http.StatusOK, map[string]interface{}{
"models": modelList,
})
}

View File

@@ -0,0 +1,191 @@
package v1
import (
"bytes"
"image"
"image/color"
"image/jpeg"
"testing"
"github.com/disintegration/imaging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShouldStripExif(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mimeType string
expected bool
}{
{
name: "JPEG should strip EXIF",
mimeType: "image/jpeg",
expected: true,
},
{
name: "JPG should strip EXIF",
mimeType: "image/jpg",
expected: true,
},
{
name: "TIFF should strip EXIF",
mimeType: "image/tiff",
expected: true,
},
{
name: "WebP should strip EXIF",
mimeType: "image/webp",
expected: true,
},
{
name: "HEIC should strip EXIF",
mimeType: "image/heic",
expected: true,
},
{
name: "HEIF should strip EXIF",
mimeType: "image/heif",
expected: true,
},
{
name: "PNG should not strip EXIF",
mimeType: "image/png",
expected: false,
},
{
name: "GIF should not strip EXIF",
mimeType: "image/gif",
expected: false,
},
{
name: "text file should not strip EXIF",
mimeType: "text/plain",
expected: false,
},
{
name: "PDF should not strip EXIF",
mimeType: "application/pdf",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := shouldStripExif(tt.mimeType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestStripImageExif(t *testing.T) {
t.Parallel()
// Create a simple test image
img := image.NewRGBA(image.Rect(0, 0, 100, 100))
// Fill with red color
for y := 0; y < 100; y++ {
for x := 0; x < 100; x++ {
img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255})
}
}
// Encode as JPEG
var buf bytes.Buffer
err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 90})
require.NoError(t, err)
originalData := buf.Bytes()
t.Run("strip JPEG metadata", func(t *testing.T) {
t.Parallel()
strippedData, err := stripImageExif(originalData, "image/jpeg")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's still a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.Equal(t, 100, decodedImg.Bounds().Dx())
assert.Equal(t, 100, decodedImg.Bounds().Dy())
})
t.Run("strip JPG metadata (alternate extension)", func(t *testing.T) {
t.Parallel()
strippedData, err := stripImageExif(originalData, "image/jpg")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's still a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.NotNil(t, decodedImg)
})
t.Run("strip PNG metadata", func(t *testing.T) {
t.Parallel()
// Encode as PNG first
var pngBuf bytes.Buffer
err := imaging.Encode(&pngBuf, img, imaging.PNG)
require.NoError(t, err)
strippedData, err := stripImageExif(pngBuf.Bytes(), "image/png")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's still a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.Equal(t, 100, decodedImg.Bounds().Dx())
assert.Equal(t, 100, decodedImg.Bounds().Dy())
})
t.Run("handle WebP format by converting to JPEG", func(t *testing.T) {
t.Parallel()
// WebP format will be converted to JPEG
strippedData, err := stripImageExif(originalData, "image/webp")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.NotNil(t, decodedImg)
})
t.Run("handle HEIC format by converting to JPEG", func(t *testing.T) {
t.Parallel()
strippedData, err := stripImageExif(originalData, "image/heic")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.NotNil(t, decodedImg)
})
t.Run("return error for invalid image data", func(t *testing.T) {
t.Parallel()
invalidData := []byte("not an image")
_, err := stripImageExif(invalidData, "image/jpeg")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode image")
})
t.Run("return error for empty image data", func(t *testing.T) {
t.Parallel()
emptyData := []byte{}
_, err := stripImageExif(emptyData, "image/jpeg")
assert.Error(t, err)
})
}

View File

@@ -0,0 +1,646 @@
package v1
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"log/slog"
"mime"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/disintegration/imaging"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/filter"
"github.com/usememos/memos/plugin/storage/s3"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
const (
// The upload memory buffer is 32 MiB.
// It should be kept low, so RAM usage doesn't get out of control.
// This is unrelated to maximum upload size limit, which is now set through system setting.
MaxUploadBufferSizeBytes = 32 << 20
MebiByte = 1024 * 1024
// ThumbnailCacheFolder is the folder name where the thumbnail images are stored.
ThumbnailCacheFolder = ".thumbnail_cache"
// defaultJPEGQuality is the JPEG quality used when re-encoding images for EXIF stripping.
// Quality 95 maintains visual quality while ensuring metadata is removed.
defaultJPEGQuality = 95
)
var SupportedThumbnailMimeTypes = []string{
"image/png",
"image/jpeg",
}
// exifCapableImageTypes defines image formats that may contain EXIF metadata.
// These formats will have their EXIF metadata stripped on upload for privacy.
var exifCapableImageTypes = map[string]bool{
"image/jpeg": true,
"image/jpg": true,
"image/tiff": true,
"image/webp": true,
"image/heic": true,
"image/heif": true,
}
func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.CreateAttachmentRequest) (*v1pb.Attachment, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Validate required fields
if request.Attachment == nil {
return nil, status.Errorf(codes.InvalidArgument, "attachment is required")
}
if request.Attachment.Filename == "" {
return nil, status.Errorf(codes.InvalidArgument, "filename is required")
}
if !validateFilename(request.Attachment.Filename) {
return nil, status.Errorf(codes.InvalidArgument, "filename contains invalid characters or format")
}
if request.Attachment.Type == "" {
ext := filepath.Ext(request.Attachment.Filename)
mimeType := mime.TypeByExtension(ext)
if mimeType == "" {
mimeType = http.DetectContentType(request.Attachment.Content)
}
// ParseMediaType to strip parameters
mediaType, _, err := mime.ParseMediaType(mimeType)
if err == nil {
request.Attachment.Type = mediaType
}
}
if request.Attachment.Type == "" {
request.Attachment.Type = "application/octet-stream"
}
if !isValidMimeType(request.Attachment.Type) {
return nil, status.Errorf(codes.InvalidArgument, "invalid MIME type format")
}
// Use provided attachment_id or generate a new one
attachmentUID := request.AttachmentId
if attachmentUID == "" {
attachmentUID = shortuuid.New()
}
create := &store.Attachment{
UID: attachmentUID,
CreatorID: user.ID,
Filename: request.Attachment.Filename,
Type: request.Attachment.Type,
}
// No upload size limit - accept files of any size
size := binary.Size(request.Attachment.Content)
create.Size = int64(size)
create.Blob = request.Attachment.Content
// Strip EXIF metadata from images for privacy protection.
// This removes sensitive information like GPS location, device details, etc.
if shouldStripExif(create.Type) {
if strippedBlob, err := stripImageExif(create.Blob, create.Type); err != nil {
// Log warning but continue with original image to ensure uploads don't fail.
slog.Warn("failed to strip EXIF metadata from image",
slog.String("type", create.Type),
slog.String("filename", create.Filename),
slog.String("error", err.Error()))
} else {
create.Blob = strippedBlob
create.Size = int64(len(strippedBlob))
}
}
if err := SaveAttachmentBlob(ctx, s.Profile, s.Store, create); err != nil {
return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err)
}
if request.Attachment.Memo != nil {
memoUID, err := ExtractMemoUIDFromName(*request.Attachment.Memo)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found: %s", *request.Attachment.Memo)
}
create.MemoID = &memo.ID
}
attachment, err := s.Store.CreateAttachment(ctx, create)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create attachment: %v", err)
}
return convertAttachmentFromStore(attachment), nil
}
func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAttachmentsRequest) (*v1pb.ListAttachmentsResponse, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Set default page size
pageSize := int(request.PageSize)
if pageSize <= 0 {
pageSize = 50
}
if pageSize > 1000 {
pageSize = 1000
}
// Parse page token for offset
offset := 0
if request.PageToken != "" {
// Simple implementation: page token is the offset as string
// In production, you might want to use encrypted tokens
if parsed, err := fmt.Sscanf(request.PageToken, "%d", &offset); err != nil || parsed != 1 {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token")
}
}
findAttachment := &store.FindAttachment{
CreatorID: &user.ID,
Limit: &pageSize,
Offset: &offset,
}
// Parse filter if provided
if request.Filter != "" {
if err := s.validateAttachmentFilter(ctx, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
findAttachment.Filters = append(findAttachment.Filters, request.Filter)
}
attachments, err := s.Store.ListAttachments(ctx, findAttachment)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err)
}
response := &v1pb.ListAttachmentsResponse{}
for _, attachment := range attachments {
response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment))
}
// For simplicity, set total size to the number of returned attachments.
// In a full implementation, you'd want a separate count query
response.TotalSize = int32(len(response.Attachments))
// Set next page token if we got the full page size (indicating there might be more)
if len(attachments) == pageSize {
response.NextPageToken = fmt.Sprintf("%d", offset+pageSize)
}
return response, nil
}
func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttachmentRequest) (*v1pb.Attachment, error) {
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
// Check access permission based on linked memo visibility.
if err := s.checkAttachmentAccess(ctx, attachment); err != nil {
return nil, err
}
return convertAttachmentFromStore(attachment), nil
}
func (s *APIV1Service) UpdateAttachment(ctx context.Context, request *v1pb.UpdateAttachmentRequest) (*v1pb.Attachment, error) {
attachmentUID, err := ExtractAttachmentUIDFromName(request.Attachment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
// Only the creator or admin can update the attachment.
if attachment.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
currentTs := time.Now().Unix()
update := &store.UpdateAttachment{
ID: attachment.ID,
UpdatedTs: &currentTs,
}
for _, field := range request.UpdateMask.Paths {
if field == "filename" {
if !validateFilename(request.Attachment.Filename) {
return nil, status.Errorf(codes.InvalidArgument, "filename contains invalid characters or format")
}
update.Filename = &request.Attachment.Filename
}
}
if err := s.Store.UpdateAttachment(ctx, update); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
}
return s.GetAttachment(ctx, &v1pb.GetAttachmentRequest{
Name: request.Attachment.Name,
})
}
func (s *APIV1Service) DeleteAttachment(ctx context.Context, request *v1pb.DeleteAttachmentRequest) (*emptypb.Empty, error) {
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
UID: &attachmentUID,
CreatorID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find attachment: %v", err)
}
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
// Delete the attachment from the database.
if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{
ID: attachment.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete attachment: %v", err)
}
return &emptypb.Empty{}, nil
}
func convertAttachmentFromStore(attachment *store.Attachment) *v1pb.Attachment {
attachmentMessage := &v1pb.Attachment{
Name: fmt.Sprintf("%s%s", AttachmentNamePrefix, attachment.UID),
CreateTime: timestamppb.New(time.Unix(attachment.CreatedTs, 0)),
Filename: attachment.Filename,
Type: attachment.Type,
Size: attachment.Size,
}
if attachment.MemoUID != nil && *attachment.MemoUID != "" {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, *attachment.MemoUID)
attachmentMessage.Memo = &memoName
}
if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 {
attachmentMessage.ExternalLink = attachment.Reference
}
return attachmentMessage
}
// SaveAttachmentBlob save the blob of attachment based on the storage config.
func SaveAttachmentBlob(ctx context.Context, profile *profile.Profile, stores *store.Store, create *store.Attachment) error {
instanceStorageSetting, err := stores.GetInstanceStorageSetting(ctx)
if err != nil {
return errors.Wrap(err, "Failed to find instance storage setting")
}
if instanceStorageSetting.StorageType == storepb.InstanceStorageSetting_LOCAL {
filepathTemplate := "assets/{timestamp}_{filename}"
if instanceStorageSetting.FilepathTemplate != "" {
filepathTemplate = instanceStorageSetting.FilepathTemplate
}
internalPath := filepathTemplate
if !strings.Contains(internalPath, "{filename}") {
internalPath = filepath.Join(internalPath, "{filename}")
}
internalPath = replaceFilenameWithPathTemplate(internalPath, create.Filename)
internalPath = filepath.ToSlash(internalPath)
// Ensure the directory exists.
osPath := filepath.FromSlash(internalPath)
if !filepath.IsAbs(osPath) {
osPath = filepath.Join(profile.Data, osPath)
}
dir := filepath.Dir(osPath)
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
return errors.Wrap(err, "Failed to create directory")
}
// Write the blob to the file.
if err := os.WriteFile(osPath, create.Blob, 0644); err != nil {
return errors.Wrap(err, "Failed to write file")
}
create.Reference = internalPath
create.Blob = nil
create.StorageType = storepb.AttachmentStorageType_LOCAL
} else if instanceStorageSetting.StorageType == storepb.InstanceStorageSetting_S3 {
s3Config := instanceStorageSetting.S3Config
if s3Config == nil {
return errors.Errorf("No activated external storage found")
}
s3Client, err := s3.NewClient(ctx, s3Config)
if err != nil {
return errors.Wrap(err, "Failed to create s3 client")
}
filepathTemplate := instanceStorageSetting.FilepathTemplate
if !strings.Contains(filepathTemplate, "{filename}") {
filepathTemplate = filepath.Join(filepathTemplate, "{filename}")
}
filepathTemplate = replaceFilenameWithPathTemplate(filepathTemplate, create.Filename)
key, err := s3Client.UploadObject(ctx, filepathTemplate, create.Type, bytes.NewReader(create.Blob))
if err != nil {
return errors.Wrap(err, "Failed to upload via s3 client")
}
presignURL, err := s3Client.PresignGetObject(ctx, key)
if err != nil {
return errors.Wrap(err, "Failed to presign via s3 client")
}
create.Reference = presignURL
create.Blob = nil
create.StorageType = storepb.AttachmentStorageType_S3
create.Payload = &storepb.AttachmentPayload{
Payload: &storepb.AttachmentPayload_S3Object_{
S3Object: &storepb.AttachmentPayload_S3Object{
S3Config: s3Config,
Key: key,
LastPresignedTime: timestamppb.New(time.Now()),
},
},
}
}
return nil
}
func (s *APIV1Service) GetAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
// For local storage, read the file from the local disk.
if attachment.StorageType == storepb.AttachmentStorageType_LOCAL {
attachmentPath := filepath.FromSlash(attachment.Reference)
if !filepath.IsAbs(attachmentPath) {
attachmentPath = filepath.Join(s.Profile.Data, attachmentPath)
}
file, err := os.Open(attachmentPath)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.Wrap(err, "file not found")
}
return nil, errors.Wrap(err, "failed to open the file")
}
defer file.Close()
blob, err := io.ReadAll(file)
if err != nil {
return nil, errors.Wrap(err, "failed to read the file")
}
return blob, nil
}
// For S3 storage, download the file from S3.
if attachment.StorageType == storepb.AttachmentStorageType_S3 {
if attachment.Payload == nil {
return nil, errors.New("attachment payload is missing")
}
s3Object := attachment.Payload.GetS3Object()
if s3Object == nil {
return nil, errors.New("S3 object payload is missing")
}
if s3Object.S3Config == nil {
return nil, errors.New("S3 config is missing")
}
if s3Object.Key == "" {
return nil, errors.New("S3 object key is missing")
}
s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config)
if err != nil {
return nil, errors.Wrap(err, "failed to create S3 client")
}
blob, err := s3Client.GetObject(context.Background(), s3Object.Key)
if err != nil {
return nil, errors.Wrap(err, "failed to get object from S3")
}
return blob, nil
}
// For database storage, return the blob from the database.
return attachment.Blob, nil
}
var fileKeyPattern = regexp.MustCompile(`\{[a-z]{1,9}\}`)
func replaceFilenameWithPathTemplate(path, filename string) string {
t := time.Now()
path = fileKeyPattern.ReplaceAllStringFunc(path, func(s string) string {
switch s {
case "{filename}":
return filename
case "{timestamp}":
return fmt.Sprintf("%d", t.Unix())
case "{year}":
return fmt.Sprintf("%d", t.Year())
case "{month}":
return fmt.Sprintf("%02d", t.Month())
case "{day}":
return fmt.Sprintf("%02d", t.Day())
case "{hour}":
return fmt.Sprintf("%02d", t.Hour())
case "{minute}":
return fmt.Sprintf("%02d", t.Minute())
case "{second}":
return fmt.Sprintf("%02d", t.Second())
case "{uuid}":
return util.GenUUID()
default:
return s
}
})
return path
}
func validateFilename(filename string) bool {
// Reject path traversal attempts and make sure no additional directories are created
if !filepath.IsLocal(filename) || strings.ContainsAny(filename, "/\\") {
return false
}
// Reject filenames starting or ending with spaces or periods
if strings.HasPrefix(filename, " ") || strings.HasSuffix(filename, " ") ||
strings.HasPrefix(filename, ".") || strings.HasSuffix(filename, ".") {
return false
}
return true
}
func isValidMimeType(mimeType string) bool {
// Reject empty or excessively long MIME types
if mimeType == "" || len(mimeType) > 255 {
return false
}
// MIME type must match the pattern: type/subtype
// Allow common characters in MIME types per RFC 2045
matched, _ := regexp.MatchString(`^[a-zA-Z0-9][a-zA-Z0-9!#$&^_.+-]{0,126}/[a-zA-Z0-9][a-zA-Z0-9!#$&^_.+-]{0,126}$`, mimeType)
return matched
}
func (s *APIV1Service) validateAttachmentFilter(ctx context.Context, filterStr string) error {
if filterStr == "" {
return errors.New("filter cannot be empty")
}
engine, err := filter.DefaultAttachmentEngine()
if err != nil {
return err
}
var dialect filter.DialectName
switch s.Profile.Driver {
case "mysql":
dialect = filter.DialectMySQL
case "postgres":
dialect = filter.DialectPostgres
default:
dialect = filter.DialectSQLite
}
if _, err := engine.CompileToStatement(ctx, filterStr, filter.RenderOptions{Dialect: dialect}); err != nil {
return errors.Wrap(err, "failed to compile filter")
}
return nil
}
// checkAttachmentAccess verifies the user has permission to access the attachment.
// For unlinked attachments (no memo), only the creator can access.
// For linked attachments, access follows the memo's visibility rules.
func (s *APIV1Service) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment) error {
user, _ := s.fetchCurrentUser(ctx)
// For unlinked attachments, only the creator can access.
if attachment.MemoID == nil {
if user == nil {
return status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if attachment.CreatorID != user.ID && !isSuperUser(user) {
return status.Errorf(codes.PermissionDenied, "permission denied")
}
return nil
}
// For linked attachments, check memo visibility.
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
if err != nil {
return status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility == store.Public {
return nil
}
if user == nil {
return status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return status.Errorf(codes.PermissionDenied, "permission denied")
}
return nil
}
// shouldStripExif checks if the MIME type is an image format that may contain EXIF metadata.
// Returns true for formats like JPEG, TIFF, WebP, HEIC, and HEIF which commonly contain
// privacy-sensitive metadata such as GPS coordinates, camera settings, and device information.
func shouldStripExif(mimeType string) bool {
return exifCapableImageTypes[mimeType]
}
// stripImageExif removes EXIF metadata from image files by decoding and re-encoding them.
// This prevents exposure of sensitive metadata such as GPS location, camera details, and timestamps.
//
// The function preserves the correct image orientation by applying EXIF orientation tags
// during decoding before stripping all metadata. Images are re-encoded with high quality
// to minimize visual degradation.
//
// Supported formats:
// - JPEG/JPG: Re-encoded as JPEG with quality 95
// - PNG: Re-encoded as PNG (lossless)
// - TIFF/WebP/HEIC/HEIF: Re-encoded as JPEG with quality 95
//
// Returns the cleaned image data without any EXIF metadata, or an error if processing fails.
func stripImageExif(imageData []byte, mimeType string) ([]byte, error) {
// Decode image with automatic EXIF orientation correction.
// This ensures the image displays correctly after metadata removal.
img, err := imaging.Decode(bytes.NewReader(imageData), imaging.AutoOrientation(true))
if err != nil {
return nil, errors.Wrap(err, "failed to decode image")
}
// Re-encode the image without EXIF metadata.
var buf bytes.Buffer
var encodeErr error
if mimeType == "image/png" {
// Preserve PNG format for lossless encoding
encodeErr = imaging.Encode(&buf, img, imaging.PNG)
} else {
// For JPEG, TIFF, WebP, HEIC, HEIF - re-encode as JPEG.
// This ensures EXIF is stripped and provides good compression.
encodeErr = imaging.Encode(&buf, img, imaging.JPEG, imaging.JPEGQuality(defaultJPEGQuality))
}
if encodeErr != nil {
return nil, errors.Wrap(encodeErr, "failed to encode image")
}
return buf.Bytes(), nil
}

View File

@@ -0,0 +1,612 @@
package v1
import (
"context"
"fmt"
"log/slog"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
const (
unmatchedUsernameAndPasswordError = "unmatched username and password"
)
// GetCurrentUser returns the authenticated user's information.
// Validates the access token and returns user details.
//
// Authentication: Required (access token).
// Returns: User information.
func (s *APIV1Service) GetCurrentUser(ctx context.Context, _ *v1pb.GetCurrentUserRequest) (*v1pb.GetCurrentUserResponse, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
}
if user == nil {
// Clear auth cookies
if err := s.clearAuthCookies(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies: %v", err)
}
return nil, status.Errorf(codes.Unauthenticated, "user not found")
}
return &v1pb.GetCurrentUserResponse{
User: convertUserFromStore(user),
}, nil
}
// SignIn authenticates a user with credentials and returns tokens.
// On success, returns an access token and sets a refresh token cookie.
//
// Supports two authentication methods:
// 1. Password-based authentication (username + password).
// 2. SSO authentication (OAuth2 authorization code).
//
// Authentication: Not required (public endpoint).
// Returns: User info, access token, and token expiry.
func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.SignInResponse, error) {
var existingUser *store.User
// Authentication Method 1: Password-based authentication
if passwordCredentials := request.GetPasswordCredentials(); passwordCredentials != nil {
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &passwordCredentials.Username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
}
// Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(passwordCredentials.Password)); err != nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
}
instanceGeneralSetting, err := s.Store.GetInstanceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance general setting, error: %v", err)
}
// Check if the password auth in is allowed.
if instanceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser {
return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed")
}
existingUser = user
} else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil {
// Authentication Method 2: SSO (OAuth2) authentication
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &ssoCredentials.IdpId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.InvalidArgument, "identity provider not found")
}
var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
}
// Pass code_verifier for PKCE support (empty string if not provided for backward compatibility)
token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code, ssoCredentials.CodeVerifier)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
}
userInfo, err = oauth2IdentityProvider.UserInfo(token)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user info, error: %v", err)
}
}
identifierFilter := identityProvider.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to compile identifier filter regex, error: %v", err)
}
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
return nil, status.Errorf(codes.PermissionDenied, "identifier %s is not allowed", userInfo.Identifier)
}
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
}
if user == nil {
// Check if the user is allowed to sign up.
instanceGeneralSetting, err := s.Store.GetInstanceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance general setting, error: %v", err)
}
if instanceGeneralSetting.DisallowUserRegistration {
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
}
// Create a new user with the user info from the identity provider.
userCreate := &store.User{
Username: userInfo.Identifier,
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
AvatarURL: userInfo.AvatarURL,
}
password, err := util.RandomString(20)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate random password, error: %v", err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err)
}
userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err)
}
}
existingUser = user
}
if existingUser == nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid credentials")
}
if existingUser.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", existingUser.Username)
}
accessToken, accessExpiresAt, err := s.doSignIn(ctx, existingUser)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to sign in: %v", err)
}
return &v1pb.SignInResponse{
User: convertUserFromStore(existingUser),
AccessToken: accessToken,
AccessTokenExpiresAt: timestamppb.New(accessExpiresAt),
}, nil
}
// doSignIn performs the actual sign-in operation by creating a session and setting the cookie.
//
// This function:
// 1. Generates refresh token and access token.
// 2. Stores refresh token metadata in user_setting.
// 3. Sets refresh token as HttpOnly cookie.
// 4. Returns access token and its expiry time.
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User) (string, time.Time, error) {
// Generate refresh token
tokenID := util.GenUUID()
refreshToken, refreshExpiresAt, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(s.Secret))
if err != nil {
return "", time.Time{}, status.Errorf(codes.Internal, "failed to generate refresh token: %v", err)
}
// Store refresh token metadata
clientInfo := s.extractClientInfo(ctx)
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(refreshExpiresAt),
CreatedAt: timestamppb.Now(),
ClientInfo: clientInfo,
}
if err := s.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord); err != nil {
slog.Error("failed to store refresh token", "error", err)
}
// Set refresh token cookie
refreshCookie := s.buildRefreshTokenCookie(ctx, refreshToken, refreshExpiresAt)
if err := SetResponseHeader(ctx, "Set-Cookie", refreshCookie); err != nil {
return "", time.Time{}, status.Errorf(codes.Internal, "failed to set refresh token cookie: %v", err)
}
// Generate access token
accessToken, accessExpiresAt, err := auth.GenerateAccessTokenV2(
user.ID,
user.Username,
string(user.Role),
string(user.RowStatus),
[]byte(s.Secret),
)
if err != nil {
return "", time.Time{}, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}
return accessToken, accessExpiresAt, nil
}
// SignOut terminates the user's authentication.
// Revokes the refresh token and clears the authentication cookie.
//
// Authentication: Required (access token).
// Returns: Empty response on success.
func (s *APIV1Service) SignOut(ctx context.Context, _ *v1pb.SignOutRequest) (*emptypb.Empty, error) {
// Get user from access token claims
claims := auth.GetUserClaims(ctx)
if claims != nil {
// Revoke refresh token if we can identify it
refreshToken := ""
if md, ok := metadata.FromIncomingContext(ctx); ok {
if cookies := md.Get("cookie"); len(cookies) > 0 {
refreshToken = auth.ExtractRefreshTokenFromCookie(cookies[0])
}
}
if refreshToken != "" {
refreshClaims, err := auth.ParseRefreshToken(refreshToken, []byte(s.Secret))
if err == nil {
// Remove refresh token from user_setting by token_id
_ = s.Store.RemoveUserRefreshToken(ctx, claims.UserID, refreshClaims.TokenID)
}
}
}
// Clear refresh token cookie
if err := s.clearAuthCookies(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies, error: %v", err)
}
return &emptypb.Empty{}, nil
}
// RefreshToken exchanges a valid refresh token for a new access token.
//
// This endpoint implements refresh token rotation with sliding window sessions:
// 1. Extracts the refresh token from the HttpOnly cookie (memos_refresh)
// 2. Validates the refresh token against the database (checking expiry and revocation)
// 3. Rotates the refresh token: generates a new one with fresh 30-day expiry
// 4. Generates a new short-lived access token (15 minutes)
// 5. Sets the new refresh token as HttpOnly cookie
// 6. Returns the new access token and its expiry time
//
// Token rotation provides:
// - Sliding window sessions: active users stay logged in indefinitely
// - Better security: stolen refresh tokens become invalid after legitimate refresh
//
// Authentication: Requires valid refresh token in cookie (public endpoint)
// Returns: New access token and expiry timestamp.
func (s *APIV1Service) RefreshToken(ctx context.Context, _ *v1pb.RefreshTokenRequest) (*v1pb.RefreshTokenResponse, error) {
// Extract refresh token from cookie
refreshToken := ""
if md, ok := metadata.FromIncomingContext(ctx); ok {
if cookies := md.Get("cookie"); len(cookies) > 0 {
refreshToken = auth.ExtractRefreshTokenFromCookie(cookies[0])
}
}
if refreshToken == "" {
return nil, status.Errorf(codes.Unauthenticated, "refresh token not found")
}
// Validate refresh token and get old token ID for rotation
authenticator := auth.NewAuthenticator(s.Store, s.Secret)
user, oldTokenID, err := authenticator.AuthenticateByRefreshToken(ctx, refreshToken)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid refresh token: %v", err)
}
// --- Refresh Token Rotation ---
// Generate new refresh token with fresh 30-day expiry (sliding window)
newTokenID := util.GenUUID()
newRefreshToken, newRefreshExpiresAt, err := auth.GenerateRefreshToken(user.ID, newTokenID, []byte(s.Secret))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate refresh token: %v", err)
}
// Store new refresh token (add before remove to handle race conditions)
clientInfo := s.extractClientInfo(ctx)
newRefreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: newTokenID,
ExpiresAt: timestamppb.New(newRefreshExpiresAt),
CreatedAt: timestamppb.Now(),
ClientInfo: clientInfo,
}
if err := s.Store.AddUserRefreshToken(ctx, user.ID, newRefreshTokenRecord); err != nil {
return nil, status.Errorf(codes.Internal, "failed to store refresh token: %v", err)
}
// Remove old refresh token
if err := s.Store.RemoveUserRefreshToken(ctx, user.ID, oldTokenID); err != nil {
// Log but don't fail - old token will expire naturally
slog.Warn("failed to remove old refresh token", "error", err, "userID", user.ID, "tokenID", oldTokenID)
}
// Set new refresh token cookie
newRefreshCookie := s.buildRefreshTokenCookie(ctx, newRefreshToken, newRefreshExpiresAt)
if err := SetResponseHeader(ctx, "Set-Cookie", newRefreshCookie); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set refresh token cookie: %v", err)
}
// --- End Rotation ---
// Generate new access token
accessToken, expiresAt, err := auth.GenerateAccessTokenV2(
user.ID,
user.Username,
string(user.Role),
string(user.RowStatus),
[]byte(s.Secret),
)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}
return &v1pb.RefreshTokenResponse{
AccessToken: accessToken,
ExpiresAt: timestamppb.New(expiresAt),
}, nil
}
func (s *APIV1Service) clearAuthCookies(ctx context.Context) error {
// Clear refresh token cookie
refreshCookie := s.buildRefreshTokenCookie(ctx, "", time.Time{})
if err := SetResponseHeader(ctx, "Set-Cookie", refreshCookie); err != nil {
return errors.Wrap(err, "failed to set refresh cookie")
}
return nil
}
func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken string, expireTime time.Time) string {
attrs := []string{
fmt.Sprintf("%s=%s", auth.RefreshTokenCookieName, refreshToken),
"Path=/",
"HttpOnly",
}
if expireTime.IsZero() {
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
} else {
// RFC 6265 requires cookie expiration dates to use GMT timezone
// Convert to UTC and format with explicit "GMT" to ensure browser compatibility
attrs = append(attrs, "Expires="+expireTime.UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT"))
}
// Try to determine if the request is HTTPS by checking the origin header
// Default to non-HTTPS (Lax SameSite) if metadata is not available
isHTTPS := false
if md, ok := metadata.FromIncomingContext(ctx); ok {
for _, v := range md.Get("origin") {
if strings.HasPrefix(v, "https://") {
isHTTPS = true
break
}
}
}
if isHTTPS {
attrs = append(attrs, "SameSite=Lax", "Secure")
} else {
attrs = append(attrs, "SameSite=Lax")
}
return strings.Join(attrs, "; ")
}
func (s *APIV1Service) fetchCurrentUser(ctx context.Context) (*store.User, error) {
userID := auth.GetUserID(ctx)
if userID == 0 {
return nil, nil
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.Errorf("user %d not found", userID)
}
return user, nil
}
// extractClientInfo extracts comprehensive client information from the request context.
//
// This function parses metadata from the gRPC context to extract:
// - User Agent: Raw user agent string for detailed parsing
// - IP Address: Client IP from X-Forwarded-For or X-Real-IP headers
// - Device Type: "mobile", "tablet", or "desktop" (parsed from user agent)
// - Operating System: OS name and version (e.g., "iOS 17.1", "Windows 10/11")
// - Browser: Browser name and version (e.g., "Chrome 120.0.0.0")
//
// This information enables users to:
// - See all active sessions with device details
// - Identify suspicious login attempts
// - Revoke specific sessions from unknown devices.
func (s *APIV1Service) extractClientInfo(ctx context.Context) *storepb.RefreshTokensUserSetting_ClientInfo {
clientInfo := &storepb.RefreshTokensUserSetting_ClientInfo{}
// Extract user agent from metadata if available
if md, ok := metadata.FromIncomingContext(ctx); ok {
if userAgents := md.Get("user-agent"); len(userAgents) > 0 {
userAgent := userAgents[0]
clientInfo.UserAgent = userAgent
// Parse user agent to extract device type, OS, browser info
s.parseUserAgent(userAgent, clientInfo)
}
if forwardedFor := md.Get("x-forwarded-for"); len(forwardedFor) > 0 {
ipAddress := strings.Split(forwardedFor[0], ",")[0] // Get the first IP in case of multiple
ipAddress = strings.TrimSpace(ipAddress)
clientInfo.IpAddress = ipAddress
} else if realIP := md.Get("x-real-ip"); len(realIP) > 0 {
clientInfo.IpAddress = realIP[0]
}
}
return clientInfo
}
// parseUserAgent extracts device type, OS, and browser information from user agent string.
//
// Detection logic:
// - Device Type: Checks for keywords like "mobile", "tablet", "ipad"
// - OS: Pattern matches for iOS, Android, Windows, macOS, Linux, Chrome OS
// - Browser: Identifies Edge, Chrome, Firefox, Safari, Opera
//
// Note: This is a simplified parser. For production use with high accuracy requirements,
// consider using a dedicated user agent parsing library.
func (*APIV1Service) parseUserAgent(userAgent string, clientInfo *storepb.RefreshTokensUserSetting_ClientInfo) {
if userAgent == "" {
return
}
userAgent = strings.ToLower(userAgent)
// Detect device type
if strings.Contains(userAgent, "ipad") || strings.Contains(userAgent, "tablet") {
clientInfo.DeviceType = "tablet"
} else if strings.Contains(userAgent, "mobile") || strings.Contains(userAgent, "android") ||
strings.Contains(userAgent, "iphone") || strings.Contains(userAgent, "ipod") ||
strings.Contains(userAgent, "windows phone") || strings.Contains(userAgent, "blackberry") {
clientInfo.DeviceType = "mobile"
} else {
clientInfo.DeviceType = "desktop"
}
// Detect operating system
if strings.Contains(userAgent, "iphone os") || strings.Contains(userAgent, "cpu os") {
// Extract iOS version
if idx := strings.Index(userAgent, "cpu os "); idx != -1 {
versionStart := idx + 7
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd != -1 {
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
clientInfo.Os = "iOS " + version
} else {
clientInfo.Os = "iOS"
}
} else if idx := strings.Index(userAgent, "iphone os "); idx != -1 {
versionStart := idx + 10
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd != -1 {
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
clientInfo.Os = "iOS " + version
} else {
clientInfo.Os = "iOS"
}
} else {
clientInfo.Os = "iOS"
}
} else if strings.Contains(userAgent, "android") {
// Extract Android version
if idx := strings.Index(userAgent, "android "); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], ";")
if versionEnd == -1 {
versionEnd = strings.Index(userAgent[versionStart:], ")")
}
if versionEnd != -1 {
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Os = "Android " + version
} else {
clientInfo.Os = "Android"
}
} else {
clientInfo.Os = "Android"
}
} else if strings.Contains(userAgent, "windows nt 10.0") {
clientInfo.Os = "Windows 10/11"
} else if strings.Contains(userAgent, "windows nt 6.3") {
clientInfo.Os = "Windows 8.1"
} else if strings.Contains(userAgent, "windows nt 6.1") {
clientInfo.Os = "Windows 7"
} else if strings.Contains(userAgent, "windows") {
clientInfo.Os = "Windows"
} else if strings.Contains(userAgent, "mac os x") {
// Extract macOS version
if idx := strings.Index(userAgent, "mac os x "); idx != -1 {
versionStart := idx + 9
versionEnd := strings.Index(userAgent[versionStart:], ";")
if versionEnd == -1 {
versionEnd = strings.Index(userAgent[versionStart:], ")")
}
if versionEnd != -1 {
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
clientInfo.Os = "macOS " + version
} else {
clientInfo.Os = "macOS"
}
} else {
clientInfo.Os = "macOS"
}
} else if strings.Contains(userAgent, "linux") {
clientInfo.Os = "Linux"
} else if strings.Contains(userAgent, "cros") {
clientInfo.Os = "Chrome OS"
}
// Detect browser
if strings.Contains(userAgent, "edg/") {
// Extract Edge version
if idx := strings.Index(userAgent, "edg/"); idx != -1 {
versionStart := idx + 4
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Edge " + version
} else {
clientInfo.Browser = "Edge"
}
} else if strings.Contains(userAgent, "chrome/") && !strings.Contains(userAgent, "edg") {
// Extract Chrome version
if idx := strings.Index(userAgent, "chrome/"); idx != -1 {
versionStart := idx + 7
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Chrome " + version
} else {
clientInfo.Browser = "Chrome"
}
} else if strings.Contains(userAgent, "firefox/") {
// Extract Firefox version
if idx := strings.Index(userAgent, "firefox/"); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Firefox " + version
} else {
clientInfo.Browser = "Firefox"
}
} else if strings.Contains(userAgent, "safari/") && !strings.Contains(userAgent, "chrome") && !strings.Contains(userAgent, "edg") {
// Extract Safari version
if idx := strings.Index(userAgent, "version/"); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Safari " + version
} else {
clientInfo.Browser = "Safari"
}
} else if strings.Contains(userAgent, "opera/") || strings.Contains(userAgent, "opr/") {
clientInfo.Browser = "Opera"
}
}

View File

@@ -0,0 +1,179 @@
package v1
import (
"context"
"testing"
"google.golang.org/grpc/metadata"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestParseUserAgent(t *testing.T) {
service := &APIV1Service{}
tests := []struct {
name string
userAgent string
expectedDevice string
expectedOS string
expectedBrowser string
}{
{
name: "Chrome on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Chrome 119.0.0.0",
},
{
name: "Safari on macOS",
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15",
expectedDevice: "desktop",
expectedOS: "macOS 10.15.7",
expectedBrowser: "Safari 17.0",
},
{
name: "Chrome on Android Mobile",
userAgent: "Mozilla/5.0 (Linux; Android 13; SM-G998B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Mobile Safari/537.36",
expectedDevice: "mobile",
expectedOS: "Android 13",
expectedBrowser: "Chrome 119.0.0.0",
},
{
name: "Safari on iPhone",
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
expectedDevice: "mobile",
expectedOS: "iOS 17.0",
expectedBrowser: "Safari 17.0",
},
{
name: "Firefox on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/119.0",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Firefox 119.0",
},
{
name: "Edge on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Edge 119.0.0.0",
},
{
name: "iPad Safari",
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
expectedDevice: "tablet",
expectedOS: "iOS 17.0",
expectedBrowser: "Safari 17.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientInfo := &storepb.RefreshTokensUserSetting_ClientInfo{}
service.parseUserAgent(tt.userAgent, clientInfo)
if clientInfo.DeviceType != tt.expectedDevice {
t.Errorf("Expected device type %s, got %s", tt.expectedDevice, clientInfo.DeviceType)
}
if clientInfo.Os != tt.expectedOS {
t.Errorf("Expected OS %s, got %s", tt.expectedOS, clientInfo.Os)
}
if clientInfo.Browser != tt.expectedBrowser {
t.Errorf("Expected browser %s, got %s", tt.expectedBrowser, clientInfo.Browser)
}
})
}
}
func TestExtractClientInfo(t *testing.T) {
service := &APIV1Service{}
// Test with metadata containing user agent and IP
md := metadata.New(map[string]string{
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
"x-forwarded-for": "203.0.113.1, 198.51.100.1",
"x-real-ip": "203.0.113.1",
})
ctx := metadata.NewIncomingContext(context.Background(), md)
clientInfo := service.extractClientInfo(ctx)
if clientInfo.UserAgent == "" {
t.Error("Expected user agent to be set")
}
if clientInfo.IpAddress != "203.0.113.1" {
t.Errorf("Expected IP address to be 203.0.113.1, got %s", clientInfo.IpAddress)
}
if clientInfo.DeviceType != "desktop" {
t.Errorf("Expected device type to be desktop, got %s", clientInfo.DeviceType)
}
if clientInfo.Os != "Windows 10/11" {
t.Errorf("Expected OS to be Windows 10/11, got %s", clientInfo.Os)
}
if clientInfo.Browser != "Chrome 119.0.0.0" {
t.Errorf("Expected browser to be Chrome 119.0.0.0, got %s", clientInfo.Browser)
}
}
// TestClientInfoExamples demonstrates the enhanced client info extraction with various user agents.
func TestClientInfoExamples(t *testing.T) {
service := &APIV1Service{}
examples := []struct {
description string
userAgent string
}{
{
description: "Modern Chrome on Windows 11",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
},
{
description: "Safari on iPhone 15 Pro",
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
},
{
description: "Chrome on Samsung Galaxy",
userAgent: "Mozilla/5.0 (Linux; Android 14; SM-S918B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36",
},
{
description: "Firefox on Ubuntu",
userAgent: "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/120.0",
},
{
description: "Edge on Windows 10",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
},
{
description: "Safari on iPad Air",
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
},
}
for _, example := range examples {
t.Run(example.description, func(t *testing.T) {
clientInfo := &storepb.RefreshTokensUserSetting_ClientInfo{}
service.parseUserAgent(example.userAgent, clientInfo)
t.Logf("User Agent: %s", example.userAgent)
t.Logf("Device Type: %s", clientInfo.DeviceType)
t.Logf("Operating System: %s", clientInfo.Os)
t.Logf("Browser: %s", clientInfo.Browser)
t.Log("---")
// Ensure all fields are populated
if clientInfo.DeviceType == "" {
t.Error("Device type should not be empty")
}
if clientInfo.Os == "" {
t.Error("OS should not be empty")
}
if clientInfo.Browser == "" {
t.Error("Browser should not be empty")
}
})
}
}

View File

@@ -0,0 +1,68 @@
package v1
import (
"encoding/base64"
"github.com/pkg/errors"
"google.golang.org/protobuf/proto"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
const (
// DefaultPageSize is the default page size for requests.
DefaultPageSize = 10
// MaxPageSize is the maximum page size for requests.
MaxPageSize = 1000
)
func convertStateFromStore(rowStatus store.RowStatus) v1pb.State {
switch rowStatus {
case store.Normal:
return v1pb.State_NORMAL
case store.Archived:
return v1pb.State_ARCHIVED
default:
return v1pb.State_STATE_UNSPECIFIED
}
}
func convertStateToStore(state v1pb.State) store.RowStatus {
switch state {
case v1pb.State_ARCHIVED:
return store.Archived
default:
return store.Normal
}
}
func getPageToken(limit int, offset int) (string, error) {
return marshalPageToken(&v1pb.PageToken{
Limit: int32(limit),
Offset: int32(offset),
})
}
func marshalPageToken(pageToken *v1pb.PageToken) (string, error) {
b, err := proto.Marshal(pageToken)
if err != nil {
return "", errors.Wrapf(err, "failed to marshal page token")
}
return base64.StdEncoding.EncodeToString(b), nil
}
func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return errors.Wrapf(err, "failed to decode page token")
}
if err := proto.Unmarshal(b, pageToken); err != nil {
return errors.Wrapf(err, "failed to unmarshal page token")
}
return nil
}
func isSuperUser(user *store.User) bool {
return user.Role == store.RoleAdmin
}

View File

@@ -0,0 +1,80 @@
package v1
import (
"net/http"
"connectrpc.com/connect"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/usememos/memos/proto/gen/api/v1/apiv1connect"
)
// ConnectServiceHandler wraps APIV1Service to implement Connect handler interfaces.
// It adapts the existing gRPC service implementations to work with Connect's
// request/response wrapper types.
//
// This wrapper pattern allows us to:
// - Reuse existing gRPC service implementations
// - Support both native gRPC and Connect protocols
// - Maintain a single source of truth for business logic.
type ConnectServiceHandler struct {
*APIV1Service
}
// NewConnectServiceHandler creates a new Connect service handler.
func NewConnectServiceHandler(svc *APIV1Service) *ConnectServiceHandler {
return &ConnectServiceHandler{APIV1Service: svc}
}
// RegisterConnectHandlers registers all Connect service handlers on the given mux.
func (s *ConnectServiceHandler) RegisterConnectHandlers(mux *http.ServeMux, opts ...connect.HandlerOption) {
// Register all service handlers
handlers := []struct {
path string
handler http.Handler
}{
wrap(apiv1connect.NewInstanceServiceHandler(s, opts...)),
wrap(apiv1connect.NewAuthServiceHandler(s, opts...)),
wrap(apiv1connect.NewUserServiceHandler(s, opts...)),
wrap(apiv1connect.NewMemoServiceHandler(s, opts...)),
wrap(apiv1connect.NewAttachmentServiceHandler(s, opts...)),
wrap(apiv1connect.NewShortcutServiceHandler(s, opts...)),
wrap(apiv1connect.NewActivityServiceHandler(s, opts...)),
wrap(apiv1connect.NewIdentityProviderServiceHandler(s, opts...)),
}
for _, h := range handlers {
mux.Handle(h.path, h.handler)
}
}
// wrap converts (path, handler) return value to a struct for cleaner iteration.
func wrap(path string, handler http.Handler) struct {
path string
handler http.Handler
} {
return struct {
path string
handler http.Handler
}{path, handler}
}
// convertGRPCError converts gRPC status errors to Connect errors.
// This preserves the error code semantics between the two protocols.
func convertGRPCError(err error) error {
if err == nil {
return nil
}
if st, ok := status.FromError(err); ok {
return connect.NewError(grpcCodeToConnectCode(st.Code()), err)
}
return connect.NewError(connect.CodeInternal, err)
}
// grpcCodeToConnectCode converts gRPC status codes to Connect error codes.
// gRPC and Connect use the same error code semantics, so this is a direct cast.
// See: https://connectrpc.com/docs/protocol/#error-codes
func grpcCodeToConnectCode(code codes.Code) connect.Code {
return connect.Code(code)
}

View File

@@ -0,0 +1,237 @@
package v1
import (
"context"
"errors"
"fmt"
"log/slog"
"reflect"
"runtime/debug"
"connectrpc.com/connect"
pkgerrors "github.com/pkg/errors"
"google.golang.org/grpc/metadata"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// MetadataInterceptor converts Connect HTTP headers to gRPC metadata.
//
// This ensures service methods can use metadata.FromIncomingContext() to access
// headers like User-Agent, X-Forwarded-For, etc., regardless of whether the
// request came via Connect RPC or gRPC-Gateway.
type MetadataInterceptor struct{}
// NewMetadataInterceptor creates a new metadata interceptor.
func NewMetadataInterceptor() *MetadataInterceptor {
return &MetadataInterceptor{}
}
func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
// Convert HTTP headers to gRPC metadata
header := req.Header()
md := metadata.MD{}
// Copy important headers for client info extraction
if ua := header.Get("User-Agent"); ua != "" {
md.Set("user-agent", ua)
}
if xff := header.Get("X-Forwarded-For"); xff != "" {
md.Set("x-forwarded-for", xff)
}
if xri := header.Get("X-Real-Ip"); xri != "" {
md.Set("x-real-ip", xri)
}
// Forward Cookie header for authentication methods that need it (e.g., RefreshToken)
if cookie := header.Get("Cookie"); cookie != "" {
md.Set("cookie", cookie)
}
// Set metadata in context so services can use metadata.FromIncomingContext()
ctx = metadata.NewIncomingContext(ctx, md)
// Execute the request
resp, err := next(ctx, req)
// Prevent browser caching of API responses to avoid stale data issues
// See: https://github.com/usememos/memos/issues/5470
if !isNilAnyResponse(resp) && resp.Header() != nil {
resp.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
resp.Header().Set("Pragma", "no-cache")
resp.Header().Set("Expires", "0")
}
return resp, err
}
}
func isNilAnyResponse(resp connect.AnyResponse) bool {
if resp == nil {
return true
}
val := reflect.ValueOf(resp)
return val.Kind() == reflect.Ptr && val.IsNil()
}
func (*MetadataInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next
}
func (*MetadataInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next
}
// LoggingInterceptor logs Connect RPC requests with appropriate log levels.
//
// Log levels:
// - INFO: Successful requests and expected client errors (not found, permission denied, etc.)
// - ERROR: Server errors (internal, unavailable, etc.)
type LoggingInterceptor struct {
logStacktrace bool
}
// NewLoggingInterceptor creates a new logging interceptor.
func NewLoggingInterceptor(logStacktrace bool) *LoggingInterceptor {
return &LoggingInterceptor{logStacktrace: logStacktrace}
}
func (in *LoggingInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
resp, err := next(ctx, req)
in.log(req.Spec().Procedure, err)
return resp, err
}
}
func (*LoggingInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next // No-op for server-side interceptor
}
func (*LoggingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next // Streaming not used in this service
}
func (in *LoggingInterceptor) log(procedure string, err error) {
level, msg := in.classifyError(err)
attrs := []slog.Attr{slog.String("method", procedure)}
if err != nil {
attrs = append(attrs, slog.String("error", err.Error()))
if in.logStacktrace {
attrs = append(attrs, slog.String("stacktrace", fmt.Sprintf("%+v", err)))
}
}
slog.LogAttrs(context.Background(), level, msg, attrs...)
}
func (*LoggingInterceptor) classifyError(err error) (slog.Level, string) {
if err == nil {
return slog.LevelInfo, "OK"
}
var connectErr *connect.Error
if !pkgerrors.As(err, &connectErr) {
return slog.LevelError, "unknown error"
}
// Client errors (expected, log at INFO)
switch connectErr.Code() {
case connect.CodeCanceled,
connect.CodeInvalidArgument,
connect.CodeNotFound,
connect.CodeAlreadyExists,
connect.CodePermissionDenied,
connect.CodeUnauthenticated,
connect.CodeResourceExhausted,
connect.CodeFailedPrecondition,
connect.CodeAborted,
connect.CodeOutOfRange:
return slog.LevelInfo, "client error"
default:
// Server errors
return slog.LevelError, "server error"
}
}
// RecoveryInterceptor recovers from panics in Connect handlers and returns an internal error.
type RecoveryInterceptor struct {
logStacktrace bool
}
// NewRecoveryInterceptor creates a new recovery interceptor.
func NewRecoveryInterceptor(logStacktrace bool) *RecoveryInterceptor {
return &RecoveryInterceptor{logStacktrace: logStacktrace}
}
func (in *RecoveryInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (resp connect.AnyResponse, err error) {
defer func() {
if r := recover(); r != nil {
in.logPanic(req.Spec().Procedure, r)
err = connect.NewError(connect.CodeInternal, pkgerrors.New("internal server error"))
}
}()
return next(ctx, req)
}
}
func (*RecoveryInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next
}
func (*RecoveryInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next
}
func (in *RecoveryInterceptor) logPanic(procedure string, panicValue any) {
attrs := []slog.Attr{
slog.String("method", procedure),
slog.Any("panic", panicValue),
}
if in.logStacktrace {
attrs = append(attrs, slog.String("stacktrace", string(debug.Stack())))
}
slog.LogAttrs(context.Background(), slog.LevelError, "panic recovered in Connect handler", attrs...)
}
// AuthInterceptor handles authentication for Connect handlers.
//
// It enforces authentication for all endpoints except those listed in PublicMethods.
// Role-based authorization (admin checks) remains in the service layer.
type AuthInterceptor struct {
authenticator *auth.Authenticator
}
// NewAuthInterceptor creates a new auth interceptor.
func NewAuthInterceptor(store *store.Store, secret string) *AuthInterceptor {
return &AuthInterceptor{
authenticator: auth.NewAuthenticator(store, secret),
}
}
func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
header := req.Header()
authHeader := header.Get("Authorization")
result := in.authenticator.Authenticate(ctx, authHeader)
// Enforce authentication for non-public methods
if result == nil && !IsPublicMethod(req.Spec().Procedure) {
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required"))
}
ctx = auth.ApplyToContext(ctx, result)
return next(ctx, req)
}
}
func (*AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next
}
func (*AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next
}

View File

@@ -0,0 +1,490 @@
package v1
import (
"context"
"connectrpc.com/connect"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
// This file contains all Connect service handler method implementations.
// Each method delegates to the underlying gRPC service implementation,
// converting between Connect and gRPC request/response types.
// InstanceService
func (s *ConnectServiceHandler) GetInstanceProfile(ctx context.Context, req *connect.Request[v1pb.GetInstanceProfileRequest]) (*connect.Response[v1pb.InstanceProfile], error) {
resp, err := s.APIV1Service.GetInstanceProfile(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetInstanceSetting(ctx context.Context, req *connect.Request[v1pb.GetInstanceSettingRequest]) (*connect.Response[v1pb.InstanceSetting], error) {
resp, err := s.APIV1Service.GetInstanceSetting(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateInstanceSetting(ctx context.Context, req *connect.Request[v1pb.UpdateInstanceSettingRequest]) (*connect.Response[v1pb.InstanceSetting], error) {
resp, err := s.APIV1Service.UpdateInstanceSetting(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// AuthService
//
// Auth service methods need special handling for response headers (cookies).
// We use connectWithHeaderCarrier helper to inject a header carrier into the context,
// which allows the service to set headers in a protocol-agnostic way.
func (s *ConnectServiceHandler) GetCurrentUser(ctx context.Context, req *connect.Request[v1pb.GetCurrentUserRequest]) (*connect.Response[v1pb.GetCurrentUserResponse], error) {
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.GetCurrentUserResponse, error) {
return s.APIV1Service.GetCurrentUser(ctx, req.Msg)
})
}
func (s *ConnectServiceHandler) SignIn(ctx context.Context, req *connect.Request[v1pb.SignInRequest]) (*connect.Response[v1pb.SignInResponse], error) {
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.SignInResponse, error) {
return s.APIV1Service.SignIn(ctx, req.Msg)
})
}
func (s *ConnectServiceHandler) SignOut(ctx context.Context, req *connect.Request[v1pb.SignOutRequest]) (*connect.Response[emptypb.Empty], error) {
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*emptypb.Empty, error) {
return s.APIV1Service.SignOut(ctx, req.Msg)
})
}
func (s *ConnectServiceHandler) RefreshToken(ctx context.Context, req *connect.Request[v1pb.RefreshTokenRequest]) (*connect.Response[v1pb.RefreshTokenResponse], error) {
return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.RefreshTokenResponse, error) {
return s.APIV1Service.RefreshToken(ctx, req.Msg)
})
}
// UserService
func (s *ConnectServiceHandler) ListUsers(ctx context.Context, req *connect.Request[v1pb.ListUsersRequest]) (*connect.Response[v1pb.ListUsersResponse], error) {
resp, err := s.APIV1Service.ListUsers(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetUser(ctx context.Context, req *connect.Request[v1pb.GetUserRequest]) (*connect.Response[v1pb.User], error) {
resp, err := s.APIV1Service.GetUser(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateUser(ctx context.Context, req *connect.Request[v1pb.CreateUserRequest]) (*connect.Response[v1pb.User], error) {
resp, err := s.APIV1Service.CreateUser(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateUser(ctx context.Context, req *connect.Request[v1pb.UpdateUserRequest]) (*connect.Response[v1pb.User], error) {
resp, err := s.APIV1Service.UpdateUser(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteUser(ctx context.Context, req *connect.Request[v1pb.DeleteUserRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteUser(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListAllUserStats(ctx context.Context, req *connect.Request[v1pb.ListAllUserStatsRequest]) (*connect.Response[v1pb.ListAllUserStatsResponse], error) {
resp, err := s.APIV1Service.ListAllUserStats(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetUserStats(ctx context.Context, req *connect.Request[v1pb.GetUserStatsRequest]) (*connect.Response[v1pb.UserStats], error) {
resp, err := s.APIV1Service.GetUserStats(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetUserSetting(ctx context.Context, req *connect.Request[v1pb.GetUserSettingRequest]) (*connect.Response[v1pb.UserSetting], error) {
resp, err := s.APIV1Service.GetUserSetting(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateUserSetting(ctx context.Context, req *connect.Request[v1pb.UpdateUserSettingRequest]) (*connect.Response[v1pb.UserSetting], error) {
resp, err := s.APIV1Service.UpdateUserSetting(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListUserSettings(ctx context.Context, req *connect.Request[v1pb.ListUserSettingsRequest]) (*connect.Response[v1pb.ListUserSettingsResponse], error) {
resp, err := s.APIV1Service.ListUserSettings(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListPersonalAccessTokens(ctx context.Context, req *connect.Request[v1pb.ListPersonalAccessTokensRequest]) (*connect.Response[v1pb.ListPersonalAccessTokensResponse], error) {
resp, err := s.APIV1Service.ListPersonalAccessTokens(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreatePersonalAccessToken(ctx context.Context, req *connect.Request[v1pb.CreatePersonalAccessTokenRequest]) (*connect.Response[v1pb.CreatePersonalAccessTokenResponse], error) {
resp, err := s.APIV1Service.CreatePersonalAccessToken(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeletePersonalAccessToken(ctx context.Context, req *connect.Request[v1pb.DeletePersonalAccessTokenRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeletePersonalAccessToken(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListUserWebhooks(ctx context.Context, req *connect.Request[v1pb.ListUserWebhooksRequest]) (*connect.Response[v1pb.ListUserWebhooksResponse], error) {
resp, err := s.APIV1Service.ListUserWebhooks(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateUserWebhook(ctx context.Context, req *connect.Request[v1pb.CreateUserWebhookRequest]) (*connect.Response[v1pb.UserWebhook], error) {
resp, err := s.APIV1Service.CreateUserWebhook(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateUserWebhook(ctx context.Context, req *connect.Request[v1pb.UpdateUserWebhookRequest]) (*connect.Response[v1pb.UserWebhook], error) {
resp, err := s.APIV1Service.UpdateUserWebhook(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteUserWebhook(ctx context.Context, req *connect.Request[v1pb.DeleteUserWebhookRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteUserWebhook(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListUserNotifications(ctx context.Context, req *connect.Request[v1pb.ListUserNotificationsRequest]) (*connect.Response[v1pb.ListUserNotificationsResponse], error) {
resp, err := s.APIV1Service.ListUserNotifications(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateUserNotification(ctx context.Context, req *connect.Request[v1pb.UpdateUserNotificationRequest]) (*connect.Response[v1pb.UserNotification], error) {
resp, err := s.APIV1Service.UpdateUserNotification(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteUserNotification(ctx context.Context, req *connect.Request[v1pb.DeleteUserNotificationRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteUserNotification(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// MemoService
func (s *ConnectServiceHandler) CreateMemo(ctx context.Context, req *connect.Request[v1pb.CreateMemoRequest]) (*connect.Response[v1pb.Memo], error) {
resp, err := s.APIV1Service.CreateMemo(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListMemos(ctx context.Context, req *connect.Request[v1pb.ListMemosRequest]) (*connect.Response[v1pb.ListMemosResponse], error) {
resp, err := s.APIV1Service.ListMemos(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetMemo(ctx context.Context, req *connect.Request[v1pb.GetMemoRequest]) (*connect.Response[v1pb.Memo], error) {
resp, err := s.APIV1Service.GetMemo(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateMemo(ctx context.Context, req *connect.Request[v1pb.UpdateMemoRequest]) (*connect.Response[v1pb.Memo], error) {
resp, err := s.APIV1Service.UpdateMemo(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteMemo(ctx context.Context, req *connect.Request[v1pb.DeleteMemoRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteMemo(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) SetMemoAttachments(ctx context.Context, req *connect.Request[v1pb.SetMemoAttachmentsRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.SetMemoAttachments(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListMemoAttachments(ctx context.Context, req *connect.Request[v1pb.ListMemoAttachmentsRequest]) (*connect.Response[v1pb.ListMemoAttachmentsResponse], error) {
resp, err := s.APIV1Service.ListMemoAttachments(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) SetMemoRelations(ctx context.Context, req *connect.Request[v1pb.SetMemoRelationsRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.SetMemoRelations(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListMemoRelations(ctx context.Context, req *connect.Request[v1pb.ListMemoRelationsRequest]) (*connect.Response[v1pb.ListMemoRelationsResponse], error) {
resp, err := s.APIV1Service.ListMemoRelations(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateMemoComment(ctx context.Context, req *connect.Request[v1pb.CreateMemoCommentRequest]) (*connect.Response[v1pb.Memo], error) {
resp, err := s.APIV1Service.CreateMemoComment(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListMemoComments(ctx context.Context, req *connect.Request[v1pb.ListMemoCommentsRequest]) (*connect.Response[v1pb.ListMemoCommentsResponse], error) {
resp, err := s.APIV1Service.ListMemoComments(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListMemoReactions(ctx context.Context, req *connect.Request[v1pb.ListMemoReactionsRequest]) (*connect.Response[v1pb.ListMemoReactionsResponse], error) {
resp, err := s.APIV1Service.ListMemoReactions(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpsertMemoReaction(ctx context.Context, req *connect.Request[v1pb.UpsertMemoReactionRequest]) (*connect.Response[v1pb.Reaction], error) {
resp, err := s.APIV1Service.UpsertMemoReaction(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteMemoReaction(ctx context.Context, req *connect.Request[v1pb.DeleteMemoReactionRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteMemoReaction(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// AttachmentService
func (s *ConnectServiceHandler) CreateAttachment(ctx context.Context, req *connect.Request[v1pb.CreateAttachmentRequest]) (*connect.Response[v1pb.Attachment], error) {
resp, err := s.APIV1Service.CreateAttachment(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListAttachments(ctx context.Context, req *connect.Request[v1pb.ListAttachmentsRequest]) (*connect.Response[v1pb.ListAttachmentsResponse], error) {
resp, err := s.APIV1Service.ListAttachments(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetAttachment(ctx context.Context, req *connect.Request[v1pb.GetAttachmentRequest]) (*connect.Response[v1pb.Attachment], error) {
resp, err := s.APIV1Service.GetAttachment(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateAttachment(ctx context.Context, req *connect.Request[v1pb.UpdateAttachmentRequest]) (*connect.Response[v1pb.Attachment], error) {
resp, err := s.APIV1Service.UpdateAttachment(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteAttachment(ctx context.Context, req *connect.Request[v1pb.DeleteAttachmentRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteAttachment(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// ShortcutService
func (s *ConnectServiceHandler) ListShortcuts(ctx context.Context, req *connect.Request[v1pb.ListShortcutsRequest]) (*connect.Response[v1pb.ListShortcutsResponse], error) {
resp, err := s.APIV1Service.ListShortcuts(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetShortcut(ctx context.Context, req *connect.Request[v1pb.GetShortcutRequest]) (*connect.Response[v1pb.Shortcut], error) {
resp, err := s.APIV1Service.GetShortcut(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateShortcut(ctx context.Context, req *connect.Request[v1pb.CreateShortcutRequest]) (*connect.Response[v1pb.Shortcut], error) {
resp, err := s.APIV1Service.CreateShortcut(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateShortcut(ctx context.Context, req *connect.Request[v1pb.UpdateShortcutRequest]) (*connect.Response[v1pb.Shortcut], error) {
resp, err := s.APIV1Service.UpdateShortcut(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteShortcut(ctx context.Context, req *connect.Request[v1pb.DeleteShortcutRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteShortcut(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// ActivityService
func (s *ConnectServiceHandler) ListActivities(ctx context.Context, req *connect.Request[v1pb.ListActivitiesRequest]) (*connect.Response[v1pb.ListActivitiesResponse], error) {
resp, err := s.APIV1Service.ListActivities(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetActivity(ctx context.Context, req *connect.Request[v1pb.GetActivityRequest]) (*connect.Response[v1pb.Activity], error) {
resp, err := s.APIV1Service.GetActivity(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// IdentityProviderService
func (s *ConnectServiceHandler) ListIdentityProviders(ctx context.Context, req *connect.Request[v1pb.ListIdentityProvidersRequest]) (*connect.Response[v1pb.ListIdentityProvidersResponse], error) {
resp, err := s.APIV1Service.ListIdentityProviders(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetIdentityProvider(ctx context.Context, req *connect.Request[v1pb.GetIdentityProviderRequest]) (*connect.Response[v1pb.IdentityProvider], error) {
resp, err := s.APIV1Service.GetIdentityProvider(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateIdentityProvider(ctx context.Context, req *connect.Request[v1pb.CreateIdentityProviderRequest]) (*connect.Response[v1pb.IdentityProvider], error) {
resp, err := s.APIV1Service.CreateIdentityProvider(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) UpdateIdentityProvider(ctx context.Context, req *connect.Request[v1pb.UpdateIdentityProviderRequest]) (*connect.Response[v1pb.IdentityProvider], error) {
resp, err := s.APIV1Service.UpdateIdentityProvider(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteIdentityProvider(ctx context.Context, req *connect.Request[v1pb.DeleteIdentityProviderRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteIdentityProvider(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}

View File

@@ -0,0 +1,124 @@
package v1
import (
"context"
"connectrpc.com/connect"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
// headerCarrierKey is the context key for storing headers to be set in the response.
type headerCarrierKey struct{}
// HeaderCarrier stores headers that need to be set in the response.
//
// Problem: The codebase supports two protocols simultaneously:
// - Native gRPC: Uses grpc.SetHeader() to set response headers
// - Connect-RPC: Uses connect.Response.Header().Set() to set response headers
//
// Solution: HeaderCarrier provides a protocol-agnostic way to set headers.
// - Service methods call SetResponseHeader() regardless of protocol
// - For gRPC requests: SetResponseHeader uses grpc.SetHeader directly
// - For Connect requests: SetResponseHeader stores headers in HeaderCarrier
// - Connect wrappers extract headers from HeaderCarrier and apply to response
//
// This allows service methods to work with both protocols without knowing which one is being used.
type HeaderCarrier struct {
headers map[string]string
}
// newHeaderCarrier creates a new header carrier.
func newHeaderCarrier() *HeaderCarrier {
return &HeaderCarrier{
headers: make(map[string]string),
}
}
// Set adds a header to the carrier.
func (h *HeaderCarrier) Set(key, value string) {
h.headers[key] = value
}
// Get retrieves a header from the carrier.
func (h *HeaderCarrier) Get(key string) string {
return h.headers[key]
}
// All returns all headers.
func (h *HeaderCarrier) All() map[string]string {
return h.headers
}
// WithHeaderCarrier adds a header carrier to the context.
func WithHeaderCarrier(ctx context.Context) context.Context {
return context.WithValue(ctx, headerCarrierKey{}, newHeaderCarrier())
}
// GetHeaderCarrier retrieves the header carrier from the context.
// Returns nil if no carrier is present.
func GetHeaderCarrier(ctx context.Context) *HeaderCarrier {
if carrier, ok := ctx.Value(headerCarrierKey{}).(*HeaderCarrier); ok {
return carrier
}
return nil
}
// SetResponseHeader sets a header in the response.
//
// This function works for both gRPC and Connect protocols:
// - For gRPC: Uses grpc.SetHeader to set headers in gRPC metadata
// - For Connect: Stores in HeaderCarrier for Connect wrapper to apply later
//
// The protocol is automatically detected based on whether a HeaderCarrier
// exists in the context (injected by Connect wrappers).
func SetResponseHeader(ctx context.Context, key, value string) error {
// Try Connect first (check if we have a header carrier)
if carrier := GetHeaderCarrier(ctx); carrier != nil {
carrier.Set(key, value)
return nil
}
// Fall back to gRPC
return grpc.SetHeader(ctx, metadata.New(map[string]string{
key: value,
}))
}
// connectWithHeaderCarrier is a helper for Connect service wrappers that need to set response headers.
//
// It injects a HeaderCarrier into the context, calls the service method,
// and applies any headers from the carrier to the Connect response.
//
// The generic parameter T is the non-pointer protobuf message type (e.g., v1pb.CreateSessionResponse),
// while fn returns *T (the pointer type) as is standard for protobuf messages.
//
// Usage in Connect wrappers:
//
// func (s *ConnectServiceHandler) CreateSession(ctx context.Context, req *connect.Request[v1pb.CreateSessionRequest]) (*connect.Response[v1pb.CreateSessionResponse], error) {
// return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.CreateSessionResponse, error) {
// return s.APIV1Service.CreateSession(ctx, req.Msg)
// })
// }
func connectWithHeaderCarrier[T any](ctx context.Context, fn func(context.Context) (*T, error)) (*connect.Response[T], error) {
// Inject header carrier for Connect protocol
ctx = WithHeaderCarrier(ctx)
// Call the service method
resp, err := fn(ctx)
if err != nil {
return nil, convertGRPCError(err)
}
// Create Connect response
connectResp := connect.NewResponse(resp)
// Apply any headers set via the header carrier
if carrier := GetHeaderCarrier(ctx); carrier != nil {
for key, value := range carrier.All() {
connectResp.Header().Set(key, value)
}
}
return connectResp, nil
}

View File

@@ -0,0 +1,25 @@
package v1
import (
"context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
)
func (s *APIV1Service) Check(ctx context.Context,
_ *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) {
// Check if database is initialized by verifying instance basic setting exists
instanceBasicSetting, err := s.Store.GetInstanceBasicSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Unavailable, "database not initialized: %v", err)
}
// Verify schema version is set (empty means database not properly initialized)
if instanceBasicSetting.SchemaVersion == "" {
return nil, status.Errorf(codes.Unavailable, "schema version not set")
}
return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil
}

View File

@@ -0,0 +1,238 @@
package v1
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb.CreateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
}
return convertIdentityProviderFromStore(identityProvider), nil
}
func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListIdentityProvidersRequest) (*v1pb.ListIdentityProvidersResponse, error) {
identityProviders, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list identity providers, error: %+v", err)
}
response := &v1pb.ListIdentityProvidersResponse{
IdentityProviders: []*v1pb.IdentityProvider{},
}
// Default to lowest-privilege role, update later based on real role
currentUserRole := store.RoleUser
currentUser, err := s.fetchCurrentUser(ctx)
if err == nil && currentUser != nil {
currentUserRole = currentUser.Role
}
for _, identityProvider := range identityProviders {
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
response.IdentityProviders = append(response.IdentityProviders, redactIdentityProviderResponse(identityProviderConverted, currentUserRole))
}
return response, nil
}
func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
// Default to lowest-privilege role, update later based on real role
currentUserRole := store.RoleUser
currentUser, err := s.fetchCurrentUser(ctx)
if err == nil && currentUser != nil {
currentUserRole = currentUser.Role
}
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
return redactIdentityProviderResponse(identityProviderConverted, currentUserRole), nil
}
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
update := &store.UpdateIdentityProviderV1{
ID: id,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
}
for _, field := range request.UpdateMask.Paths {
switch field {
case "title":
update.Name = &request.IdentityProvider.Title
case "identifier_filter":
update.IdentifierFilter = &request.IdentityProvider.IdentifierFilter
case "config":
update.Config = convertIdentityProviderConfigToStore(request.IdentityProvider.Type, request.IdentityProvider.Config)
default:
// Ignore unsupported fields
}
}
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
}
return convertIdentityProviderFromStore(identityProvider), nil
}
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
// Check if the identity provider exists before trying to delete it
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &id})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
}
return &emptypb.Empty{}, nil
}
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
temp := &v1pb.IdentityProvider{
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
Title: identityProvider.Name,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
}
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2Config := identityProvider.Config.GetOauth2Config()
temp.Config = &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &v1pb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
AvatarUrl: oauth2Config.FieldMapping.AvatarUrl,
},
},
},
}
}
return temp
}
func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
temp := &storepb.IdentityProvider{
Id: id,
Name: identityProvider.Title,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
Config: convertIdentityProviderConfigToStore(identityProvider.Type, identityProvider.Config),
}
return temp
}
func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProvider_Type, config *v1pb.IdentityProviderConfig) *storepb.IdentityProviderConfig {
if identityProviderType == v1pb.IdentityProvider_OAUTH2 {
oauth2Config := config.GetOauth2Config()
return &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &storepb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
AvatarUrl: oauth2Config.FieldMapping.AvatarUrl,
},
},
},
}
}
return nil
}
func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider {
if userRole != store.RoleAdmin {
if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 {
identityProvider.Config.GetOauth2Config().ClientSecret = ""
}
}
return identityProvider
}

View File

@@ -0,0 +1,285 @@
package v1
import (
"context"
"fmt"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// GetInstanceProfile returns the instance profile.
func (s *APIV1Service) GetInstanceProfile(ctx context.Context, _ *v1pb.GetInstanceProfileRequest) (*v1pb.InstanceProfile, error) {
admin, err := s.GetInstanceAdmin(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance admin: %v", err)
}
instanceProfile := &v1pb.InstanceProfile{
Version: s.Profile.Version,
Demo: s.Profile.Demo,
InstanceUrl: s.Profile.InstanceURL,
Admin: admin, // nil when not initialized
}
return instanceProfile, nil
}
func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.GetInstanceSettingRequest) (*v1pb.InstanceSetting, error) {
instanceSettingKeyString, err := ExtractInstanceSettingKeyFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid instance setting name: %v", err)
}
instanceSettingKey := storepb.InstanceSettingKey(storepb.InstanceSettingKey_value[instanceSettingKeyString])
// Get instance setting from store with default value.
switch instanceSettingKey {
case storepb.InstanceSettingKey_BASIC:
_, err = s.Store.GetInstanceBasicSetting(ctx)
case storepb.InstanceSettingKey_GENERAL:
_, err = s.Store.GetInstanceGeneralSetting(ctx)
case storepb.InstanceSettingKey_MEMO_RELATED:
_, err = s.Store.GetInstanceMemoRelatedSetting(ctx)
case storepb.InstanceSettingKey_STORAGE:
_, err = s.Store.GetInstanceStorageSetting(ctx)
default:
return nil, status.Errorf(codes.InvalidArgument, "unsupported instance setting key: %v", instanceSettingKey)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance setting: %v", err)
}
instanceSetting, err := s.Store.GetInstanceSetting(ctx, &store.FindInstanceSetting{
Name: instanceSettingKey.String(),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance setting: %v", err)
}
if instanceSetting == nil {
return nil, status.Errorf(codes.NotFound, "instance setting not found")
}
// For storage setting, only admin can get it.
if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if user.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
return convertInstanceSettingFromStore(instanceSetting), nil
}
func (s *APIV1Service) UpdateInstanceSetting(ctx context.Context, request *v1pb.UpdateInstanceSettingRequest) (*v1pb.InstanceSetting, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if user.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// TODO: Apply update_mask if specified
_ = request.UpdateMask
updateSetting := convertInstanceSettingToStore(request.Setting)
instanceSetting, err := s.Store.UpsertInstanceSetting(ctx, updateSetting)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert instance setting: %v", err)
}
return convertInstanceSettingFromStore(instanceSetting), nil
}
func convertInstanceSettingFromStore(setting *storepb.InstanceSetting) *v1pb.InstanceSetting {
instanceSetting := &v1pb.InstanceSetting{
Name: fmt.Sprintf("instance/settings/%s", setting.Key.String()),
}
switch setting.Value.(type) {
case *storepb.InstanceSetting_GeneralSetting:
instanceSetting.Value = &v1pb.InstanceSetting_GeneralSetting_{
GeneralSetting: convertInstanceGeneralSettingFromStore(setting.GetGeneralSetting()),
}
case *storepb.InstanceSetting_StorageSetting:
instanceSetting.Value = &v1pb.InstanceSetting_StorageSetting_{
StorageSetting: convertInstanceStorageSettingFromStore(setting.GetStorageSetting()),
}
case *storepb.InstanceSetting_MemoRelatedSetting:
instanceSetting.Value = &v1pb.InstanceSetting_MemoRelatedSetting_{
MemoRelatedSetting: convertInstanceMemoRelatedSettingFromStore(setting.GetMemoRelatedSetting()),
}
}
return instanceSetting
}
func convertInstanceSettingToStore(setting *v1pb.InstanceSetting) *storepb.InstanceSetting {
settingKeyString, _ := ExtractInstanceSettingKeyFromName(setting.Name)
instanceSetting := &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey(storepb.InstanceSettingKey_value[settingKeyString]),
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: convertInstanceGeneralSettingToStore(setting.GetGeneralSetting()),
},
}
switch instanceSetting.Key {
case storepb.InstanceSettingKey_GENERAL:
instanceSetting.Value = &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: convertInstanceGeneralSettingToStore(setting.GetGeneralSetting()),
}
case storepb.InstanceSettingKey_STORAGE:
instanceSetting.Value = &storepb.InstanceSetting_StorageSetting{
StorageSetting: convertInstanceStorageSettingToStore(setting.GetStorageSetting()),
}
case storepb.InstanceSettingKey_MEMO_RELATED:
instanceSetting.Value = &storepb.InstanceSetting_MemoRelatedSetting{
MemoRelatedSetting: convertInstanceMemoRelatedSettingToStore(setting.GetMemoRelatedSetting()),
}
default:
// Keep the default GeneralSetting value
}
return instanceSetting
}
func convertInstanceGeneralSettingFromStore(setting *storepb.InstanceGeneralSetting) *v1pb.InstanceSetting_GeneralSetting {
if setting == nil {
return nil
}
generalSetting := &v1pb.InstanceSetting_GeneralSetting{
DisallowUserRegistration: setting.DisallowUserRegistration,
DisallowPasswordAuth: setting.DisallowPasswordAuth,
AdditionalScript: setting.AdditionalScript,
AdditionalStyle: setting.AdditionalStyle,
WeekStartDayOffset: setting.WeekStartDayOffset,
DisallowChangeUsername: setting.DisallowChangeUsername,
DisallowChangeNickname: setting.DisallowChangeNickname,
}
if setting.CustomProfile != nil {
generalSetting.CustomProfile = &v1pb.InstanceSetting_GeneralSetting_CustomProfile{
Title: setting.CustomProfile.Title,
Description: setting.CustomProfile.Description,
LogoUrl: setting.CustomProfile.LogoUrl,
}
}
return generalSetting
}
func convertInstanceGeneralSettingToStore(setting *v1pb.InstanceSetting_GeneralSetting) *storepb.InstanceGeneralSetting {
if setting == nil {
return nil
}
generalSetting := &storepb.InstanceGeneralSetting{
DisallowUserRegistration: setting.DisallowUserRegistration,
DisallowPasswordAuth: setting.DisallowPasswordAuth,
AdditionalScript: setting.AdditionalScript,
AdditionalStyle: setting.AdditionalStyle,
WeekStartDayOffset: setting.WeekStartDayOffset,
DisallowChangeUsername: setting.DisallowChangeUsername,
DisallowChangeNickname: setting.DisallowChangeNickname,
}
if setting.CustomProfile != nil {
generalSetting.CustomProfile = &storepb.InstanceCustomProfile{
Title: setting.CustomProfile.Title,
Description: setting.CustomProfile.Description,
LogoUrl: setting.CustomProfile.LogoUrl,
}
}
return generalSetting
}
func convertInstanceStorageSettingFromStore(settingpb *storepb.InstanceStorageSetting) *v1pb.InstanceSetting_StorageSetting {
if settingpb == nil {
return nil
}
setting := &v1pb.InstanceSetting_StorageSetting{
StorageType: v1pb.InstanceSetting_StorageSetting_StorageType(settingpb.StorageType),
FilepathTemplate: settingpb.FilepathTemplate,
UploadSizeLimitMb: settingpb.UploadSizeLimitMb,
}
if settingpb.S3Config != nil {
setting.S3Config = &v1pb.InstanceSetting_StorageSetting_S3Config{
AccessKeyId: settingpb.S3Config.AccessKeyId,
AccessKeySecret: settingpb.S3Config.AccessKeySecret,
Endpoint: settingpb.S3Config.Endpoint,
Region: settingpb.S3Config.Region,
Bucket: settingpb.S3Config.Bucket,
UsePathStyle: settingpb.S3Config.UsePathStyle,
}
}
return setting
}
func convertInstanceStorageSettingToStore(setting *v1pb.InstanceSetting_StorageSetting) *storepb.InstanceStorageSetting {
if setting == nil {
return nil
}
settingpb := &storepb.InstanceStorageSetting{
StorageType: storepb.InstanceStorageSetting_StorageType(setting.StorageType),
FilepathTemplate: setting.FilepathTemplate,
UploadSizeLimitMb: setting.UploadSizeLimitMb,
}
if setting.S3Config != nil {
settingpb.S3Config = &storepb.StorageS3Config{
AccessKeyId: setting.S3Config.AccessKeyId,
AccessKeySecret: setting.S3Config.AccessKeySecret,
Endpoint: setting.S3Config.Endpoint,
Region: setting.S3Config.Region,
Bucket: setting.S3Config.Bucket,
UsePathStyle: setting.S3Config.UsePathStyle,
}
}
return settingpb
}
func convertInstanceMemoRelatedSettingFromStore(setting *storepb.InstanceMemoRelatedSetting) *v1pb.InstanceSetting_MemoRelatedSetting {
if setting == nil {
return nil
}
return &v1pb.InstanceSetting_MemoRelatedSetting{
DisallowPublicVisibility: setting.DisallowPublicVisibility,
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
ContentLengthLimit: setting.ContentLengthLimit,
EnableDoubleClickEdit: setting.EnableDoubleClickEdit,
Reactions: setting.Reactions,
}
}
func convertInstanceMemoRelatedSettingToStore(setting *v1pb.InstanceSetting_MemoRelatedSetting) *storepb.InstanceMemoRelatedSetting {
if setting == nil {
return nil
}
return &storepb.InstanceMemoRelatedSetting{
DisallowPublicVisibility: setting.DisallowPublicVisibility,
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
ContentLengthLimit: setting.ContentLengthLimit,
EnableDoubleClickEdit: setting.EnableDoubleClickEdit,
Reactions: setting.Reactions,
}
}
func (s *APIV1Service) GetInstanceAdmin(ctx context.Context) (*v1pb.User, error) {
adminUserType := store.RoleAdmin
user, err := s.Store.GetUser(ctx, &store.FindUser{
Role: &adminUserType,
})
if err != nil {
return nil, errors.Wrapf(err, "failed to find admin")
}
if user == nil {
return nil, nil
}
return convertUserFromStore(user), nil
}

View File

@@ -0,0 +1,136 @@
package v1
import (
"context"
"slices"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
// Delete attachments that are not in the request.
for _, attachment := range attachments {
found := false
for _, requestAttachment := range request.Attachments {
requestAttachmentUID, err := ExtractAttachmentUIDFromName(requestAttachment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
}
if attachment.UID == requestAttachmentUID {
found = true
break
}
}
if !found {
if err = s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{
ID: int32(attachment.ID),
MemoID: &memo.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
}
}
}
slices.Reverse(request.Attachments)
// Update attachments' memo_id in the request.
for index, attachment := range request.Attachments {
attachmentUID, err := ExtractAttachmentUIDFromName(attachment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
}
tempAttachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
if tempAttachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found: %s", attachmentUID)
}
updatedTs := time.Now().Unix() + int64(index)
if err := s.Store.UpdateAttachment(ctx, &store.UpdateAttachment{
ID: tempAttachment.ID,
MemoID: &memo.ID,
UpdatedTs: &updatedTs,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.ListMemoAttachmentsRequest) (*v1pb.ListMemoAttachmentsResponse, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility.
if memo.Visibility != store.Public {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err)
}
response := &v1pb.ListMemoAttachmentsResponse{
Attachments: []*v1pb.Attachment{},
}
for _, attachment := range attachments {
response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment))
}
return response, nil
}

View File

@@ -0,0 +1,181 @@
package v1
import (
"context"
"fmt"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
referenceType := store.MemoRelationReference
// Delete all reference relations first.
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &memo.ID,
Type: &referenceType,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo relation")
}
for _, relation := range request.Relations {
// Ignore reflexive relations.
if request.Name == relation.RelatedMemo.Name {
continue
}
// Ignore comment relations as there's no need to update a comment's relation.
// Inserting/Deleting a comment is handled elsewhere.
if relation.Type == v1pb.MemoRelation_COMMENT {
continue
}
relatedMemoUID, err := ExtractMemoUIDFromName(relation.RelatedMemo.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &relatedMemoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get related memo")
}
if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo.ID,
RelatedMemoID: relatedMemo.ID,
Type: convertMemoRelationTypeToStore(relation.Type),
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert memo relation")
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.ListMemoRelationsRequest) (*v1pb.ListMemoRelationsResponse, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
var memoFilter string
if currentUser == nil {
memoFilter = `visibility == "PUBLIC"`
} else {
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
}
relationList := []*v1pb.MemoRelation{}
tempList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memo.ID,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo relations: %v", err)
}
for _, raw := range tempList {
relation, err := s.convertMemoRelationFromStore(ctx, raw)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert memo relation")
}
relationList = append(relationList, relation)
}
tempList, err = s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &memo.ID,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list related memo relations: %v", err)
}
for _, raw := range tempList {
relation, err := s.convertMemoRelationFromStore(ctx, raw)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert memo relation")
}
relationList = append(relationList, relation)
}
response := &v1pb.ListMemoRelationsResponse{
Relations: relationList,
}
return response, nil
}
func (s *APIV1Service) convertMemoRelationFromStore(ctx context.Context, memoRelation *store.MemoRelation) (*v1pb.MemoRelation, error) {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoRelation.MemoID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
memoSnippet, err := s.getMemoContentSnippet(memo.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get memo content snippet")
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoRelation.RelatedMemoID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get related memo: %v", err)
}
relatedMemoSnippet, err := s.getMemoContentSnippet(relatedMemo.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get related memo content snippet")
}
return &v1pb.MemoRelation{
Memo: &v1pb.MemoRelation_Memo{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
Snippet: memoSnippet,
},
RelatedMemo: &v1pb.MemoRelation_Memo{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
Snippet: relatedMemoSnippet,
},
Type: convertMemoRelationTypeFromStore(memoRelation.Type),
}, nil
}
func convertMemoRelationTypeFromStore(relationType store.MemoRelationType) v1pb.MemoRelation_Type {
switch relationType {
case store.MemoRelationReference:
return v1pb.MemoRelation_REFERENCE
case store.MemoRelationComment:
return v1pb.MemoRelation_COMMENT
default:
return v1pb.MemoRelation_TYPE_UNSPECIFIED
}
}
func convertMemoRelationTypeToStore(relationType v1pb.MemoRelation_Type) store.MemoRelationType {
switch relationType {
case v1pb.MemoRelation_COMMENT:
return store.MemoRelationComment
default:
return store.MemoRelationReference
}
}

View File

@@ -0,0 +1,866 @@
package v1
import (
"context"
"fmt"
"log/slog"
"strings"
"time"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/internal/base"
"github.com/usememos/memos/plugin/webhook"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/runner/memopayload"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Use custom memo_id if provided, otherwise generate a new UUID
memoUID := strings.TrimSpace(request.MemoId)
if memoUID == "" {
memoUID = shortuuid.New()
} else if !base.UIDMatcher.MatchString(memoUID) {
// Validate custom memo ID format
return nil, status.Errorf(codes.InvalidArgument, "invalid memo_id format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen")
}
create := &store.Memo{
UID: memoUID,
CreatorID: user.ID,
Content: request.Memo.Content,
Visibility: convertVisibilityToStore(request.Memo.Visibility),
}
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
// Handle display_time first: if provided, use it to set the appropriate timestamp
// based on the instance setting (similar to UpdateMemo logic)
// Note: explicit create_time/update_time below will override this if provided
if request.Memo.DisplayTime != nil && request.Memo.DisplayTime.IsValid() {
displayTs := request.Memo.DisplayTime.AsTime().Unix()
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
create.UpdatedTs = displayTs
} else {
create.CreatedTs = displayTs
}
}
// Set custom timestamps if provided in the request
// These take precedence over display_time
if request.Memo.CreateTime != nil && request.Memo.CreateTime.IsValid() {
createdTs := request.Memo.CreateTime.AsTime().Unix()
create.CreatedTs = createdTs
}
if request.Memo.UpdateTime != nil && request.Memo.UpdateTime.IsValid() {
updatedTs := request.Memo.UpdateTime.AsTime().Unix()
create.UpdatedTs = updatedTs
}
if instanceMemoRelatedSetting.DisallowPublicVisibility && create.Visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
contentLengthLimit, err := s.getContentLengthLimit(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
}
if len(create.Content) > contentLengthLimit {
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
}
if err := memopayload.RebuildMemoPayload(create, s.MarkdownService); err != nil {
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
}
if request.Memo.Location != nil {
create.Payload.Location = convertLocationToStore(request.Memo.Location)
}
memo, err := s.Store.CreateMemo(ctx, create)
if err != nil {
// Check for unique constraint violation (AIP-133 compliance)
errMsg := err.Error()
if strings.Contains(errMsg, "UNIQUE constraint failed") ||
strings.Contains(errMsg, "duplicate key") ||
strings.Contains(errMsg, "Duplicate entry") {
return nil, status.Errorf(codes.AlreadyExists, "memo with ID %q already exists", memoUID)
}
return nil, err
}
attachments := []*store.Attachment{}
if len(request.Memo.Attachments) > 0 {
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
Attachments: request.Memo.Attachments,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo attachments")
}
a, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo attachments")
}
attachments = a
}
if len(request.Memo.Relations) > 0 {
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
Relations: request.Memo.Relations,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo relations")
}
}
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
// Try to dispatch webhook when memo is created.
if err := s.DispatchMemoCreatedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err))
}
return memoMessage, nil
}
func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosRequest) (*v1pb.ListMemosResponse, error) {
memoFind := &store.FindMemo{
// Exclude comments by default.
ExcludeComments: true,
}
if request.State == v1pb.State_ARCHIVED {
state := store.Archived
memoFind.RowStatus = &state
} else {
state := store.Normal
memoFind.RowStatus = &state
}
// Parse order_by field (replaces the old sort and direction fields)
if request.OrderBy != "" {
if err := s.parseMemoOrderBy(request.OrderBy, memoFind); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid order_by: %v", err)
}
} else {
// Default ordering by display_time desc
memoFind.OrderByTimeAsc = false
}
if request.Filter != "" {
if err := s.validateFilter(ctx, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
memoFind.Filters = append(memoFind.Filters, request.Filter)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if currentUser == nil {
memoFind.VisibilityList = []store.Visibility{store.Public}
} else {
if memoFind.CreatorID == nil {
filter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
memoFind.Filters = append(memoFind.Filters, filter)
} else if *memoFind.CreatorID != currentUser.ID {
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
}
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
memoFind.OrderByUpdatedTs = true
}
var limit, offset int
if request.PageToken != "" {
var pageToken v1pb.PageToken
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
}
limit = int(pageToken.Limit)
offset = int(pageToken.Offset)
} else {
limit = int(request.PageSize)
}
if limit <= 0 {
limit = DefaultPageSize
}
limitPlusOne := limit + 1
memoFind.Limit = &limitPlusOne
memoFind.Offset = &offset
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
memoMessages := []*v1pb.Memo{}
nextPageToken := ""
if len(memos) == limitPlusOne {
memos = memos[:limit]
nextPageToken, err = getPageToken(limit, offset+limit)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
}
}
if len(memos) == 0 {
response := &v1pb.ListMemosResponse{
Memos: memoMessages,
NextPageToken: nextPageToken,
}
return response, nil
}
reactionMap := make(map[string][]*store.Reaction)
contentIDs := make([]string, 0, len(memos))
attachmentMap := make(map[int32][]*store.Attachment)
memoIDs := make([]int32, 0, len(memos))
for _, m := range memos {
contentIDs = append(contentIDs, fmt.Sprintf("%s%s", MemoNamePrefix, m.UID))
memoIDs = append(memoIDs, m.ID)
}
// REACTIONS
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ContentIDList: contentIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
for _, reaction := range reactions {
reactionMap[reaction.ContentID] = append(reactionMap[reaction.ContentID], reaction)
}
// ATTACHMENTS
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoIDList: memoIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
for _, attachment := range attachments {
attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment)
}
for _, memo := range memos {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
reactions := reactionMap[memoName]
attachments := attachmentMap[memo.ID]
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memoMessages = append(memoMessages, memoMessage)
}
response := &v1pb.ListMemosResponse{
Memos: memoMessages,
NextPageToken: nextPageToken,
}
return response, nil
}
func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest) (*v1pb.Memo, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
UID: &memoUID,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility != store.Public {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
return memoMessage, nil
}
func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoRequest) (*v1pb.Memo, error) {
memoUID, err := ExtractMemoUIDFromName(request.Memo.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Only the creator or admin can update the memo.
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
update := &store.UpdateMemo{
ID: memo.ID,
}
for _, path := range request.UpdateMask.Paths {
if path == "content" {
contentLengthLimit, err := s.getContentLengthLimit(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
}
if len(request.Memo.Content) > contentLengthLimit {
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
}
memo.Content = request.Memo.Content
if err := memopayload.RebuildMemoPayload(memo, s.MarkdownService); err != nil {
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
}
update.Content = &memo.Content
update.Payload = memo.Payload
} else if path == "visibility" {
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
visibility := convertVisibilityToStore(request.Memo.Visibility)
if instanceMemoRelatedSetting.DisallowPublicVisibility && visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
update.Visibility = &visibility
} else if path == "pinned" {
update.Pinned = &request.Memo.Pinned
} else if path == "state" {
rowStatus := convertStateToStore(request.Memo.State)
update.RowStatus = &rowStatus
} else if path == "create_time" {
createdTs := request.Memo.CreateTime.AsTime().Unix()
update.CreatedTs = &createdTs
} else if path == "update_time" {
updatedTs := time.Now().Unix()
if request.Memo.UpdateTime != nil {
updatedTs = request.Memo.UpdateTime.AsTime().Unix()
}
update.UpdatedTs = &updatedTs
} else if path == "display_time" {
displayTs := request.Memo.DisplayTime.AsTime().Unix()
memoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
if memoRelatedSetting.DisplayWithUpdateTime {
update.UpdatedTs = &displayTs
} else {
update.CreatedTs = &displayTs
}
} else if path == "location" {
payload := memo.Payload
payload.Location = convertLocationToStore(request.Memo.Location)
update.Payload = payload
} else if path == "attachments" {
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
Name: request.Memo.Name,
Attachments: request.Memo.Attachments,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo attachments")
}
} else if path == "relations" {
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: request.Memo.Name,
Relations: request.Memo.Relations,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo relations")
}
}
}
if err = s.Store.UpdateMemo(ctx, update); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo")
}
memo, err = s.Store.GetMemo(ctx, &store.FindMemo{
ID: &memo.ID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo")
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Memo.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
// Try to dispatch webhook when memo is updated.
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
}
return memoMessage, nil
}
func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoRequest) (*emptypb.Empty, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
UID: &memoUID,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Only the creator or admin can update the memo.
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments); err == nil {
// Try to dispatch webhook when memo is deleted.
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err))
}
}
// Delete memo comments first (store.DeleteMemo handles their relations and attachments)
commentType := store.MemoRelationComment
relations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{RelatedMemoID: &memo.ID, Type: &commentType})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo comments")
}
for _, relation := range relations {
if err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: relation.MemoID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo comment")
}
}
// Delete the memo (store.DeleteMemo handles relation and attachment cleanup)
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo")
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.CreateMemoCommentRequest) (*v1pb.Memo, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if relatedMemo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility before allowing comment.
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if relatedMemo.Visibility == store.Private && relatedMemo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// Create the memo comment first.
memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{
Memo: request.Comment,
MemoId: request.CommentId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create memo")
}
memoUID, err = ExtractMemoUIDFromName(memoComment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
// Build the relation between the comment memo and the original memo.
_, err = s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationComment,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create memo relation")
}
creatorID, err := ExtractUserIDFromName(memoComment.Creator)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo creator")
}
if memoComment.Visibility != v1pb.Visibility_PRIVATE && creatorID != relatedMemo.CreatorID {
activity, err := s.Store.CreateActivity(ctx, &store.Activity{
CreatorID: creatorID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{
MemoComment: &storepb.ActivityMemoCommentPayload{
MemoId: memo.ID,
RelatedMemoId: relatedMemo.ID,
},
},
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create activity")
}
if _, err := s.Store.CreateInbox(ctx, &store.Inbox{
SenderID: creatorID,
ReceiverID: relatedMemo.CreatorID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
ActivityId: &activity.ID,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to create inbox")
}
}
if err := s.DispatchMemoCommentCreatedWebhook(ctx, memoComment, relatedMemo.CreatorID); err != nil {
slog.Warn("Failed to dispatch memo comment created webhook", slog.Any("err", err))
}
return memoComment, nil
}
func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListMemoCommentsRequest) (*v1pb.ListMemoCommentsResponse, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
var memoFilter string
if currentUser == nil {
memoFilter = `visibility == "PUBLIC"`
} else {
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
}
memoRelationComment := store.MemoRelationComment
memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &memo.ID,
Type: &memoRelationComment,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo relations")
}
if len(memoRelations) == 0 {
response := &v1pb.ListMemoCommentsResponse{
Memos: []*v1pb.Memo{},
}
return response, nil
}
memoRelationIDs := make([]int32, 0, len(memoRelations))
for _, m := range memoRelations {
memoRelationIDs = append(memoRelationIDs, m.MemoID)
}
memos, err := s.Store.ListMemos(ctx, &store.FindMemo{IDList: memoRelationIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
memoIDToNameMap := make(map[int32]string)
contentIDs := make([]string, 0, len(memos))
memoIDsForAttachments := make([]int32, 0, len(memos))
for _, memo := range memos {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
memoIDToNameMap[memo.ID] = memoName
contentIDs = append(contentIDs, memoName)
memoIDsForAttachments = append(memoIDsForAttachments, memo.ID)
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ContentIDList: contentIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
memoReactionsMap := make(map[string][]*store.Reaction)
for _, reaction := range reactions {
memoReactionsMap[reaction.ContentID] = append(memoReactionsMap[reaction.ContentID], reaction)
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoIDList: memoIDsForAttachments})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
attachmentMap := make(map[int32][]*store.Attachment)
for _, attachment := range attachments {
attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment)
}
var memosResponse []*v1pb.Memo
for _, m := range memos {
memoName := memoIDToNameMap[m.ID]
reactions := memoReactionsMap[memoName]
attachments := attachmentMap[m.ID]
memoMessage, err := s.convertMemoFromStore(ctx, m, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memosResponse = append(memosResponse, memoMessage)
}
response := &v1pb.ListMemoCommentsResponse{
Memos: memosResponse,
}
return response, nil
}
func (s *APIV1Service) getContentLengthLimit(ctx context.Context) (int, error) {
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return 0, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
return int(instanceMemoRelatedSetting.ContentLengthLimit), nil
}
// DispatchMemoCreatedWebhook dispatches webhook when memo is created.
func (s *APIV1Service) DispatchMemoCreatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.created")
}
// DispatchMemoUpdatedWebhook dispatches webhook when memo is updated.
func (s *APIV1Service) DispatchMemoUpdatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.updated")
}
// DispatchMemoDeletedWebhook dispatches webhook when memo is deleted.
func (s *APIV1Service) DispatchMemoDeletedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.deleted")
}
// DispatchMemoCommentCreatedWebhook dispatches webhook to the related memo owner when a comment is created.
func (s *APIV1Service) DispatchMemoCommentCreatedWebhook(ctx context.Context, commentMemo *v1pb.Memo, relatedMemoCreatorID int32) error {
webhooks, err := s.Store.GetUserWebhooks(ctx, relatedMemoCreatorID)
if err != nil {
return err
}
for _, hook := range webhooks {
payload, err := convertMemoToWebhookPayload(commentMemo)
if err != nil {
return errors.Wrap(err, "failed to convert memo to webhook payload")
}
payload.ActivityType = "memos.memo.comment.created"
payload.URL = hook.Url
webhook.PostAsync(payload)
}
return nil
}
func (s *APIV1Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *v1pb.Memo, activityType string) error {
creatorID, err := ExtractUserIDFromName(memo.Creator)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid memo creator")
}
webhooks, err := s.Store.GetUserWebhooks(ctx, creatorID)
if err != nil {
return err
}
for _, hook := range webhooks {
payload, err := convertMemoToWebhookPayload(memo)
if err != nil {
return errors.Wrap(err, "failed to convert memo to webhook payload")
}
payload.ActivityType = activityType
payload.URL = hook.Url
// Use asynchronous webhook dispatch
webhook.PostAsync(payload)
}
return nil
}
func convertMemoToWebhookPayload(memo *v1pb.Memo) (*webhook.WebhookRequestPayload, error) {
creatorID, err := ExtractUserIDFromName(memo.Creator)
if err != nil {
return nil, errors.Wrap(err, "invalid memo creator")
}
return &webhook.WebhookRequestPayload{
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creatorID),
Memo: memo,
}, nil
}
func (s *APIV1Service) getMemoContentSnippet(content string) (string, error) {
// Use goldmark service for snippet generation
snippet, err := s.MarkdownService.GenerateSnippet([]byte(content), 64)
if err != nil {
return "", errors.Wrap(err, "failed to generate snippet")
}
return snippet, nil
}
// parseMemoOrderBy parses the order_by field and sets the appropriate ordering in memoFind.
// Follows AIP-132: supports comma-separated list of fields with optional "desc" suffix.
// Example: "pinned desc, display_time desc" or "create_time asc".
func (*APIV1Service) parseMemoOrderBy(orderBy string, memoFind *store.FindMemo) error {
if strings.TrimSpace(orderBy) == "" {
return errors.New("empty order_by")
}
// Split by comma to support multiple sort fields per AIP-132.
fields := strings.Split(orderBy, ",")
// Track if we've seen pinned field.
hasPinned := false
for _, field := range fields {
parts := strings.Fields(strings.TrimSpace(field))
if len(parts) == 0 {
continue
}
fieldName := parts[0]
fieldDirection := "desc" // default per AIP-132 (we use desc as default for time fields)
if len(parts) > 1 {
fieldDirection = strings.ToLower(parts[1])
if fieldDirection != "asc" && fieldDirection != "desc" {
return errors.Errorf("invalid order direction: %s, must be 'asc' or 'desc'", parts[1])
}
}
switch fieldName {
case "pinned":
hasPinned = true
memoFind.OrderByPinned = true
// Note: pinned is always DESC (true first) regardless of direction specified.
case "display_time", "create_time", "name":
// Only set if this is the first time field we encounter.
if !memoFind.OrderByUpdatedTs {
memoFind.OrderByTimeAsc = fieldDirection == "asc"
}
case "update_time":
memoFind.OrderByUpdatedTs = true
memoFind.OrderByTimeAsc = fieldDirection == "asc"
default:
return errors.Errorf("unsupported order field: %s, supported fields are: pinned, display_time, create_time, update_time, name", fieldName)
}
}
// If only pinned was specified, still need to set a default time ordering.
if hasPinned && !memoFind.OrderByUpdatedTs && len(fields) == 1 {
memoFind.OrderByTimeAsc = false // default to desc
}
return nil
}

View File

@@ -0,0 +1,134 @@
package v1
import (
"context"
"fmt"
"time"
"github.com/pkg/errors"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment) (*v1pb.Memo, error) {
displayTs := memo.CreatedTs
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get instance memo related setting")
}
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
name := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
memoMessage := &v1pb.Memo{
Name: name,
State: convertStateFromStore(memo.RowStatus),
Creator: fmt.Sprintf("%s%d", UserNamePrefix, memo.CreatorID),
CreateTime: timestamppb.New(time.Unix(memo.CreatedTs, 0)),
UpdateTime: timestamppb.New(time.Unix(memo.UpdatedTs, 0)),
DisplayTime: timestamppb.New(time.Unix(displayTs, 0)),
Content: memo.Content,
Visibility: convertVisibilityFromStore(memo.Visibility),
Pinned: memo.Pinned,
}
if memo.Payload != nil {
memoMessage.Tags = memo.Payload.Tags
memoMessage.Property = convertMemoPropertyFromStore(memo.Payload.Property)
memoMessage.Location = convertLocationFromStore(memo.Payload.Location)
}
if memo.ParentUID != nil {
parentName := fmt.Sprintf("%s%s", MemoNamePrefix, *memo.ParentUID)
memoMessage.Parent = &parentName
}
memoMessage.Reactions = []*v1pb.Reaction{}
for _, reaction := range reactions {
reactionResponse := convertReactionFromStore(reaction)
memoMessage.Reactions = append(memoMessage.Reactions, reactionResponse)
}
listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &v1pb.ListMemoRelationsRequest{Name: name})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo relations")
}
memoMessage.Relations = listMemoRelationsResponse.Relations
memoMessage.Attachments = []*v1pb.Attachment{}
for _, attachment := range attachments {
attachmentResponse := convertAttachmentFromStore(attachment)
memoMessage.Attachments = append(memoMessage.Attachments, attachmentResponse)
}
snippet, err := s.getMemoContentSnippet(memo.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get memo content snippet")
}
memoMessage.Snippet = snippet
return memoMessage, nil
}
func convertMemoPropertyFromStore(property *storepb.MemoPayload_Property) *v1pb.Memo_Property {
if property == nil {
return nil
}
return &v1pb.Memo_Property{
HasLink: property.HasLink,
HasTaskList: property.HasTaskList,
HasCode: property.HasCode,
HasIncompleteTasks: property.HasIncompleteTasks,
}
}
func convertLocationFromStore(location *storepb.MemoPayload_Location) *v1pb.Location {
if location == nil {
return nil
}
return &v1pb.Location{
Placeholder: location.Placeholder,
Latitude: location.Latitude,
Longitude: location.Longitude,
}
}
func convertLocationToStore(location *v1pb.Location) *storepb.MemoPayload_Location {
if location == nil {
return nil
}
return &storepb.MemoPayload_Location{
Placeholder: location.Placeholder,
Latitude: location.Latitude,
Longitude: location.Longitude,
}
}
func convertVisibilityFromStore(visibility store.Visibility) v1pb.Visibility {
switch visibility {
case store.Private:
return v1pb.Visibility_PRIVATE
case store.Protected:
return v1pb.Visibility_PROTECTED
case store.Public:
return v1pb.Visibility_PUBLIC
default:
return v1pb.Visibility_VISIBILITY_UNSPECIFIED
}
}
func convertVisibilityToStore(visibility v1pb.Visibility) store.Visibility {
switch visibility {
case v1pb.Visibility_PROTECTED:
return store.Protected
case v1pb.Visibility_PUBLIC:
return store.Public
default:
return store.Private
}
}

View File

@@ -0,0 +1 @@
package v1

View File

@@ -0,0 +1,153 @@
package v1
import (
"context"
"fmt"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.ListMemoReactionsRequest) (*v1pb.ListMemoReactionsResponse, error) {
// Extract memo UID and check visibility.
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility.
if memo.Visibility != store.Public {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
response := &v1pb.ListMemoReactionsResponse{
Reactions: []*v1pb.Reaction{},
}
for _, reaction := range reactions {
reactionMessage := convertReactionFromStore(reaction)
response.Reactions = append(response.Reactions, reactionMessage)
}
return response, nil
}
func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.UpsertMemoReactionRequest) (*v1pb.Reaction, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Extract memo UID and check visibility before allowing reaction.
memoUID, err := ExtractMemoUIDFromName(request.Reaction.ContentId)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility.
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: request.Reaction.ContentId,
ReactionType: request.Reaction.ReactionType,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert reaction")
}
reactionMessage := convertReactionFromStore(reaction)
return reactionMessage, nil
}
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
_, reactionID, err := ExtractMemoReactionIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
}
// Get reaction and check ownership.
reaction, err := s.Store.GetReaction(ctx, &store.FindReaction{
ID: &reactionID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get reaction")
}
if reaction == nil {
// Return permission denied to avoid revealing if reaction exists.
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if reaction.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
ID: reactionID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
}
return &emptypb.Empty{}, nil
}
func convertReactionFromStore(reaction *store.Reaction) *v1pb.Reaction {
reactionUID := fmt.Sprintf("%d", reaction.ID)
// Generate nested resource name: memos/{memo}/reactions/{reaction}
// reaction.ContentID already contains "memos/{memo}"
return &v1pb.Reaction{
Name: fmt.Sprintf("%s/%s%s", reaction.ContentID, ReactionNamePrefix, reactionUID),
Creator: fmt.Sprintf("%s%d", UserNamePrefix, reaction.CreatorID),
ContentId: reaction.ContentID,
ReactionType: reaction.ReactionType,
CreateTime: timestamppb.New(time.Unix(reaction.CreatedTs, 0)),
}
}

View File

@@ -0,0 +1,158 @@
package v1
import (
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
)
const (
InstanceSettingNamePrefix = "instance/settings/"
UserNamePrefix = "users/"
MemoNamePrefix = "memos/"
AttachmentNamePrefix = "attachments/"
ReactionNamePrefix = "reactions/"
InboxNamePrefix = "inboxes/"
IdentityProviderNamePrefix = "identity-providers/"
ActivityNamePrefix = "activities/"
WebhookNamePrefix = "webhooks/"
)
// GetNameParentTokens returns the tokens from a resource name.
func GetNameParentTokens(name string, tokenPrefixes ...string) ([]string, error) {
parts := strings.Split(name, "/")
if len(parts) != 2*len(tokenPrefixes) {
return nil, errors.Errorf("invalid request %q", name)
}
var tokens []string
for i, tokenPrefix := range tokenPrefixes {
if fmt.Sprintf("%s/", parts[2*i]) != tokenPrefix {
return nil, errors.Errorf("invalid prefix %q in request %q", tokenPrefix, name)
}
if parts[2*i+1] == "" {
return nil, errors.Errorf("invalid request %q with empty prefix %q", name, tokenPrefix)
}
tokens = append(tokens, parts[2*i+1])
}
return tokens, nil
}
func ExtractInstanceSettingKeyFromName(name string) (string, error) {
const prefix = "instance/settings/"
if !strings.HasPrefix(name, prefix) {
return "", errors.Errorf("invalid instance setting name: expected prefix %q, got %q", prefix, name)
}
settingKey := strings.TrimPrefix(name, prefix)
if settingKey == "" {
return "", errors.Errorf("invalid instance setting name: empty setting key in %q", name)
}
// Ensure there are no additional path segments
if strings.Contains(settingKey, "/") {
return "", errors.Errorf("invalid instance setting name: setting key cannot contain '/' in %q", name)
}
return settingKey, nil
}
// ExtractUserIDFromName returns the uid from a resource name.
func ExtractUserIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, UserNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid user ID %q", tokens[0])
}
return id, nil
}
// extractUserIdentifierFromName extracts the identifier (ID or username) from a user resource name.
// Supports: "users/101" or "users/steven"
// Returns the identifier string (e.g., "101" or "steven").
func extractUserIdentifierFromName(name string) string {
tokens, err := GetNameParentTokens(name, UserNamePrefix)
if err != nil || len(tokens) == 0 {
return ""
}
return tokens[0]
}
// ExtractMemoUIDFromName returns the memo UID from a resource name.
// e.g., "memos/uuid" -> "uuid".
func ExtractMemoUIDFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, MemoNamePrefix)
if err != nil {
return "", err
}
id := tokens[0]
return id, nil
}
// ExtractAttachmentUIDFromName returns the attachment UID from a resource name.
func ExtractAttachmentUIDFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, AttachmentNamePrefix)
if err != nil {
return "", err
}
id := tokens[0]
return id, nil
}
// ExtractMemoReactionIDFromName returns the memo UID and reaction ID from a resource name.
// e.g., "memos/abc/reactions/123" -> ("abc", 123).
func ExtractMemoReactionIDFromName(name string) (string, int32, error) {
tokens, err := GetNameParentTokens(name, MemoNamePrefix, ReactionNamePrefix)
if err != nil {
return "", 0, err
}
memoUID := tokens[0]
reactionID, err := util.ConvertStringToInt32(tokens[1])
if err != nil {
return "", 0, errors.Errorf("invalid reaction ID %q", tokens[1])
}
return memoUID, reactionID, nil
}
// ExtractInboxIDFromName returns the inbox ID from a resource name.
func ExtractInboxIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, InboxNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid inbox ID %q", tokens[0])
}
return id, nil
}
func ExtractIdentityProviderIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
}
return id, nil
}
func ExtractActivityIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, ActivityNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid activity ID %q", tokens[0])
}
return id, nil
}

View File

@@ -0,0 +1,346 @@
package v1
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/filter"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// Helper function to extract user ID and shortcut ID from shortcut resource name.
// Format: users/{user}/shortcuts/{shortcut}.
func extractUserAndShortcutIDFromName(name string) (int32, string, error) {
parts := strings.Split(name, "/")
if len(parts) != 4 || parts[0] != "users" || parts[2] != "shortcuts" {
return 0, "", errors.Errorf("invalid shortcut name format: %s", name)
}
userID, err := util.ConvertStringToInt32(parts[1])
if err != nil {
return 0, "", errors.Errorf("invalid user ID %q", parts[1])
}
shortcutID := parts[3]
if shortcutID == "" {
return 0, "", errors.Errorf("empty shortcut ID in name: %s", name)
}
return userID, shortcutID, nil
}
// Helper function to construct shortcut resource name.
func constructShortcutName(userID int32, shortcutID string) string {
return fmt.Sprintf("users/%d/shortcuts/%s", userID, shortcutID)
}
func (s *APIV1Service) ListShortcuts(ctx context.Context, request *v1pb.ListShortcutsRequest) (*v1pb.ListShortcutsResponse, error) {
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user setting: %v", err)
}
if userSetting == nil {
return &v1pb.ListShortcutsResponse{
Shortcuts: []*v1pb.Shortcut{},
}, nil
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := []*v1pb.Shortcut{}
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
shortcuts = append(shortcuts, &v1pb.Shortcut{
Name: constructShortcutName(userID, shortcut.GetId()),
Title: shortcut.GetTitle(),
Filter: shortcut.GetFilter(),
})
}
return &v1pb.ListShortcutsResponse{
Shortcuts: shortcuts,
}, nil
}
func (s *APIV1Service) GetShortcut(ctx context.Context, request *v1pb.GetShortcutRequest) (*v1pb.Shortcut, error) {
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting := userSetting.GetShortcuts()
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
if shortcut.GetId() == shortcutID {
return &v1pb.Shortcut{
Name: constructShortcutName(userID, shortcut.GetId()),
Title: shortcut.GetTitle(),
Filter: shortcut.GetFilter(),
}, nil
}
}
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateShortcutRequest) (*v1pb.Shortcut, error) {
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
newShortcut := &storepb.ShortcutsUserSetting_Shortcut{
Id: util.GenUUID(),
Title: request.Shortcut.GetTitle(),
Filter: request.Shortcut.GetFilter(),
}
if newShortcut.Title == "" {
return nil, status.Errorf(codes.InvalidArgument, "title is required")
}
if err := s.validateFilter(ctx, newShortcut.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
if request.ValidateOnly {
return &v1pb.Shortcut{
Name: constructShortcutName(userID, newShortcut.GetId()),
Title: newShortcut.GetTitle(),
Filter: newShortcut.GetFilter(),
}, nil
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
userSetting = &storepb.UserSetting{
UserId: userID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{
Shortcuts: &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{},
},
},
}
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := shortcutsUserSetting.GetShortcuts()
shortcuts = append(shortcuts, newShortcut)
shortcutsUserSetting.Shortcuts = shortcuts
userSetting.Value = &storepb.UserSetting_Shortcuts{
Shortcuts: shortcutsUserSetting,
}
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
return &v1pb.Shortcut{
Name: constructShortcutName(userID, newShortcut.GetId()),
Title: newShortcut.GetTitle(),
Filter: newShortcut.GetFilter(),
}, nil
}
func (s *APIV1Service) UpdateShortcut(ctx context.Context, request *v1pb.UpdateShortcutRequest) (*v1pb.Shortcut, error) {
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Shortcut.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := shortcutsUserSetting.GetShortcuts()
var foundShortcut *storepb.ShortcutsUserSetting_Shortcut
newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts))
for _, shortcut := range shortcuts {
if shortcut.GetId() == shortcutID {
foundShortcut = shortcut
for _, field := range request.UpdateMask.Paths {
if field == "title" {
if request.Shortcut.GetTitle() == "" {
return nil, status.Errorf(codes.InvalidArgument, "title is required")
}
shortcut.Title = request.Shortcut.GetTitle()
} else if field == "filter" {
if err := s.validateFilter(ctx, request.Shortcut.GetFilter()); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
shortcut.Filter = request.Shortcut.GetFilter()
}
}
}
newShortcuts = append(newShortcuts, shortcut)
}
if foundShortcut == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting.Shortcuts = newShortcuts
userSetting.Value = &storepb.UserSetting_Shortcuts{
Shortcuts: shortcutsUserSetting,
}
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
if err != nil {
return nil, err
}
return &v1pb.Shortcut{
Name: constructShortcutName(userID, foundShortcut.GetId()),
Title: foundShortcut.GetTitle(),
Filter: foundShortcut.GetFilter(),
}, nil
}
func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteShortcutRequest) (*emptypb.Empty, error) {
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := shortcutsUserSetting.GetShortcuts()
newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts))
found := false
for _, shortcut := range shortcuts {
if shortcut.GetId() != shortcutID {
newShortcuts = append(newShortcuts, shortcut)
} else {
found = true
}
}
if !found {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting.Shortcuts = newShortcuts
userSetting.Value = &storepb.UserSetting_Shortcuts{
Shortcuts: shortcutsUserSetting,
}
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) validateFilter(ctx context.Context, filterStr string) error {
if filterStr == "" {
return errors.New("filter cannot be empty")
}
engine, err := filter.DefaultEngine()
if err != nil {
return err
}
var dialect filter.DialectName
switch s.Profile.Driver {
case "mysql":
dialect = filter.DialectMySQL
case "postgres":
dialect = filter.DialectPostgres
default:
dialect = filter.DialectSQLite
}
if _, err := engine.CompileToStatement(ctx, filterStr, filter.RenderOptions{Dialect: dialect}); err != nil {
return errors.Wrap(err, "failed to compile filter")
}
return nil
}

View File

@@ -0,0 +1,263 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
v1 "github.com/usememos/memos/server/router/api/v1" //nolint:revive
"github.com/usememos/memos/store"
)
// TestListActivitiesWithDeletedMemos verifies that ListActivities gracefully handles
// activities that reference deleted memos instead of crashing the entire request.
func TestListActivitiesWithDeletedMemos(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users - one to create memo, one to comment
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
require.NoError(t, err)
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
require.NoError(t, err)
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
// Create a memo by userOne
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Original memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo1)
// Create a comment on the memo by userTwo (this will create an activity for userOne)
comment, err := ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
Name: memo1.Name,
Comment: &apiv1.Memo{
Content: "This is a comment",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, comment)
// Verify activity was created for the comment (check from userOne's perspective - they receive the notification)
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
initialActivityCount := len(activities.Activities)
require.Greater(t, initialActivityCount, 0, "Should have at least one activity")
// Delete the original memo (this deletes the comment too)
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
Name: memo1.Name,
})
require.NoError(t, err)
// List activities again - should succeed even though the memo is deleted
activities, err = ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
// Activities list should be empty or not contain the deleted memo activity
for _, activity := range activities.Activities {
if activity.Payload != nil && activity.Payload.GetMemoComment() != nil {
require.NotEqual(t, memo1.Name, activity.Payload.GetMemoComment().Memo,
"Activity should not reference deleted memo")
}
}
// After deletion, there should be fewer activities
require.LessOrEqual(t, len(activities.Activities), initialActivityCount-1,
"Should have filtered out the activity for the deleted memo")
}
// TestGetActivityWithDeletedMemo verifies that GetActivity returns a proper error
// when trying to fetch an activity that references a deleted memo.
func TestGetActivityWithDeletedMemo(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
require.NoError(t, err)
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
require.NoError(t, err)
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
// Create a memo by userOne
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Original memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo1)
// Create a comment to trigger activity creation by userTwo
comment, err := ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
Name: memo1.Name,
Comment: &apiv1.Memo{
Content: "Comment",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, comment)
// Get the activity ID by listing activities from userOne's perspective
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
require.Greater(t, len(activities.Activities), 0)
activityName := activities.Activities[0].Name
// Delete the memo
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
Name: memo1.Name,
})
require.NoError(t, err)
// Try to get the specific activity - should return NotFound error
_, err = ts.Service.GetActivity(userOneCtx, &apiv1.GetActivityRequest{
Name: activityName,
})
require.Error(t, err)
require.Contains(t, err.Error(), "activity references deleted content")
}
// TestActivitiesWithPartiallyDeletedMemos verifies that when some memos are deleted,
// other valid activities are still returned.
func TestActivitiesWithPartiallyDeletedMemos(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
require.NoError(t, err)
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
require.NoError(t, err)
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
// Create two memos by userOne
memo1, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "First memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
memo2, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Second memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
// Create comments on both by userTwo (creates activities for userOne)
_, err = ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
Name: memo1.Name,
Comment: &apiv1.Memo{
Content: "Comment on first",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
_, err = ts.Service.CreateMemoComment(userTwoCtx, &apiv1.CreateMemoCommentRequest{
Name: memo2.Name,
Comment: &apiv1.Memo{
Content: "Comment on second",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
// Should have 2 activities from userOne's perspective
activities, err := ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
require.Equal(t, 2, len(activities.Activities))
// Delete first memo
_, err = ts.Service.DeleteMemo(userOneCtx, &apiv1.DeleteMemoRequest{
Name: memo1.Name,
})
require.NoError(t, err)
// List activities - should still work and return only the second memo's activity
activities, err = ts.Service.ListActivities(userOneCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
require.Equal(t, 1, len(activities.Activities), "Should have 1 activity remaining")
// Verify the remaining activity relates to a valid memo
require.NotNil(t, activities.Activities[0].Payload.GetMemoComment())
require.Contains(t, activities.Activities[0].Payload.GetMemoComment().RelatedMemo, "memos/")
}
// TestActivityStoreDirectDeletion tests the scenario where a memo is deleted directly
// from the store (simulating database-level deletion or migration).
func TestActivityStoreDirectDeletion(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "test-user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a memo
memo1, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
// Create a comment
comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{
Name: memo1.Name,
Comment: &apiv1.Memo{
Content: "Test comment",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
// Extract memo UID from the comment name
commentMemoUID, err := v1.ExtractMemoUIDFromName(comment.Name)
require.NoError(t, err)
commentMemo, err := ts.Store.GetMemo(ctx, &store.FindMemo{
UID: &commentMemoUID,
})
require.NoError(t, err)
require.NotNil(t, commentMemo)
// Delete the comment memo directly from store (simulating orphaned activity)
err = ts.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: commentMemo.ID})
require.NoError(t, err)
// List activities should still succeed even with orphaned activity
activities, err := ts.Service.ListActivities(userCtx, &apiv1.ListActivitiesRequest{})
require.NoError(t, err)
// Activities should be empty or not include the orphaned one
for _, activity := range activities.Activities {
if activity.Payload != nil && activity.Payload.GetMemoComment() != nil {
require.NotEqual(t, comment.Name, activity.Payload.GetMemoComment().Memo,
"Should not return activity with deleted memo")
}
}
}

View File

@@ -0,0 +1 @@
fake png content

View File

@@ -0,0 +1 @@
fake png content

View File

@@ -0,0 +1,2 @@
‰PNG


After

Width:  |  Height:  |  Size: 8 B

Binary file not shown.

View File

@@ -0,0 +1,59 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestCreateAttachment(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
ctx := context.Background()
user, err := ts.CreateRegularUser(ctx, "test_user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Test case 1: Create attachment with empty type but known extension
t.Run("EmptyType_KnownExtension", func(t *testing.T) {
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "test.png",
Content: []byte("fake png content"),
},
})
require.NoError(t, err)
require.Equal(t, "image/png", attachment.Type)
})
// Test case 2: Create attachment with empty type and unknown extension, but detectable content
t.Run("EmptyType_UnknownExtension_ContentSniffing", func(t *testing.T) {
// PNG magic header: 89 50 4E 47 0D 0A 1A 0A
pngContent := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "test.unknown",
Content: pngContent,
},
})
require.NoError(t, err)
require.Equal(t, "image/png", attachment.Type)
})
// Test case 3: Empty type, unknown extension, random content -> fallback to application/octet-stream
t.Run("EmptyType_Fallback", func(t *testing.T) {
randomContent := []byte{0x00, 0x01, 0x02, 0x03}
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "test.data",
Content: randomContent,
},
})
require.NoError(t, err)
require.Equal(t, "application/octet-stream", attachment.Type)
})
}

View File

@@ -0,0 +1,655 @@
package test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
func TestAuthenticatorAccessTokenV2(t *testing.T) {
ctx := context.Background()
t.Run("authenticates valid access token v2", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate access token v2
token, _, err := auth.GenerateAccessTokenV2(
user.ID,
user.Username,
string(user.Role),
string(user.RowStatus),
[]byte(ts.Secret),
)
require.NoError(t, err)
// Authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
claims, err := authenticator.AuthenticateByAccessTokenV2(token)
require.NoError(t, err)
assert.NotNil(t, claims)
assert.Equal(t, user.ID, claims.UserID)
assert.Equal(t, user.Username, claims.Username)
assert.Equal(t, string(user.Role), claims.Role)
assert.Equal(t, string(user.RowStatus), claims.Status)
})
t.Run("fails with invalid token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, err := authenticator.AuthenticateByAccessTokenV2("invalid-token")
assert.Error(t, err)
})
t.Run("fails with wrong secret", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate token with one secret
token, _, err := auth.GenerateAccessTokenV2(
user.ID,
user.Username,
string(user.Role),
string(user.RowStatus),
[]byte("secret-1"),
)
require.NoError(t, err)
// Try to authenticate with different secret
authenticator := auth.NewAuthenticator(ts.Store, "secret-2")
_, err = authenticator.AuthenticateByAccessTokenV2(token)
assert.Error(t, err)
})
}
func TestAuthenticatorRefreshToken(t *testing.T) {
ctx := context.Background()
t.Run("authenticates valid refresh token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create refresh token record in store
tokenID := util.GenUUID()
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(auth.RefreshTokenDuration)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord)
require.NoError(t, err)
// Generate refresh token JWT
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
require.NoError(t, err)
// Authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
authenticatedUser, returnedTokenID, err := authenticator.AuthenticateByRefreshToken(ctx, token)
require.NoError(t, err)
assert.NotNil(t, authenticatedUser)
assert.Equal(t, user.ID, authenticatedUser.ID)
assert.Equal(t, tokenID, returnedTokenID)
})
t.Run("fails with revoked token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
tokenID := util.GenUUID()
// Generate refresh token JWT but don't store it in database (simulates revocation)
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
require.NoError(t, err)
// Try to authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "revoked")
})
t.Run("fails with expired token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create expired refresh token record in store
tokenID := util.GenUUID()
expiredToken := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), // Expired
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, expiredToken)
require.NoError(t, err)
// Generate refresh token JWT (JWT itself isn't expired yet)
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
require.NoError(t, err)
// Try to authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
})
t.Run("fails with archived user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create valid refresh token
tokenID := util.GenUUID()
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(auth.RefreshTokenDuration)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord)
require.NoError(t, err)
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
require.NoError(t, err)
// Archive the user
archivedStatus := store.Archived
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
ID: user.ID,
RowStatus: &archivedStatus,
})
require.NoError(t, err)
// Try to authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "archived")
})
}
func TestAuthenticatorPAT(t *testing.T) {
ctx := context.Background()
t.Run("authenticates valid PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate PAT
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
// Store PAT in database
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
require.NoError(t, err)
// Authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
authenticatedUser, pat, err := authenticator.AuthenticateByPAT(ctx, token)
require.NoError(t, err)
assert.NotNil(t, authenticatedUser)
assert.NotNil(t, pat)
assert.Equal(t, user.ID, authenticatedUser.ID)
assert.Equal(t, tokenID, pat.TokenId)
})
t.Run("fails with invalid PAT format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err := authenticator.AuthenticateByPAT(ctx, "invalid-token-without-prefix")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid PAT format")
})
t.Run("fails with non-existent PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Generate a PAT but don't store it
token := auth.GeneratePersonalAccessToken()
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err := authenticator.AuthenticateByPAT(ctx, token)
assert.Error(t, err)
})
t.Run("fails with expired PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate and store expired PAT
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
expiredPAT := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Expired PAT",
ExpiresAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), // Expired
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, expiredPAT)
require.NoError(t, err)
// Try to authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err = authenticator.AuthenticateByPAT(ctx, token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
})
t.Run("succeeds with non-expiring PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate and store PAT without expiration
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Never-expiring PAT",
ExpiresAt: nil, // No expiration
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
require.NoError(t, err)
// Authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
authenticatedUser, pat, err := authenticator.AuthenticateByPAT(ctx, token)
require.NoError(t, err)
assert.NotNil(t, authenticatedUser)
assert.NotNil(t, pat)
assert.Nil(t, pat.ExpiresAt)
})
t.Run("fails with archived user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Generate and store PAT
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
require.NoError(t, err)
// Archive the user
archivedStatus := store.Archived
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
ID: user.ID,
RowStatus: &archivedStatus,
})
require.NoError(t, err)
// Try to authenticate
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
_, _, err = authenticator.AuthenticateByPAT(ctx, token)
assert.Error(t, err)
assert.Contains(t, err.Error(), "archived")
})
}
func TestStoreRefreshTokenMethods(t *testing.T) {
ctx := context.Background()
t.Run("adds and retrieves refresh token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
tokenID := util.GenUUID()
token := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
require.NoError(t, err)
// Retrieve tokens
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, tokens, 1)
assert.Equal(t, tokenID, tokens[0].TokenId)
})
t.Run("retrieves specific refresh token by ID", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
tokenID := util.GenUUID()
token := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
require.NoError(t, err)
// Retrieve specific token
retrievedToken, err := ts.Store.GetUserRefreshTokenByID(ctx, user.ID, tokenID)
require.NoError(t, err)
assert.NotNil(t, retrievedToken)
assert.Equal(t, tokenID, retrievedToken.TokenId)
})
t.Run("removes refresh token", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
tokenID := util.GenUUID()
token := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID,
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
require.NoError(t, err)
// Remove token
err = ts.Store.RemoveUserRefreshToken(ctx, user.ID, tokenID)
require.NoError(t, err)
// Verify removal
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, tokens, 0)
})
t.Run("handles multiple refresh tokens", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Add multiple tokens
tokenID1 := util.GenUUID()
tokenID2 := util.GenUUID()
token1 := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID1,
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
CreatedAt: timestamppb.Now(),
}
token2 := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: tokenID2,
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token1)
require.NoError(t, err)
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token2)
require.NoError(t, err)
// Retrieve all tokens
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, tokens, 2)
// Remove one token
err = ts.Store.RemoveUserRefreshToken(ctx, user.ID, tokenID1)
require.NoError(t, err)
// Verify only one token remains
tokens, err = ts.Store.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, tokens, 1)
assert.Equal(t, tokenID2, tokens[0].TokenId)
})
}
func TestStorePersonalAccessTokenMethods(t *testing.T) {
ctx := context.Background()
t.Run("adds and retrieves PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Retrieve PATs
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, pats, 1)
assert.Equal(t, tokenID, pats[0].TokenId)
assert.Equal(t, tokenHash, pats[0].TokenHash)
})
t.Run("removes PAT", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Remove PAT
err = ts.Store.RemoveUserPersonalAccessToken(ctx, user.ID, tokenID)
require.NoError(t, err)
// Verify removal
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, pats, 0)
})
t.Run("updates PAT last used time", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Update last used time
lastUsed := timestamppb.Now()
err = ts.Store.UpdatePATLastUsed(ctx, user.ID, tokenID, lastUsed)
require.NoError(t, err)
// Verify update
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, pats, 1)
assert.NotNil(t, pats[0].LastUsedAt)
})
t.Run("handles multiple PATs", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Add multiple PATs
token1 := auth.GeneratePersonalAccessToken()
tokenHash1 := auth.HashPersonalAccessToken(token1)
tokenID1 := util.GenUUID()
token2 := auth.GeneratePersonalAccessToken()
tokenHash2 := auth.HashPersonalAccessToken(token2)
tokenID2 := util.GenUUID()
pat1 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID1,
TokenHash: tokenHash1,
Description: "PAT 1",
CreatedAt: timestamppb.Now(),
}
pat2 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID2,
TokenHash: tokenHash2,
Description: "PAT 2",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat1)
require.NoError(t, err)
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat2)
require.NoError(t, err)
// Retrieve all PATs
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, pats, 2)
// Remove one PAT
err = ts.Store.RemoveUserPersonalAccessToken(ctx, user.ID, tokenID1)
require.NoError(t, err)
// Verify only one PAT remains
pats, err = ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, pats, 1)
assert.Equal(t, tokenID2, pats[0].TokenId)
})
t.Run("finds user by PAT hash", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
token := auth.GeneratePersonalAccessToken()
tokenHash := auth.HashPersonalAccessToken(token)
tokenID := util.GenUUID()
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tokenID,
TokenHash: tokenHash,
Description: "Test PAT",
CreatedAt: timestamppb.Now(),
}
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Find user by PAT hash
result, err := ts.Store.GetUserByPATHash(ctx, tokenHash)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, user.ID, result.UserID)
assert.NotNil(t, result.User)
assert.Equal(t, user.Username, result.User.Username)
assert.NotNil(t, result.PAT)
assert.Equal(t, tokenID, result.PAT.TokenId)
})
}

View File

@@ -0,0 +1,552 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestCreateIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("CreateIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
ctx := ts.CreateUserContext(ctx, hostUser.ID)
// Create OAuth2 identity provider
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test OAuth2 Provider",
IdentifierFilter: "",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "test-client-id",
ClientSecret: "test-client-secret",
AuthUrl: "https://example.com/oauth/authorize",
TokenUrl: "https://example.com/oauth/token",
UserInfoUrl: "https://example.com/oauth/userinfo",
Scopes: []string{"openid", "profile", "email"},
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
DisplayName: "name",
Email: "email",
AvatarUrl: "avatar_url",
},
},
},
},
},
}
resp, err := ts.Service.CreateIdentityProvider(ctx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "Test OAuth2 Provider", resp.Title)
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
require.Contains(t, resp.Name, "identity-providers/")
require.NotNil(t, resp.Config.GetOauth2Config())
require.Equal(t, "test-client-id", resp.Config.GetOauth2Config().ClientId)
})
t.Run("CreateIdentityProvider permission denied for non-host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
// Set user context
ctx := ts.CreateUserContext(ctx, regularUser.ID)
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err = ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("CreateIdentityProvider unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err := ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "user not authenticated")
})
}
func TestListIdentityProviders(t *testing.T) {
ctx := context.Background()
t.Run("ListIdentityProviders empty", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.ListIdentityProvidersRequest{}
resp, err := ts.Service.ListIdentityProviders(ctx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Empty(t, resp.IdentityProviders)
})
t.Run("ListIdentityProviders with providers", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a couple of identity providers
createReq1 := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Provider 1",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "client1",
AuthUrl: "https://example1.com/auth",
TokenUrl: "https://example1.com/token",
UserInfoUrl: "https://example1.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
createReq2 := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Provider 2",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "client2",
AuthUrl: "https://example2.com/auth",
TokenUrl: "https://example2.com/token",
UserInfoUrl: "https://example2.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq1)
require.NoError(t, err)
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq2)
require.NoError(t, err)
// List providers
listReq := &v1pb.ListIdentityProvidersRequest{}
resp, err := ts.Service.ListIdentityProviders(ctx, listReq)
require.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.IdentityProviders, 2)
// Verify response contains expected providers
titles := []string{resp.IdentityProviders[0].Title, resp.IdentityProviders[1].Title}
require.Contains(t, titles, "Provider 1")
require.Contains(t, titles, "Provider 2")
})
}
func TestGetIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("GetIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "test-client",
ClientSecret: "test-secret",
AuthUrl: "https://example.com/auth",
TokenUrl: "https://example.com/token",
UserInfoUrl: "https://example.com/user",
Scopes: []string{"openid", "profile"},
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
DisplayName: "name",
Email: "email",
},
},
},
},
},
}
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
require.NoError(t, err)
// Get identity provider
getReq := &v1pb.GetIdentityProviderRequest{
Name: created.Name,
}
// Test unauthenticated, should not contain client secret
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, created.Name, resp.Name)
require.Equal(t, "Test Provider", resp.Title)
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
require.NotNil(t, resp.Config.GetOauth2Config())
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
require.Equal(t, "", resp.Config.GetOauth2Config().ClientSecret)
// Test as host user, should contain client secret
respHostUser, err := ts.Service.GetIdentityProvider(userCtx, getReq)
require.NoError(t, err)
require.NotNil(t, respHostUser)
require.Equal(t, created.Name, respHostUser.Name)
require.Equal(t, "Test Provider", respHostUser.Title)
require.Equal(t, v1pb.IdentityProvider_OAUTH2, respHostUser.Type)
require.NotNil(t, respHostUser.Config.GetOauth2Config())
require.Equal(t, "test-client", respHostUser.Config.GetOauth2Config().ClientId)
require.Equal(t, "test-secret", respHostUser.Config.GetOauth2Config().ClientSecret)
})
t.Run("GetIdentityProvider not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.GetIdentityProviderRequest{
Name: "identity-providers/999",
}
_, err := ts.Service.GetIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
t.Run("GetIdentityProvider invalid name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.GetIdentityProviderRequest{
Name: "invalid-name",
}
_, err := ts.Service.GetIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
}
func TestUpdateIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("UpdateIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Original Provider",
IdentifierFilter: "",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "original-client",
AuthUrl: "https://original.com/auth",
TokenUrl: "https://original.com/token",
UserInfoUrl: "https://original.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
require.NoError(t, err)
// Update identity provider
updateReq := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: created.Name,
Title: "Updated Provider",
IdentifierFilter: "test@example.com",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "updated-client",
ClientSecret: "updated-secret",
AuthUrl: "https://updated.com/auth",
TokenUrl: "https://updated.com/token",
UserInfoUrl: "https://updated.com/user",
Scopes: []string{"openid", "profile", "email"},
FieldMapping: &v1pb.FieldMapping{
Identifier: "sub",
DisplayName: "given_name",
Email: "email",
AvatarUrl: "picture",
},
},
},
},
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "identifier_filter", "config"},
},
}
updated, err := ts.Service.UpdateIdentityProvider(userCtx, updateReq)
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "Updated Provider", updated.Title)
require.Equal(t, "test@example.com", updated.IdentifierFilter)
require.Equal(t, "updated-client", updated.Config.GetOauth2Config().ClientId)
})
t.Run("UpdateIdentityProvider missing update mask", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: "identity-providers/1",
Title: "Updated Provider",
},
}
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "update_mask is required")
})
t.Run("UpdateIdentityProvider invalid name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: "invalid-name",
Title: "Updated Provider",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title"},
},
}
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
}
func TestDeleteIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("DeleteIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Provider to Delete",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "client-to-delete",
AuthUrl: "https://example.com/auth",
TokenUrl: "https://example.com/token",
UserInfoUrl: "https://example.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
require.NoError(t, err)
// Delete identity provider
deleteReq := &v1pb.DeleteIdentityProviderRequest{
Name: created.Name,
}
_, err = ts.Service.DeleteIdentityProvider(userCtx, deleteReq)
require.NoError(t, err)
// Verify deletion
getReq := &v1pb.GetIdentityProviderRequest{
Name: created.Name,
}
_, err = ts.Service.GetIdentityProvider(ctx, getReq)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
t.Run("DeleteIdentityProvider invalid name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.DeleteIdentityProviderRequest{
Name: "invalid-name",
}
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
t.Run("DeleteIdentityProvider not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.DeleteIdentityProviderRequest{
Name: "identity-providers/999",
}
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
require.Error(t, err)
// Note: Delete might succeed even if item doesn't exist, depending on store implementation
})
}
func TestIdentityProviderPermissions(t *testing.T) {
ctx := context.Background()
t.Run("Only host users can create identity providers", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err = ts.Service.CreateIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("Authentication required", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err := ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "user not authenticated")
})
}

View File

@@ -0,0 +1,54 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestInstanceAdminRetrieval(t *testing.T) {
ctx := context.Background()
t.Run("Instance becomes initialized after first admin user is created", func(t *testing.T) {
// Create test service
ts := NewTestService(t)
defer ts.Cleanup()
// Verify instance is not initialized initially
profile1, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
require.NoError(t, err)
require.Nil(t, profile1.Admin, "Instance should not be initialized before first admin user")
// Create the first admin user
user, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
require.NotNil(t, user)
// Verify instance is now initialized
profile2, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
require.NoError(t, err)
require.NotNil(t, profile2.Admin, "Instance should be initialized after first admin user is created")
require.Equal(t, user.Username, profile2.Admin.Username)
})
t.Run("Admin retrieval is cached by Store layer", func(t *testing.T) {
// Create test service
ts := NewTestService(t)
defer ts.Cleanup()
// Create admin user
user, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Multiple calls should return consistent admin user (from cache)
for i := 0; i < 5; i++ {
profile, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
require.NoError(t, err)
require.NotNil(t, profile.Admin)
require.Equal(t, user.Username, profile.Admin.Username)
}
})
}

View File

@@ -0,0 +1,204 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestGetInstanceProfile(t *testing.T) {
ctx := context.Background()
t.Run("GetInstanceProfile returns instance profile", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetInstanceProfile directly
req := &v1pb.GetInstanceProfileRequest{}
resp, err := ts.Service.GetInstanceProfile(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
// Verify the response contains expected data
require.Equal(t, "test-1.0.0", resp.Version)
require.True(t, resp.Demo)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
// Instance should not be initialized since no admin users are created
require.Nil(t, resp.Admin)
})
t.Run("GetInstanceProfile with initialized instance", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user in the store
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
require.NotNil(t, hostUser)
// Call GetInstanceProfile directly
req := &v1pb.GetInstanceProfileRequest{}
resp, err := ts.Service.GetInstanceProfile(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
// Verify the response contains expected data with initialized flag
require.Equal(t, "test-1.0.0", resp.Version)
require.True(t, resp.Demo)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
// Instance should be initialized since an admin user exists
require.NotNil(t, resp.Admin)
require.Equal(t, hostUser.Username, resp.Admin.Username)
})
}
func TestGetInstanceProfile_Concurrency(t *testing.T) {
ctx := context.Background()
t.Run("Concurrent access to service", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user
_, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Make concurrent requests
numGoroutines := 10
results := make(chan *v1pb.InstanceProfile, numGoroutines)
errors := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
req := &v1pb.GetInstanceProfileRequest{}
resp, err := ts.Service.GetInstanceProfile(ctx, req)
if err != nil {
errors <- err
return
}
results <- resp
}()
}
// Collect all results
for i := 0; i < numGoroutines; i++ {
select {
case err := <-errors:
t.Fatalf("Goroutine returned error: %v", err)
case resp := <-results:
require.NotNil(t, resp)
require.Equal(t, "test-1.0.0", resp.Version)
require.True(t, resp.Demo)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
require.NotNil(t, resp.Admin)
}
}
})
}
func TestGetInstanceSetting(t *testing.T) {
ctx := context.Background()
t.Run("GetInstanceSetting - general setting", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetInstanceSetting for general setting
req := &v1pb.GetInstanceSettingRequest{
Name: "instance/settings/GENERAL",
}
resp, err := ts.Service.GetInstanceSetting(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "instance/settings/GENERAL", resp.Name)
// The general setting should have a general_setting field
generalSetting := resp.GetGeneralSetting()
require.NotNil(t, generalSetting)
// General setting should have default values
require.False(t, generalSetting.DisallowUserRegistration)
require.False(t, generalSetting.DisallowPasswordAuth)
require.Empty(t, generalSetting.AdditionalScript)
})
t.Run("GetInstanceSetting - storage setting", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user for storage setting access
hostUser, err := ts.CreateHostUser(ctx, "testhost")
require.NoError(t, err)
// Add user to context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Call GetInstanceSetting for storage setting
req := &v1pb.GetInstanceSettingRequest{
Name: "instance/settings/STORAGE",
}
resp, err := ts.Service.GetInstanceSetting(userCtx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "instance/settings/STORAGE", resp.Name)
// The storage setting should have a storage_setting field
storageSetting := resp.GetStorageSetting()
require.NotNil(t, storageSetting)
})
t.Run("GetInstanceSetting - memo related setting", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetInstanceSetting for memo related setting
req := &v1pb.GetInstanceSettingRequest{
Name: "instance/settings/MEMO_RELATED",
}
resp, err := ts.Service.GetInstanceSetting(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "instance/settings/MEMO_RELATED", resp.Name)
// The memo related setting should have a memo_related_setting field
memoRelatedSetting := resp.GetMemoRelatedSetting()
require.NotNil(t, memoRelatedSetting)
})
t.Run("GetInstanceSetting - invalid setting name", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetInstanceSetting with invalid name
req := &v1pb.GetInstanceSettingRequest{
Name: "invalid/setting/name",
}
_, err := ts.Service.GetInstanceSetting(ctx, req)
// Should return an error
require.Error(t, err)
require.Contains(t, err.Error(), "invalid instance setting name")
})
}

View File

@@ -0,0 +1,166 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestSetMemoAttachments(t *testing.T) {
ctx := context.Background()
t.Run("SetMemoAttachments success by memo owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create attachment
attachment, err := ts.Service.CreateAttachment(userCtx, &apiv1.CreateAttachmentRequest{
Attachment: &apiv1.Attachment{
Filename: "test.txt",
Size: 5,
Type: "text/plain",
Content: []byte("hello"),
},
})
require.NoError(t, err)
require.NotNil(t, attachment)
// Set memo attachments - should succeed
_, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{
{Name: attachment.Name},
},
})
require.NoError(t, err)
})
t.Run("SetMemoAttachments success by host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create memo by regular user
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Host user can modify attachments - should succeed
_, err = ts.Service.SetMemoAttachments(hostCtx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{},
})
require.NoError(t, err)
})
t.Run("SetMemoAttachments permission denied for non-owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user1
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
// Create user2
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
// Create memo by user1
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// User2 tries to modify attachments - should fail
_, err = ts.Service.SetMemoAttachments(user2Ctx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("SetMemoAttachments unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Unauthenticated user tries to modify attachments - should fail
_, err = ts.Service.SetMemoAttachments(ctx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
})
t.Run("SetMemoAttachments memo not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Try to set attachments on non-existent memo - should fail
_, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{
Name: "memos/nonexistent-uid-12345",
Attachments: []*apiv1.Attachment{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}

View File

@@ -0,0 +1,169 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestSetMemoRelations(t *testing.T) {
ctx := context.Background()
t.Run("SetMemoRelations success by memo owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo1
memo1, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo 1",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo1)
// Create memo2
memo2, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo 2",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo2)
// Set memo relations - should succeed
_, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{
Name: memo1.Name,
Relations: []*apiv1.MemoRelation{
{
RelatedMemo: &apiv1.MemoRelation_Memo{
Name: memo2.Name,
},
Type: apiv1.MemoRelation_REFERENCE,
},
},
})
require.NoError(t, err)
})
t.Run("SetMemoRelations success by host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create memo by regular user
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Host user can modify relations - should succeed
_, err = ts.Service.SetMemoRelations(hostCtx, &apiv1.SetMemoRelationsRequest{
Name: memo.Name,
Relations: []*apiv1.MemoRelation{},
})
require.NoError(t, err)
})
t.Run("SetMemoRelations permission denied for non-owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user1
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
// Create user2
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
// Create memo by user1
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// User2 tries to modify relations - should fail
_, err = ts.Service.SetMemoRelations(user2Ctx, &apiv1.SetMemoRelationsRequest{
Name: memo.Name,
Relations: []*apiv1.MemoRelation{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("SetMemoRelations unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Unauthenticated user tries to modify relations - should fail
_, err = ts.Service.SetMemoRelations(ctx, &apiv1.SetMemoRelationsRequest{
Name: memo.Name,
Relations: []*apiv1.MemoRelation{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
})
t.Run("SetMemoRelations memo not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Try to set relations on non-existent memo - should fail
_, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{
Name: "memos/nonexistent-uid-12345",
Relations: []*apiv1.MemoRelation{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}

View File

@@ -0,0 +1,369 @@
package test
import (
"context"
"fmt"
"slices"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestListMemos(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create userOne
userOne, err := ts.CreateRegularUser(ctx, "test-user-1")
require.NoError(t, err)
require.NotNil(t, userOne)
// Create userOne context
userOneCtx := ts.CreateUserContext(ctx, userOne.ID)
// Create userTwo
userTwo, err := ts.CreateRegularUser(ctx, "test-user-2")
require.NoError(t, err)
require.NotNil(t, userTwo)
// Create userTwo context
userTwoCtx := ts.CreateUserContext(ctx, userTwo.ID)
// Create attachmentOne by userOne
attachmentOne, err := ts.Service.CreateAttachment(userOneCtx, &apiv1.CreateAttachmentRequest{
Attachment: &apiv1.Attachment{
Name: "",
Filename: "hello.txt",
Size: 5,
Type: "text/plain",
Content: []byte{
104, 101, 108, 108, 111,
},
},
})
require.NoError(t, err)
require.NotNil(t, attachmentOne)
// Create attachmentTwo by userOne
attachmentTwo, err := ts.Service.CreateAttachment(userOneCtx, &apiv1.CreateAttachmentRequest{
Attachment: &apiv1.Attachment{
Name: "",
Filename: "world.txt",
Size: 5,
Type: "text/plain",
Content: []byte{
119, 111, 114, 108, 100,
},
},
})
require.NoError(t, err)
require.NotNil(t, attachmentTwo)
// Create memoOne with two attachments by userOne
memoOne, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Hellooo, any words after this sentence won't be in the snippet. This is the next sentence. And I also have two attachments.",
Visibility: apiv1.Visibility_PROTECTED,
Attachments: []*apiv1.Attachment{
&apiv1.Attachment{
Name: attachmentOne.Name,
},
&apiv1.Attachment{
Name: attachmentTwo.Name,
},
},
},
})
require.NoError(t, err)
require.NotNil(t, memoOne)
// Create memoTwo by userTwo referencing memoOne
memoTwo, err := ts.Service.CreateMemo(userTwoCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This is a memo reminding you to check the attachment attached to memoOne. I have referenced the memo below.⬇️",
Visibility: apiv1.Visibility_PROTECTED,
Relations: []*apiv1.MemoRelation{
&apiv1.MemoRelation{
RelatedMemo: &apiv1.MemoRelation_Memo{
Name: memoOne.Name,
},
},
},
},
})
require.NoError(t, err)
require.NotNil(t, memoTwo)
// Create memoThree by userOne
memoThree, err := ts.Service.CreateMemo(userOneCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This is a very popular memo. I have 2 reactions!",
Visibility: apiv1.Visibility_PROTECTED,
},
})
require.NoError(t, err)
require.NotNil(t, memoThree)
// Create reaction from userOne on memoThree
reactionOne, err := ts.Service.UpsertMemoReaction(userOneCtx, &apiv1.UpsertMemoReactionRequest{
Name: memoThree.Name,
Reaction: &apiv1.Reaction{
ContentId: memoThree.Name,
ReactionType: "❤️",
},
})
require.NoError(t, err)
require.NotNil(t, reactionOne)
// Create reaction from userTwo on memoThree
reactionTwo, err := ts.Service.UpsertMemoReaction(userTwoCtx, &apiv1.UpsertMemoReactionRequest{
Name: memoThree.Name,
Reaction: &apiv1.Reaction{
ContentId: memoThree.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reactionTwo)
memos, err := ts.Service.ListMemos(userOneCtx, &apiv1.ListMemosRequest{PageSize: 10})
require.NoError(t, err)
require.NotNil(t, memos)
require.Equal(t, 3, len(memos.Memos))
// ///////////////
// VERIFY MEMO ONE
// ///////////////
memoOneResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoOne.GetName() })
require.NotEqual(t, memoOneResIdx, -1)
memoOneRes := memos.Memos[memoOneResIdx]
require.NotNil(t, memoOneRes)
require.Equal(t, fmt.Sprintf("users/%d", userOne.ID), memoOneRes.GetCreator())
require.Equal(t, apiv1.Visibility_PROTECTED, memoOneRes.GetVisibility())
require.Equal(t, memoOne.Content, memoOneRes.GetContent())
require.Equal(t, memoOne.Content[:64]+"...", memoOneRes.GetSnippet(), "memoOne's content is snipped past the 64 char limit")
require.Len(t, memoOneRes.Attachments, 2)
require.Len(t, memoOneRes.Relations, 1)
require.Empty(t, memoOneRes.Reactions)
// verify memoOne's attachments
// attachment one
attachmentOneResIdx := slices.IndexFunc(memoOneRes.Attachments, func(a *apiv1.Attachment) bool { return a.GetName() == attachmentOne.GetName() })
require.NotEqual(t, attachmentOneResIdx, -1)
attachmentOneRes := memoOneRes.Attachments[attachmentOneResIdx]
require.NotNil(t, attachmentOneRes)
require.Equal(t, attachmentOne.GetName(), attachmentOneRes.GetName())
require.Equal(t, attachmentOne.GetContent(), attachmentOneRes.GetContent())
// attachment two
attachmentTwoResIdx := slices.IndexFunc(memoOneRes.Attachments, func(a *apiv1.Attachment) bool { return a.GetName() == attachmentTwo.GetName() })
require.NotEqual(t, attachmentTwoResIdx, -1)
attachmentTwoRes := memoOneRes.Attachments[attachmentTwoResIdx]
require.NotNil(t, attachmentTwoRes)
require.Equal(t, attachmentTwo.GetName(), attachmentTwoRes.GetName())
require.Equal(t, attachmentTwo.GetName(), attachmentTwoRes.GetName())
require.Equal(t, attachmentTwo.GetContent(), attachmentTwoRes.GetContent())
// verify memoOne's relations
require.Len(t, memoOneRes.Relations, 1)
memoOneExpectedRelation := &apiv1.MemoRelation{
Memo: &apiv1.MemoRelation_Memo{Name: memoTwo.GetName()},
RelatedMemo: &apiv1.MemoRelation_Memo{Name: memoOne.GetName()},
}
require.Equal(t, memoOneExpectedRelation.Memo.GetName(), memoOneRes.Relations[0].Memo.GetName())
require.Equal(t, memoOneExpectedRelation.RelatedMemo.GetName(), memoOneRes.Relations[0].RelatedMemo.GetName())
// ///////////////
// VERIFY MEMO TWO
// ///////////////
memoTwoResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoTwo.GetName() })
require.NotEqual(t, memoTwoResIdx, -1)
memoTwoRes := memos.Memos[memoTwoResIdx]
require.NotNil(t, memoTwoRes)
require.Equal(t, fmt.Sprintf("users/%d", userTwo.ID), memoTwoRes.GetCreator())
require.Equal(t, apiv1.Visibility_PROTECTED, memoTwoRes.GetVisibility())
require.Equal(t, memoTwo.Content, memoTwoRes.GetContent())
require.Empty(t, memoTwoRes.Attachments)
require.Len(t, memoTwoRes.Relations, 1)
require.Empty(t, memoTwoRes.Reactions)
// verify memoTwo's relations
require.Len(t, memoTwoRes.Relations, 1)
memoTwoExpectedRelation := &apiv1.MemoRelation{
Memo: &apiv1.MemoRelation_Memo{Name: memoTwo.GetName()},
RelatedMemo: &apiv1.MemoRelation_Memo{Name: memoOne.GetName()},
}
require.Equal(t, memoTwoExpectedRelation.Memo.GetName(), memoTwoRes.Relations[0].Memo.GetName())
require.Equal(t, memoTwoExpectedRelation.RelatedMemo.GetName(), memoTwoRes.Relations[0].RelatedMemo.GetName())
// ///////////////
// VERIFY MEMO THREE
// ///////////////
memoThreeResIdx := slices.IndexFunc(memos.Memos, func(m *apiv1.Memo) bool { return m.GetName() == memoThree.GetName() })
require.NotEqual(t, memoThreeResIdx, -1)
memoThreeRes := memos.Memos[memoThreeResIdx]
require.NotNil(t, memoThreeRes)
require.Equal(t, fmt.Sprintf("users/%d", userOne.ID), memoThreeRes.GetCreator())
require.Equal(t, apiv1.Visibility_PROTECTED, memoThreeRes.GetVisibility())
require.Equal(t, memoThree.Content, memoThreeRes.GetContent())
require.Empty(t, memoThreeRes.Attachments)
require.Empty(t, memoThreeRes.Relations)
require.Len(t, memoThreeRes.Reactions, 2)
// verify memoThree's reactions
require.Len(t, memoThreeRes.Reactions, 2)
// userOne's reaction
userOneReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%d", userOne.ID) })
require.NotEqual(t, userOneReactionIdx, -1)
userOneReaction := memoThreeRes.Reactions[userOneReactionIdx]
require.NotNil(t, userOneReaction)
require.Equal(t, "❤️", userOneReaction.ReactionType)
// userTwo's reaction
userTwoReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%d", userTwo.ID) })
require.NotEqual(t, userTwoReactionIdx, -1)
userTwoReaction := memoThreeRes.Reactions[userTwoReactionIdx]
require.NotNil(t, userTwoReaction)
require.Equal(t, "👍", userTwoReaction.ReactionType)
}
// TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments.
// This addresses issue #5483: https://github.com/usememos/memos/issues/5483
func TestCreateMemoWithCustomTimestamps(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test user
user, err := ts.CreateRegularUser(ctx, "test-user-timestamps")
require.NoError(t, err)
require.NotNil(t, user)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Define custom timestamps (January 1, 2020)
customCreateTime := time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC)
customUpdateTime := time.Date(2020, 1, 2, 12, 0, 0, 0, time.UTC)
customDisplayTime := time.Date(2020, 1, 3, 12, 0, 0, 0, time.UTC)
// Test 1: Create a memo with custom create_time
memoWithCreateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has a custom creation time",
Visibility: apiv1.Visibility_PRIVATE,
CreateTime: timestamppb.New(customCreateTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithCreateTime)
require.Equal(t, customCreateTime.Unix(), memoWithCreateTime.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
// Test 2: Create a memo with custom update_time
memoWithUpdateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has a custom update time",
Visibility: apiv1.Visibility_PRIVATE,
UpdateTime: timestamppb.New(customUpdateTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithUpdateTime)
require.Equal(t, customUpdateTime.Unix(), memoWithUpdateTime.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
// Test 3: Create a memo with custom display_time
// Note: display_time is computed from either created_ts or updated_ts based on instance setting
// Since DisplayWithUpdateTime defaults to false, display_time maps to created_ts
memoWithDisplayTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has a custom display time",
Visibility: apiv1.Visibility_PRIVATE,
DisplayTime: timestamppb.New(customDisplayTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithDisplayTime)
// Since DisplayWithUpdateTime is false by default, display_time sets created_ts
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.DisplayTime.AsTime().Unix(), "display_time should match the custom timestamp")
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.CreateTime.AsTime().Unix(), "create_time should also match since display_time maps to created_ts")
// Test 4: Create a memo with all custom timestamps
// When both display_time and create_time are provided, create_time takes precedence
memoWithAllTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has all custom timestamps",
Visibility: apiv1.Visibility_PRIVATE,
CreateTime: timestamppb.New(customCreateTime),
UpdateTime: timestamppb.New(customUpdateTime),
DisplayTime: timestamppb.New(customDisplayTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithAllTimestamps)
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
require.Equal(t, customUpdateTime.Unix(), memoWithAllTimestamps.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
// display_time is computed from created_ts when DisplayWithUpdateTime is false
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.DisplayTime.AsTime().Unix(), "display_time should be derived from create_time")
// Test 5: Create a comment (memo relation) with custom timestamps
parentMemo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This is the parent memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, parentMemo)
customCommentCreateTime := time.Date(2021, 6, 15, 10, 30, 0, 0, time.UTC)
comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{
Name: parentMemo.Name,
Comment: &apiv1.Memo{
Content: "This is a comment with custom create time",
Visibility: apiv1.Visibility_PRIVATE,
CreateTime: timestamppb.New(customCommentCreateTime),
},
})
require.NoError(t, err)
require.NotNil(t, comment)
require.Equal(t, customCommentCreateTime.Unix(), comment.CreateTime.AsTime().Unix(), "comment create_time should match the custom timestamp")
// Test 6: Verify that memos without custom timestamps still get auto-generated ones
memoWithoutTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has auto-generated timestamps",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memoWithoutTimestamps)
require.NotNil(t, memoWithoutTimestamps.CreateTime, "create_time should be auto-generated")
require.NotNil(t, memoWithoutTimestamps.UpdateTime, "update_time should be auto-generated")
require.True(t, time.Now().Unix()-memoWithoutTimestamps.CreateTime.AsTime().Unix() < 5, "create_time should be recent (within 5 seconds)")
}

View File

@@ -0,0 +1,194 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestDeleteMemoReaction(t *testing.T) {
ctx := context.Background()
t.Run("DeleteMemoReaction success by reaction owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction
reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// Delete reaction - should succeed
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.NoError(t, err)
})
t.Run("DeleteMemoReaction success by host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create memo by regular user
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction by regular user
reaction, err := ts.Service.UpsertMemoReaction(regularUserCtx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// Host user can delete reaction - should succeed
_, err = ts.Service.DeleteMemoReaction(hostCtx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.NoError(t, err)
})
t.Run("DeleteMemoReaction permission denied for non-owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user1
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
// Create user2
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
// Create memo by user1
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction by user1
reaction, err := ts.Service.UpsertMemoReaction(user1Ctx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// User2 tries to delete reaction - should fail with permission denied
_, err = ts.Service.DeleteMemoReaction(user2Ctx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("DeleteMemoReaction unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction
reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// Unauthenticated user tries to delete reaction - should fail
_, err = ts.Service.DeleteMemoReaction(ctx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
})
t.Run("DeleteMemoReaction not found returns permission denied", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Try to delete non-existent reaction - should fail with permission denied
// (not "not found" to avoid information disclosure)
// Use new nested resource format: memos/{memo}/reactions/{reaction}
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
Name: "memos/nonexistent/reactions/99999",
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
require.NotContains(t, err.Error(), "not found")
})
}

View File

@@ -0,0 +1,819 @@
package test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestListShortcuts(t *testing.T) {
ctx := context.Background()
t.Run("ListShortcuts success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// List shortcuts (should be empty initially)
req := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
resp, err := ts.Service.ListShortcuts(userCtx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Empty(t, resp.Shortcuts)
})
t.Run("ListShortcuts permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Set user1 context but try to list user2's shortcuts
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user2.ID),
}
_, err = ts.Service.ListShortcuts(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("ListShortcuts invalid parent format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.ListShortcutsRequest{
Parent: "invalid-parent-format",
}
_, err = ts.Service.ListShortcuts(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid user name")
})
t.Run("ListShortcuts unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.ListShortcutsRequest{
Parent: "users/1",
}
_, err := ts.Service.ListShortcuts(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
}
func TestGetShortcut(t *testing.T) {
ctx := context.Background()
t.Run("GetShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// First create a shortcut
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Test Shortcut",
Filter: "tag in [\"test\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Now get the shortcut
getReq := &v1pb.GetShortcutRequest{
Name: created.Name,
}
resp, err := ts.Service.GetShortcut(userCtx, getReq)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, created.Name, resp.Name)
require.Equal(t, "Test Shortcut", resp.Title)
require.Equal(t, "tag in [\"test\"]", resp.Filter)
})
t.Run("GetShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
Title: "User1 Shortcut",
Filter: "tag in [\"user1\"]",
},
}
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
require.NoError(t, err)
// Try to get shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
getReq := &v1pb.GetShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.GetShortcut(user2Ctx, getReq)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("GetShortcut invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.GetShortcutRequest{
Name: "invalid-shortcut-name",
}
_, err = ts.Service.GetShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid shortcut name")
})
t.Run("GetShortcut not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.GetShortcutRequest{
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
}
_, err = ts.Service.GetShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}
func TestCreateShortcut(t *testing.T) {
ctx := context.Background()
t.Run("CreateShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "My Shortcut",
Filter: "tag in [\"important\"]",
},
}
resp, err := ts.Service.CreateShortcut(userCtx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "My Shortcut", resp.Title)
require.Equal(t, "tag in [\"important\"]", resp.Filter)
require.Contains(t, resp.Name, fmt.Sprintf("users/%d/shortcuts/", user.ID))
// Verify the shortcut was created by listing
listReq := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Len(t, listResp.Shortcuts, 1)
require.Equal(t, "My Shortcut", listResp.Shortcuts[0].Title)
})
t.Run("CreateShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Set user1 context but try to create shortcut for user2
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user2.ID),
Shortcut: &v1pb.Shortcut{
Title: "Forbidden Shortcut",
Filter: "tag in [\"forbidden\"]",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("CreateShortcut invalid parent format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: "invalid-parent",
Shortcut: &v1pb.Shortcut{
Title: "Test Shortcut",
Filter: "tag in [\"test\"]",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid user name")
})
t.Run("CreateShortcut invalid filter", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Invalid Filter Shortcut",
Filter: "invalid||filter))syntax",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid filter")
})
t.Run("CreateShortcut missing title", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Filter: "tag in [\"test\"]",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "title is required")
})
}
func TestUpdateShortcut(t *testing.T) {
ctx := context.Background()
t.Run("UpdateShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Original Title",
Filter: "tag in [\"original\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Update the shortcut
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created.Name,
Title: "Updated Title",
Filter: "tag in [\"updated\"]",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "filter"},
},
}
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "Updated Title", updated.Title)
require.Equal(t, "tag in [\"updated\"]", updated.Filter)
require.Equal(t, created.Name, updated.Name)
})
t.Run("UpdateShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
Title: "User1 Shortcut",
Filter: "tag in [\"user1\"]",
},
}
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
require.NoError(t, err)
// Try to update shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created.Name,
Title: "Hacked Title",
Filter: "tag in [\"hacked\"]",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "filter"},
},
}
_, err = ts.Service.UpdateShortcut(user2Ctx, updateReq)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("UpdateShortcut missing update mask", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user and context for authentication
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: fmt.Sprintf("users/%d/shortcuts/test", user.ID),
Title: "Updated Title",
},
}
_, err = ts.Service.UpdateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "update mask is required")
})
t.Run("UpdateShortcut invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: "invalid-shortcut-name",
Title: "Updated Title",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title"},
},
}
_, err := ts.Service.UpdateShortcut(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid shortcut name")
})
t.Run("UpdateShortcut invalid filter", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Test Shortcut",
Filter: "tag in [\"test\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Try to update with invalid filter
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created.Name,
Filter: "invalid||filter))syntax",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"filter"},
},
}
_, err = ts.Service.UpdateShortcut(userCtx, updateReq)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid filter")
})
}
func TestDeleteShortcut(t *testing.T) {
ctx := context.Background()
t.Run("DeleteShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Shortcut to Delete",
Filter: "tag in [\"delete\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Delete the shortcut
deleteReq := &v1pb.DeleteShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
require.NoError(t, err)
// Verify deletion by listing shortcuts
listReq := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Empty(t, listResp.Shortcuts)
// Also verify by trying to get the deleted shortcut
getReq := &v1pb.GetShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.GetShortcut(userCtx, getReq)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
t.Run("DeleteShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
Title: "User1 Shortcut",
Filter: "tag in [\"user1\"]",
},
}
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
require.NoError(t, err)
// Try to delete shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
deleteReq := &v1pb.DeleteShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.DeleteShortcut(user2Ctx, deleteReq)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("DeleteShortcut invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.DeleteShortcutRequest{
Name: "invalid-shortcut-name",
}
_, err := ts.Service.DeleteShortcut(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid shortcut name")
})
t.Run("DeleteShortcut not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.DeleteShortcutRequest{
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
}
_, err = ts.Service.DeleteShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}
func TestShortcutFiltering(t *testing.T) {
ctx := context.Background()
t.Run("CreateShortcut with valid filters", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Test various valid filter formats
validFilters := []string{
"tag in [\"work\"]",
"content.contains(\"meeting\")",
"tag in [\"work\"] && content.contains(\"meeting\")",
"tag in [\"work\"] || tag in [\"personal\"]",
"creator_id == 1",
"visibility == \"PUBLIC\"",
"has_task_list == true",
"has_task_list == false",
}
for i, filter := range validFilters {
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Valid Filter " + string(rune(i)),
Filter: filter,
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.NoError(t, err, "Filter should be valid: %s", filter)
}
})
t.Run("CreateShortcut with invalid filters", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Test various invalid filter formats
invalidFilters := []string{
"tag in ", // incomplete expression
"invalid_field @in [\"value\"]", // unknown field
"tag in [\"work\"] &&", // incomplete expression
"tag in [\"work\"] || || tag in [\"test\"]", // double operator
"((tag in [\"work\"]", // unmatched parentheses
"tag in [\"work\"] && )", // mismatched parentheses
"tag == \"work\"", // wrong operator (== not supported for tags)
"tag in work", // missing brackets
}
for _, filter := range invalidFilters {
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Invalid Filter Test",
Filter: filter,
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err, "Filter should be invalid: %s", filter)
require.Contains(t, err.Error(), "invalid filter", "Error should mention invalid filter for: %s", filter)
}
})
}
func TestShortcutCRUDComplete(t *testing.T) {
ctx := context.Background()
t.Run("Complete CRUD lifecycle", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// 1. Create multiple shortcuts
shortcut1Req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Work Notes",
Filter: "tag in [\"work\"]",
},
}
shortcut2Req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Personal Notes",
Filter: "tag in [\"personal\"]",
},
}
created1, err := ts.Service.CreateShortcut(userCtx, shortcut1Req)
require.NoError(t, err)
require.Equal(t, "Work Notes", created1.Title)
created2, err := ts.Service.CreateShortcut(userCtx, shortcut2Req)
require.NoError(t, err)
require.Equal(t, "Personal Notes", created2.Title)
// 2. List shortcuts and verify both exist
listReq := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Len(t, listResp.Shortcuts, 2)
// 3. Get individual shortcuts
getReq1 := &v1pb.GetShortcutRequest{Name: created1.Name}
getResp1, err := ts.Service.GetShortcut(userCtx, getReq1)
require.NoError(t, err)
require.Equal(t, created1.Name, getResp1.Name)
require.Equal(t, "Work Notes", getResp1.Title)
getReq2 := &v1pb.GetShortcutRequest{Name: created2.Name}
getResp2, err := ts.Service.GetShortcut(userCtx, getReq2)
require.NoError(t, err)
require.Equal(t, created2.Name, getResp2.Name)
require.Equal(t, "Personal Notes", getResp2.Title)
// 4. Update one shortcut
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created1.Name,
Title: "Work & Meeting Notes",
Filter: "tag in [\"work\"] || tag in [\"meeting\"]",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "filter"},
},
}
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
require.NoError(t, err)
require.Equal(t, "Work & Meeting Notes", updated.Title)
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", updated.Filter)
// 5. Verify update by getting it again
getUpdatedReq := &v1pb.GetShortcutRequest{Name: created1.Name}
getUpdatedResp, err := ts.Service.GetShortcut(userCtx, getUpdatedReq)
require.NoError(t, err)
require.Equal(t, "Work & Meeting Notes", getUpdatedResp.Title)
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", getUpdatedResp.Filter)
// 6. Delete one shortcut
deleteReq := &v1pb.DeleteShortcutRequest{
Name: created2.Name,
}
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
require.NoError(t, err)
// 7. Verify deletion by listing (should only have 1 left)
finalListResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Len(t, finalListResp.Shortcuts, 1)
require.Equal(t, "Work & Meeting Notes", finalListResp.Shortcuts[0].Title)
// 8. Verify deleted shortcut can't be accessed
getDeletedReq := &v1pb.GetShortcutRequest{Name: created2.Name}
_, err = ts.Service.GetShortcut(userCtx, getDeletedReq)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}

View File

@@ -0,0 +1,86 @@
package test
import (
"context"
"testing"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/plugin/markdown"
"github.com/usememos/memos/server/auth"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
teststore "github.com/usememos/memos/store/test"
)
// TestService holds the test service setup for API v1 services.
type TestService struct {
Service *apiv1.APIV1Service
Store *store.Store
Profile *profile.Profile
Secret string
}
// NewTestService creates a new test service with SQLite database.
func NewTestService(t *testing.T) *TestService {
ctx := context.Background()
// Create a test store with SQLite
testStore := teststore.NewTestingStore(ctx, t)
// Create a test profile
testProfile := &profile.Profile{
Demo: true,
Version: "test-1.0.0",
InstanceURL: "http://localhost:8080",
Driver: "sqlite",
DSN: ":memory:",
}
// Create APIV1Service with nil grpcServer since we're testing direct calls
secret := "test-secret"
markdownService := markdown.NewService(
markdown.WithTagExtension(),
)
service := &apiv1.APIV1Service{
Secret: secret,
Profile: testProfile,
Store: testStore,
MarkdownService: markdownService,
}
return &TestService{
Service: service,
Store: testStore,
Profile: testProfile,
Secret: secret,
}
}
// Cleanup closes resources after test.
func (ts *TestService) Cleanup() {
ts.Store.Close()
}
// CreateHostUser creates an admin user for testing.
func (ts *TestService) CreateHostUser(ctx context.Context, username string) (*store.User, error) {
return ts.Store.CreateUser(ctx, &store.User{
Username: username,
Role: store.RoleAdmin,
Email: username + "@example.com",
})
}
// CreateRegularUser creates a regular user for testing.
func (ts *TestService) CreateRegularUser(ctx context.Context, username string) (*store.User, error) {
return ts.Store.CreateUser(ctx, &store.User{
Username: username,
Role: store.RoleUser,
Email: username + "@example.com",
})
}
// CreateUserContext creates a context with the given user's ID for authentication.
func (*TestService) CreateUserContext(ctx context.Context, userID int32) context.Context {
// Use the context key from the auth package
return context.WithValue(ctx, auth.UserIDContextKey, userID)
}

View File

@@ -0,0 +1,173 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestCreateUserRegistration(t *testing.T) {
ctx := context.Background()
t.Run("CreateUser success when registration enabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// User registration is enabled by default, no need to set it explicitly
// Create user without authentication - should succeed
_, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.NoError(t, err)
})
t.Run("CreateUser blocked when registration disabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user first so we're not in first-user setup mode
_, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Disable user registration
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
DisallowUserRegistration: true,
},
},
})
require.NoError(t, err)
// Try to create user without authentication - should fail
_, err = ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not allowed")
})
t.Run("CreateUser succeeds for superuser even when registration disabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Disable user registration
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
DisallowUserRegistration: true,
},
},
})
require.NoError(t, err)
// Host user can create users even when registration is disabled - should succeed
_, err = ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.NoError(t, err)
})
t.Run("CreateUser regular user cannot create users when registration disabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Disable user registration
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
DisallowUserRegistration: true,
},
},
})
require.NoError(t, err)
// Regular user tries to create user when registration is disabled - should fail
_, err = ts.Service.CreateUser(regularUserCtx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not allowed")
})
t.Run("CreateUser host can assign roles", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Host user can create user with specific role - should succeed
createdUser, err := ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newadmin",
Email: "newadmin@example.com",
Password: "password123",
Role: apiv1.User_ADMIN,
},
})
require.NoError(t, err)
require.NotNil(t, createdUser)
require.Equal(t, apiv1.User_ADMIN, createdUser.Role)
})
t.Run("CreateUser unauthenticated user can only create regular user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user first so we're not in first-user setup mode
_, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// User registration is enabled by default
// Unauthenticated user tries to create admin user - role should be ignored
createdUser, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "wannabeadmin",
Email: "wannabeadmin@example.com",
Password: "password123",
Role: apiv1.User_ADMIN, // This should be ignored
},
})
require.NoError(t, err)
require.NotNil(t, createdUser)
require.Equal(t, apiv1.User_USER, createdUser.Role, "Unauthenticated users can only create USER role")
})
}

View File

@@ -0,0 +1,105 @@
package test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func TestGetUserStats_TagCount(t *testing.T) {
ctx := context.Background()
// Create test service
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test host user
user, err := ts.CreateHostUser(ctx, "test_user")
require.NoError(t, err)
// Create user context for authentication
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a memo with a single tag
memo, err := ts.Store.CreateMemo(ctx, &store.Memo{
UID: "test-memo-1",
CreatorID: user.ID,
Content: "This is a test memo with #test tag",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: []string{"test"},
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Test GetUserStats
userName := fmt.Sprintf("users/%d", user.ID)
response, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
Name: userName,
})
require.NoError(t, err)
require.NotNil(t, response)
// Check that the tag count is exactly 1, not 2
require.Contains(t, response.TagCount, "test")
require.Equal(t, int32(1), response.TagCount["test"], "Tag count should be 1 for a single occurrence")
// Create another memo with the same tag
memo2, err := ts.Store.CreateMemo(ctx, &store.Memo{
UID: "test-memo-2",
CreatorID: user.ID,
Content: "Another memo with #test tag",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: []string{"test"},
},
})
require.NoError(t, err)
require.NotNil(t, memo2)
// Test GetUserStats again
response2, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
Name: userName,
})
require.NoError(t, err)
require.NotNil(t, response2)
// Check that the tag count is exactly 2, not 3
require.Contains(t, response2.TagCount, "test")
require.Equal(t, int32(2), response2.TagCount["test"], "Tag count should be 2 for two occurrences")
// Test with a new unique tag
memo3, err := ts.Store.CreateMemo(ctx, &store.Memo{
UID: "test-memo-3",
CreatorID: user.ID,
Content: "Memo with #unique tag",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: []string{"unique"},
},
})
require.NoError(t, err)
require.NotNil(t, memo3)
// Test GetUserStats for the new tag
response3, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
Name: userName,
})
require.NoError(t, err)
require.NotNil(t, response3)
// Check that the unique tag count is exactly 1
require.Contains(t, response3.TagCount, "unique")
require.Equal(t, int32(1), response3.TagCount["unique"], "New tag count should be 1 for first occurrence")
// The original test tag should still be 2
require.Contains(t, response3.TagCount, "test")
require.Equal(t, int32(2), response3.TagCount["test"], "Original tag count should remain 2")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,236 @@
package v1
import (
"context"
"fmt"
"time"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUserStatsRequest) (*v1pb.ListAllUserStatsResponse, error) {
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get instance memo related setting")
}
normalStatus := store.Normal
memoFind := &store.FindMemo{
// Exclude comments by default.
ExcludeComments: true,
ExcludeContent: true,
RowStatus: &normalStatus,
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
memoFind.VisibilityList = []store.Visibility{store.Public}
} else {
if memoFind.CreatorID == nil {
filter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
memoFind.Filters = append(memoFind.Filters, filter)
} else if *memoFind.CreatorID != currentUser.ID {
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
}
userMemoStatMap := make(map[int32]*v1pb.UserStats)
limit := 1000
offset := 0
memoFind.Limit = &limit
memoFind.Offset = &offset
for {
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
if len(memos) == 0 {
break
}
for _, memo := range memos {
// Initialize user stats if not exists
if _, exists := userMemoStatMap[memo.CreatorID]; !exists {
userMemoStatMap[memo.CreatorID] = &v1pb.UserStats{
Name: fmt.Sprintf("users/%d/stats", memo.CreatorID),
TagCount: make(map[string]int32),
MemoDisplayTimestamps: []*timestamppb.Timestamp{},
PinnedMemos: []string{},
MemoTypeStats: &v1pb.UserStats_MemoTypeStats{
LinkCount: 0,
CodeCount: 0,
TodoCount: 0,
UndoCount: 0,
},
}
}
stats := userMemoStatMap[memo.CreatorID]
// Add display timestamp
displayTs := memo.CreatedTs
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
stats.MemoDisplayTimestamps = append(stats.MemoDisplayTimestamps, timestamppb.New(time.Unix(displayTs, 0)))
// Count memo stats
stats.TotalMemoCount++
// Count tags and other properties
if memo.Payload != nil {
for _, tag := range memo.Payload.Tags {
stats.TagCount[tag]++
}
if memo.Payload.Property != nil {
if memo.Payload.Property.HasLink {
stats.MemoTypeStats.LinkCount++
}
if memo.Payload.Property.HasCode {
stats.MemoTypeStats.CodeCount++
}
if memo.Payload.Property.HasTaskList {
stats.MemoTypeStats.TodoCount++
}
if memo.Payload.Property.HasIncompleteTasks {
stats.MemoTypeStats.UndoCount++
}
}
}
// Track pinned memos
if memo.Pinned {
stats.PinnedMemos = append(stats.PinnedMemos, fmt.Sprintf("users/%d/memos/%d", memo.CreatorID, memo.ID))
}
}
offset += limit
}
userMemoStats := []*v1pb.UserStats{}
for _, userMemoStat := range userMemoStatMap {
userMemoStats = append(userMemoStats, userMemoStat)
}
response := &v1pb.ListAllUserStatsResponse{
Stats: userMemoStats,
}
return response, nil
}
func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserStatsRequest) (*v1pb.UserStats, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
normalStatus := store.Normal
memoFind := &store.FindMemo{
CreatorID: &userID,
// Exclude comments by default.
ExcludeComments: true,
ExcludeContent: true,
RowStatus: &normalStatus,
}
if currentUser == nil {
memoFind.VisibilityList = []store.Visibility{store.Public}
} else if currentUser.ID != userID {
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get instance memo related setting")
}
displayTimestamps := []*timestamppb.Timestamp{}
tagCount := make(map[string]int32)
linkCount := int32(0)
codeCount := int32(0)
todoCount := int32(0)
undoCount := int32(0)
pinnedMemos := []string{}
totalMemoCount := int32(0)
limit := 1000
offset := 0
memoFind.Limit = &limit
memoFind.Offset = &offset
for {
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
if len(memos) == 0 {
break
}
totalMemoCount += int32(len(memos))
for _, memo := range memos {
displayTs := memo.CreatedTs
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
displayTimestamps = append(displayTimestamps, timestamppb.New(time.Unix(displayTs, 0)))
// Count different memo types based on content.
if memo.Payload != nil {
for _, tag := range memo.Payload.Tags {
tagCount[tag]++
}
if memo.Payload.Property != nil {
if memo.Payload.Property.HasLink {
linkCount++
}
if memo.Payload.Property.HasCode {
codeCount++
}
if memo.Payload.Property.HasTaskList {
todoCount++
}
if memo.Payload.Property.HasIncompleteTasks {
undoCount++
}
}
}
if memo.Pinned {
pinnedMemos = append(pinnedMemos, fmt.Sprintf("users/%d/memos/%d", userID, memo.ID))
}
}
offset += limit
}
userStats := &v1pb.UserStats{
Name: fmt.Sprintf("users/%d/stats", userID),
MemoDisplayTimestamps: displayTimestamps,
TagCount: tagCount,
PinnedMemos: pinnedMemos,
TotalMemoCount: totalMemoCount,
MemoTypeStats: &v1pb.UserStats_MemoTypeStats{
LinkCount: linkCount,
CodeCount: codeCount,
TodoCount: todoCount,
UndoCount: undoCount,
},
}
return userStats, nil
}

175
server/router/api/v1/v1.go Normal file
View File

@@ -0,0 +1,175 @@
package v1
import (
"context"
"net/http"
"connectrpc.com/connect"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"golang.org/x/sync/semaphore"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/plugin/markdown"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
type APIV1Service struct {
v1pb.UnimplementedInstanceServiceServer
v1pb.UnimplementedAuthServiceServer
v1pb.UnimplementedUserServiceServer
v1pb.UnimplementedMemoServiceServer
v1pb.UnimplementedAttachmentServiceServer
v1pb.UnimplementedShortcutServiceServer
v1pb.UnimplementedActivityServiceServer
v1pb.UnimplementedIdentityProviderServiceServer
Secret string
Profile *profile.Profile
Store *store.Store
MarkdownService markdown.Service
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
thumbnailSemaphore *semaphore.Weighted
}
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store) *APIV1Service {
markdownService := markdown.NewService(
markdown.WithTagExtension(),
)
return &APIV1Service{
Secret: secret,
Profile: profile,
Store: store,
MarkdownService: markdownService,
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
}
}
// RegisterGateway registers the gRPC-Gateway and Connect handlers with the given Echo instance.
func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error {
// Auth middleware for gRPC-Gateway - runs after routing, has access to method name.
// Uses the same PublicMethods config as the Connect AuthInterceptor.
authenticator := auth.NewAuthenticator(s.Store, s.Secret)
gatewayAuthMiddleware := func(next runtime.HandlerFunc) runtime.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
ctx := r.Context()
// Get the RPC method name from context (set by grpc-gateway after routing)
rpcMethod, ok := runtime.RPCMethod(ctx)
// Extract credentials from HTTP headers
authHeader := r.Header.Get("Authorization")
result := authenticator.Authenticate(ctx, authHeader)
// Enforce authentication for non-public methods
// If rpcMethod cannot be determined, allow through, service layer will handle visibility checks
if result == nil && ok && !IsPublicMethod(rpcMethod) {
http.Error(w, `{"code": 16, "message": "authentication required"}`, http.StatusUnauthorized)
return
}
// Apply auth result to context (no-op when result is nil for public endpoints)
if result != nil {
ctx = auth.ApplyToContext(ctx, result)
r = r.WithContext(ctx)
}
next(w, r, pathParams)
}
}
// Create gRPC-Gateway mux with auth middleware.
gwMux := runtime.NewServeMux(
runtime.WithMiddlewares(gatewayAuthMiddleware),
)
if err := v1pb.RegisterInstanceServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterAuthServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterUserServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterMemoServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterAttachmentServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterShortcutServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterActivityServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterIdentityProviderServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
gwGroup := echoServer.Group("")
gwGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"},
}))
handler := echo.WrapHandler(gwMux)
gwGroup.Any("/api/v1/*", handler)
gwGroup.Any("/file/*", handler)
// Connect handlers for browser clients (replaces grpc-web).
logStacktraces := s.Profile.Demo
connectInterceptors := connect.WithInterceptors(
NewMetadataInterceptor(), // Convert HTTP headers to gRPC metadata first
NewLoggingInterceptor(logStacktraces),
NewRecoveryInterceptor(logStacktraces),
NewAuthInterceptor(s.Store, s.Secret),
)
connectMux := http.NewServeMux()
connectHandler := NewConnectServiceHandler(s)
connectHandler.RegisterConnectHandlers(connectMux, connectInterceptors)
// Wrap with CORS for browser access
corsHandler := middleware.CORSWithConfig(middleware.CORSConfig{
UnsafeAllowOriginFunc: func(_ *echo.Context, origin string) (string, bool, error) {
return origin, true, nil
},
AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions},
AllowHeaders: []string{"*"},
AllowCredentials: true,
})
connectGroup := echoServer.Group("", corsHandler)
connectGroup.Any("/memos.api.v1.*", echo.WrapHandler(connectMux))
// Register AI REST endpoints (direct HTTP, no Connect/gRPC required)
// Apply auth middleware so user context is populated (tries Bearer token then cookie)
aiAuthenticator := auth.NewAuthenticator(s.Store, s.Secret)
aiGroup := echoServer.Group("", func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
ctx := c.Request().Context()
authHeader := c.Request().Header.Get("Authorization")
cookieHeader := c.Request().Header.Get("Cookie")
result := aiAuthenticator.Authenticate(ctx, authHeader)
if result == nil && cookieHeader != "" {
// Try cookie-based auth (refresh token)
user, err := aiAuthenticator.AuthenticateToUser(ctx, authHeader, cookieHeader)
if err == nil && user != nil {
ctx = auth.SetUserInContext(ctx, user, "")
c.SetRequest(c.Request().WithContext(ctx))
return next(c)
}
}
if result != nil {
ctx = auth.ApplyToContext(ctx, result)
c.SetRequest(c.Request().WithContext(ctx))
}
return next(c)
}
})
s.RegisterAIHTTPHandlers(aiGroup)
return nil
}