Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
398 lines
13 KiB
Go
398 lines
13 KiB
Go
package v1
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v5"
|
|
groqai "github.com/usememos/memos/plugin/ai/groq"
|
|
ollamaai "github.com/usememos/memos/plugin/ai/ollama"
|
|
"github.com/usememos/memos/server/auth"
|
|
"github.com/usememos/memos/store"
|
|
)
|
|
|
|
const aiSettingsStoreKey = "ai_settings_v1"
|
|
|
|
// RegisterAIHTTPHandlers registers plain HTTP JSON handlers for the AI service.
|
|
// Uses direct raw DB storage to avoid proto generation dependency.
|
|
func (s *APIV1Service) RegisterAIHTTPHandlers(g *echo.Group) {
|
|
g.POST("/api/ai/settings", s.handleGetAISettings)
|
|
g.POST("/api/ai/settings/update", s.handleUpdateAISettings)
|
|
g.POST("/api/ai/complete", s.handleGenerateCompletion)
|
|
g.POST("/api/ai/test", s.handleTestAIProvider)
|
|
g.POST("/api/ai/models/ollama", s.handleListOllamaModels)
|
|
}
|
|
|
|
type aiSettingsJSON struct {
|
|
Groq *groqSettingsJSON `json:"groq,omitempty"`
|
|
Ollama *ollamaSettingsJSON `json:"ollama,omitempty"`
|
|
}
|
|
|
|
type groqSettingsJSON struct {
|
|
Enabled bool `json:"enabled"`
|
|
APIKey string `json:"apiKey"`
|
|
DefaultModel string `json:"defaultModel"`
|
|
}
|
|
|
|
type ollamaSettingsJSON struct {
|
|
Enabled bool `json:"enabled"`
|
|
Host string `json:"host"`
|
|
DefaultModel string `json:"defaultModel"`
|
|
TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
|
|
}
|
|
|
|
func defaultAISettings() *aiSettingsJSON {
|
|
return &aiSettingsJSON{
|
|
Groq: &groqSettingsJSON{DefaultModel: "llama-3.1-8b-instant"},
|
|
Ollama: &ollamaSettingsJSON{Host: "http://localhost:11434", DefaultModel: "llama3", TimeoutSeconds: 120},
|
|
}
|
|
}
|
|
|
|
func (s *APIV1Service) loadAISettings(ctx context.Context) (*aiSettingsJSON, error) {
|
|
settings := defaultAISettings()
|
|
list, err := s.Store.GetDriver().ListInstanceSettings(ctx, &store.FindInstanceSetting{Name: aiSettingsStoreKey})
|
|
if err != nil || len(list) == 0 {
|
|
return settings, nil
|
|
}
|
|
_ = json.Unmarshal([]byte(list[0].Value), settings)
|
|
return settings, nil
|
|
}
|
|
|
|
func (s *APIV1Service) saveAISettings(ctx context.Context, settings *aiSettingsJSON) error {
|
|
b, err := json.Marshal(settings)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.Store.GetDriver().UpsertInstanceSetting(ctx, &store.InstanceSetting{
|
|
Name: aiSettingsStoreKey,
|
|
Value: string(b),
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (s *APIV1Service) handleGetAISettings(c *echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
if auth.GetUserID(ctx) == 0 {
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
|
}
|
|
settings, err := s.loadAISettings(ctx)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
|
}
|
|
return c.JSON(http.StatusOK, settings)
|
|
}
|
|
|
|
func (s *APIV1Service) handleUpdateAISettings(c *echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
if auth.GetUserID(ctx) == 0 {
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
|
}
|
|
var req aiSettingsJSON
|
|
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
|
}
|
|
if err := s.saveAISettings(ctx, &req); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
|
}
|
|
return c.JSON(http.StatusOK, req)
|
|
}
|
|
|
|
type generateCompletionRequest struct {
|
|
Provider int `json:"provider"`
|
|
Prompt string `json:"prompt"`
|
|
Model string `json:"model"`
|
|
Temperature float64 `json:"temperature"`
|
|
MaxTokens int `json:"maxTokens"`
|
|
AutoTag bool `json:"autoTag,omitempty"`
|
|
SpellCheck bool `json:"spellCheck,omitempty"`
|
|
}
|
|
|
|
type generateCompletionResponse struct {
|
|
Text string `json:"text"`
|
|
ModelUsed string `json:"modelUsed"`
|
|
PromptTokens int `json:"promptTokens"`
|
|
CompletionTokens int `json:"completionTokens"`
|
|
TotalTokens int `json:"totalTokens"`
|
|
}
|
|
|
|
func (s *APIV1Service) handleGenerateCompletion(c *echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
if auth.GetUserID(ctx) == 0 {
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
|
}
|
|
var req generateCompletionRequest
|
|
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
|
}
|
|
if req.Prompt == "" {
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "prompt is required"})
|
|
}
|
|
|
|
settings, err := s.loadAISettings(ctx)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
|
}
|
|
|
|
temp := float32(0.7)
|
|
if req.Temperature > 0 {
|
|
temp = float32(req.Temperature)
|
|
}
|
|
maxTokens := 1000
|
|
if req.MaxTokens > 0 {
|
|
maxTokens = req.MaxTokens
|
|
}
|
|
|
|
switch req.Provider {
|
|
case 1: // Groq
|
|
if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" {
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Groq is not enabled or API key is missing. Configure it in Settings > AI."})
|
|
}
|
|
model := req.Model
|
|
if model == "" {
|
|
model = settings.Groq.DefaultModel
|
|
}
|
|
if model == "" {
|
|
model = "llama-3.1-8b-instant"
|
|
}
|
|
client := groqai.NewGroqClient(groqai.GroqConfig{
|
|
APIKey: settings.Groq.APIKey,
|
|
DefaultModel: model,
|
|
})
|
|
resp, err := client.GenerateCompletion(ctx, groqai.CompletionRequest{
|
|
Model: model,
|
|
Prompt: req.Prompt,
|
|
Temperature: temp,
|
|
MaxTokens: maxTokens,
|
|
})
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
|
}
|
|
|
|
// Apply post-processing if requested
|
|
text := resp.Text
|
|
if req.SpellCheck {
|
|
text = s.correctSpelling(ctx, text)
|
|
}
|
|
if req.AutoTag {
|
|
text = s.addAutoTags(ctx, text)
|
|
}
|
|
|
|
return c.JSON(http.StatusOK, generateCompletionResponse{
|
|
Text: text, ModelUsed: resp.Model,
|
|
PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens,
|
|
})
|
|
|
|
case 2: // Ollama
|
|
if settings.Ollama == nil || !settings.Ollama.Enabled {
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."})
|
|
}
|
|
model := req.Model
|
|
if model == "" {
|
|
model = settings.Ollama.DefaultModel
|
|
}
|
|
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
|
|
Host: settings.Ollama.Host,
|
|
DefaultModel: model,
|
|
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
|
|
})
|
|
resp, err := client.GenerateCompletion(ctx, ollamaai.CompletionRequest{
|
|
Model: model,
|
|
Prompt: req.Prompt,
|
|
Temperature: temp,
|
|
MaxTokens: maxTokens,
|
|
})
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
|
}
|
|
|
|
// Apply post-processing if requested
|
|
text := resp.Text
|
|
if req.SpellCheck {
|
|
text = s.correctSpelling(ctx, text)
|
|
}
|
|
if req.AutoTag {
|
|
text = s.addAutoTags(ctx, text)
|
|
}
|
|
|
|
return c.JSON(http.StatusOK, generateCompletionResponse{
|
|
Text: text, ModelUsed: resp.Model,
|
|
PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens,
|
|
})
|
|
|
|
default:
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider. Use 1 for Groq, 2 for Ollama"})
|
|
}
|
|
}
|
|
|
|
func (s *APIV1Service) handleTestAIProvider(c *echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
if auth.GetUserID(ctx) == 0 {
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
|
}
|
|
var req struct {
|
|
Provider int `json:"provider"`
|
|
}
|
|
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
|
}
|
|
settings, err := s.loadAISettings(ctx)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
|
}
|
|
|
|
switch req.Provider {
|
|
case 1:
|
|
if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" {
|
|
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Groq not enabled or API key missing"})
|
|
}
|
|
client := groqai.NewGroqClient(groqai.GroqConfig{APIKey: settings.Groq.APIKey, DefaultModel: settings.Groq.DefaultModel})
|
|
if err := client.TestConnection(ctx); err != nil {
|
|
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()})
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Groq connected successfully"})
|
|
case 2:
|
|
if settings.Ollama == nil || !settings.Ollama.Enabled {
|
|
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Ollama not enabled"})
|
|
}
|
|
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
|
|
Host: settings.Ollama.Host,
|
|
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
|
|
})
|
|
if err := client.TestConnection(ctx); err != nil {
|
|
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()})
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Ollama connected successfully"})
|
|
default:
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider"})
|
|
}
|
|
}
|
|
|
|
// correctSpelling applies basic spelling corrections to the text
|
|
func (s *APIV1Service) correctSpelling(ctx context.Context, text string) string {
|
|
// TODO: Implement actual spell checking using a spell checker library
|
|
// For now, return the text as-is
|
|
return text
|
|
}
|
|
|
|
// addAutoTags analyzes the text and adds relevant #tags
|
|
func (s *APIV1Service) addAutoTags(ctx context.Context, text string) string {
|
|
// Common programming languages and technologies
|
|
tagMap := map[string][]string{
|
|
"javascript": {"javascript", "js"},
|
|
"python": {"python", "py"},
|
|
"java": {"java"},
|
|
"go": {"go", "golang"},
|
|
"rust": {"rust"},
|
|
"dart": {"dart"},
|
|
"flutter": {"flutter"},
|
|
"react": {"react", "reactjs"},
|
|
"vue": {"vue", "vuejs"},
|
|
"angular": {"angular"},
|
|
"node": {"nodejs", "node"},
|
|
"express": {"express"},
|
|
"mongodb": {"mongodb"},
|
|
"postgresql": {"postgresql", "postgres"},
|
|
"mysql": {"mysql"},
|
|
"redis": {"redis"},
|
|
"docker": {"docker"},
|
|
"kubernetes": {"kubernetes", "k8s"},
|
|
"aws": {"aws", "amazon"},
|
|
"azure": {"azure"},
|
|
"gcp": {"gcp", "google-cloud"},
|
|
"git": {"git"},
|
|
"github": {"github"},
|
|
"gitlab": {"gitlab"},
|
|
"ci/cd": {"ci-cd"},
|
|
"testing": {"testing"},
|
|
"api": {"api"},
|
|
"rest": {"rest", "api"},
|
|
"graphql": {"graphql"},
|
|
"database": {"database", "db"},
|
|
"frontend": {"frontend"},
|
|
"backend": {"backend"},
|
|
"fullstack": {"fullstack"},
|
|
"mobile": {"mobile"},
|
|
"web": {"web"},
|
|
"devops": {"devops"},
|
|
"security": {"security"},
|
|
"machine learning": {"ml", "machine-learning"},
|
|
"ai": {"ai", "artificial-intelligence"},
|
|
"blockchain": {"blockchain"},
|
|
"crypto": {"crypto", "cryptocurrency"},
|
|
}
|
|
|
|
// Convert text to lowercase for matching
|
|
lowerText := strings.ToLower(text)
|
|
var tags []string
|
|
tagSet := make(map[string]bool)
|
|
|
|
// Check for programming languages and technologies
|
|
for keyword, tagList := range tagMap {
|
|
if strings.Contains(lowerText, keyword) {
|
|
for _, tag := range tagList {
|
|
if !tagSet[tag] {
|
|
tags = append(tags, tag)
|
|
tagSet[tag] = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add tags to the end of the text if any were found
|
|
if len(tags) > 0 {
|
|
tagString := "\n\n"
|
|
for i, tag := range tags {
|
|
if i > 0 {
|
|
tagString += " "
|
|
}
|
|
tagString += "#" + tag
|
|
}
|
|
text += tagString
|
|
}
|
|
|
|
return text
|
|
}
|
|
|
|
func (s *APIV1Service) handleListOllamaModels(c *echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
if auth.GetUserID(ctx) == 0 {
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
|
}
|
|
|
|
settings, err := s.loadAISettings(ctx)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
|
}
|
|
|
|
if settings.Ollama == nil || !settings.Ollama.Enabled {
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."})
|
|
}
|
|
|
|
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
|
|
Host: settings.Ollama.Host,
|
|
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
|
|
})
|
|
|
|
models, err := client.ListModels(ctx)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": fmt.Sprintf("failed to list Ollama models: %v", err)})
|
|
}
|
|
|
|
// Convert to the same format as frontend expects
|
|
var modelList []map[string]string
|
|
for _, model := range models {
|
|
modelList = append(modelList, map[string]string{
|
|
"id": model,
|
|
"name": model,
|
|
})
|
|
}
|
|
|
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
|
"models": modelList,
|
|
})
|
|
}
|