package v1 import ( "context" "encoding/json" "fmt" "net/http" "strings" "time" "github.com/labstack/echo/v5" groqai "github.com/usememos/memos/plugin/ai/groq" ollamaai "github.com/usememos/memos/plugin/ai/ollama" "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" ) const aiSettingsStoreKey = "ai_settings_v1" // RegisterAIHTTPHandlers registers plain HTTP JSON handlers for the AI service. // Uses direct raw DB storage to avoid proto generation dependency. func (s *APIV1Service) RegisterAIHTTPHandlers(g *echo.Group) { g.POST("/api/ai/settings", s.handleGetAISettings) g.POST("/api/ai/settings/update", s.handleUpdateAISettings) g.POST("/api/ai/complete", s.handleGenerateCompletion) g.POST("/api/ai/test", s.handleTestAIProvider) g.POST("/api/ai/models/ollama", s.handleListOllamaModels) } type aiSettingsJSON struct { Groq *groqSettingsJSON `json:"groq,omitempty"` Ollama *ollamaSettingsJSON `json:"ollama,omitempty"` } type groqSettingsJSON struct { Enabled bool `json:"enabled"` APIKey string `json:"apiKey"` DefaultModel string `json:"defaultModel"` } type ollamaSettingsJSON struct { Enabled bool `json:"enabled"` Host string `json:"host"` DefaultModel string `json:"defaultModel"` TimeoutSeconds int `json:"timeoutSeconds,omitempty"` } func defaultAISettings() *aiSettingsJSON { return &aiSettingsJSON{ Groq: &groqSettingsJSON{DefaultModel: "llama-3.1-8b-instant"}, Ollama: &ollamaSettingsJSON{Host: "http://localhost:11434", DefaultModel: "llama3", TimeoutSeconds: 120}, } } func (s *APIV1Service) loadAISettings(ctx context.Context) (*aiSettingsJSON, error) { settings := defaultAISettings() list, err := s.Store.GetDriver().ListInstanceSettings(ctx, &store.FindInstanceSetting{Name: aiSettingsStoreKey}) if err != nil || len(list) == 0 { return settings, nil } _ = json.Unmarshal([]byte(list[0].Value), settings) return settings, nil } func (s *APIV1Service) saveAISettings(ctx context.Context, settings *aiSettingsJSON) error { b, err := json.Marshal(settings) if err != nil { return err } _, err = s.Store.GetDriver().UpsertInstanceSetting(ctx, &store.InstanceSetting{ Name: aiSettingsStoreKey, Value: string(b), }) return err } func (s *APIV1Service) handleGetAISettings(c *echo.Context) error { ctx := c.Request().Context() if auth.GetUserID(ctx) == 0 { return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"}) } settings, err := s.loadAISettings(ctx) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, settings) } func (s *APIV1Service) handleUpdateAISettings(c *echo.Context) error { ctx := c.Request().Context() if auth.GetUserID(ctx) == 0 { return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"}) } var req aiSettingsJSON if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"}) } if err := s.saveAISettings(ctx, &req); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, req) } type generateCompletionRequest struct { Provider int `json:"provider"` Prompt string `json:"prompt"` Model string `json:"model"` Temperature float64 `json:"temperature"` MaxTokens int `json:"maxTokens"` AutoTag bool `json:"autoTag,omitempty"` SpellCheck bool `json:"spellCheck,omitempty"` } type generateCompletionResponse struct { Text string `json:"text"` ModelUsed string `json:"modelUsed"` PromptTokens int `json:"promptTokens"` CompletionTokens int `json:"completionTokens"` TotalTokens int `json:"totalTokens"` } func (s *APIV1Service) handleGenerateCompletion(c *echo.Context) error { ctx := c.Request().Context() if auth.GetUserID(ctx) == 0 { return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"}) } var req generateCompletionRequest if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"}) } if req.Prompt == "" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "prompt is required"}) } settings, err := s.loadAISettings(ctx) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } temp := float32(0.7) if req.Temperature > 0 { temp = float32(req.Temperature) } maxTokens := 1000 if req.MaxTokens > 0 { maxTokens = req.MaxTokens } switch req.Provider { case 1: // Groq if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Groq is not enabled or API key is missing. Configure it in Settings > AI."}) } model := req.Model if model == "" { model = settings.Groq.DefaultModel } if model == "" { model = "llama-3.1-8b-instant" } client := groqai.NewGroqClient(groqai.GroqConfig{ APIKey: settings.Groq.APIKey, DefaultModel: model, }) resp, err := client.GenerateCompletion(ctx, groqai.CompletionRequest{ Model: model, Prompt: req.Prompt, Temperature: temp, MaxTokens: maxTokens, }) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } // Apply post-processing if requested text := resp.Text if req.SpellCheck { text = s.correctSpelling(ctx, text) } if req.AutoTag { text = s.addAutoTags(ctx, text) } return c.JSON(http.StatusOK, generateCompletionResponse{ Text: text, ModelUsed: resp.Model, PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens, }) case 2: // Ollama if settings.Ollama == nil || !settings.Ollama.Enabled { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."}) } model := req.Model if model == "" { model = settings.Ollama.DefaultModel } client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{ Host: settings.Ollama.Host, DefaultModel: model, Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second, }) resp, err := client.GenerateCompletion(ctx, ollamaai.CompletionRequest{ Model: model, Prompt: req.Prompt, Temperature: temp, MaxTokens: maxTokens, }) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } // Apply post-processing if requested text := resp.Text if req.SpellCheck { text = s.correctSpelling(ctx, text) } if req.AutoTag { text = s.addAutoTags(ctx, text) } return c.JSON(http.StatusOK, generateCompletionResponse{ Text: text, ModelUsed: resp.Model, PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens, }) default: return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider. Use 1 for Groq, 2 for Ollama"}) } } func (s *APIV1Service) handleTestAIProvider(c *echo.Context) error { ctx := c.Request().Context() if auth.GetUserID(ctx) == 0 { return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"}) } var req struct { Provider int `json:"provider"` } if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"}) } settings, err := s.loadAISettings(ctx) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } switch req.Provider { case 1: if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" { return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Groq not enabled or API key missing"}) } client := groqai.NewGroqClient(groqai.GroqConfig{APIKey: settings.Groq.APIKey, DefaultModel: settings.Groq.DefaultModel}) if err := client.TestConnection(ctx); err != nil { return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()}) } return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Groq connected successfully"}) case 2: if settings.Ollama == nil || !settings.Ollama.Enabled { return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Ollama not enabled"}) } client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{ Host: settings.Ollama.Host, Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second, }) if err := client.TestConnection(ctx); err != nil { return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()}) } return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Ollama connected successfully"}) default: return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider"}) } } // correctSpelling applies basic spelling corrections to the text func (s *APIV1Service) correctSpelling(ctx context.Context, text string) string { // TODO: Implement actual spell checking using a spell checker library // For now, return the text as-is return text } // addAutoTags analyzes the text and adds relevant #tags func (s *APIV1Service) addAutoTags(ctx context.Context, text string) string { // Common programming languages and technologies tagMap := map[string][]string{ "javascript": {"javascript", "js"}, "python": {"python", "py"}, "java": {"java"}, "go": {"go", "golang"}, "rust": {"rust"}, "dart": {"dart"}, "flutter": {"flutter"}, "react": {"react", "reactjs"}, "vue": {"vue", "vuejs"}, "angular": {"angular"}, "node": {"nodejs", "node"}, "express": {"express"}, "mongodb": {"mongodb"}, "postgresql": {"postgresql", "postgres"}, "mysql": {"mysql"}, "redis": {"redis"}, "docker": {"docker"}, "kubernetes": {"kubernetes", "k8s"}, "aws": {"aws", "amazon"}, "azure": {"azure"}, "gcp": {"gcp", "google-cloud"}, "git": {"git"}, "github": {"github"}, "gitlab": {"gitlab"}, "ci/cd": {"ci-cd"}, "testing": {"testing"}, "api": {"api"}, "rest": {"rest", "api"}, "graphql": {"graphql"}, "database": {"database", "db"}, "frontend": {"frontend"}, "backend": {"backend"}, "fullstack": {"fullstack"}, "mobile": {"mobile"}, "web": {"web"}, "devops": {"devops"}, "security": {"security"}, "machine learning": {"ml", "machine-learning"}, "ai": {"ai", "artificial-intelligence"}, "blockchain": {"blockchain"}, "crypto": {"crypto", "cryptocurrency"}, } // Convert text to lowercase for matching lowerText := strings.ToLower(text) var tags []string tagSet := make(map[string]bool) // Check for programming languages and technologies for keyword, tagList := range tagMap { if strings.Contains(lowerText, keyword) { for _, tag := range tagList { if !tagSet[tag] { tags = append(tags, tag) tagSet[tag] = true } } } } // Add tags to the end of the text if any were found if len(tags) > 0 { tagString := "\n\n" for i, tag := range tags { if i > 0 { tagString += " " } tagString += "#" + tag } text += tagString } return text } func (s *APIV1Service) handleListOllamaModels(c *echo.Context) error { ctx := c.Request().Context() if auth.GetUserID(ctx) == 0 { return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"}) } settings, err := s.loadAISettings(ctx) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } if settings.Ollama == nil || !settings.Ollama.Enabled { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."}) } client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{ Host: settings.Ollama.Host, Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second, }) models, err := client.ListModels(ctx) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": fmt.Sprintf("failed to list Ollama models: %v", err)}) } // Convert to the same format as frontend expects var modelList []map[string]string for _, model := range models { modelList = append(modelList, map[string]string{ "id": model, "name": model, }) } return c.JSON(http.StatusOK, map[string]interface{}{ "models": modelList, }) }