first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
This commit is contained in:
397
server/router/api/v1/ai_http.go
Normal file
397
server/router/api/v1/ai_http.go
Normal file
@@ -0,0 +1,397 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
groqai "github.com/usememos/memos/plugin/ai/groq"
|
||||
ollamaai "github.com/usememos/memos/plugin/ai/ollama"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const aiSettingsStoreKey = "ai_settings_v1"
|
||||
|
||||
// RegisterAIHTTPHandlers registers plain HTTP JSON handlers for the AI service.
|
||||
// Uses direct raw DB storage to avoid proto generation dependency.
|
||||
func (s *APIV1Service) RegisterAIHTTPHandlers(g *echo.Group) {
|
||||
g.POST("/api/ai/settings", s.handleGetAISettings)
|
||||
g.POST("/api/ai/settings/update", s.handleUpdateAISettings)
|
||||
g.POST("/api/ai/complete", s.handleGenerateCompletion)
|
||||
g.POST("/api/ai/test", s.handleTestAIProvider)
|
||||
g.POST("/api/ai/models/ollama", s.handleListOllamaModels)
|
||||
}
|
||||
|
||||
type aiSettingsJSON struct {
|
||||
Groq *groqSettingsJSON `json:"groq,omitempty"`
|
||||
Ollama *ollamaSettingsJSON `json:"ollama,omitempty"`
|
||||
}
|
||||
|
||||
type groqSettingsJSON struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey"`
|
||||
DefaultModel string `json:"defaultModel"`
|
||||
}
|
||||
|
||||
type ollamaSettingsJSON struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Host string `json:"host"`
|
||||
DefaultModel string `json:"defaultModel"`
|
||||
TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
|
||||
}
|
||||
|
||||
func defaultAISettings() *aiSettingsJSON {
|
||||
return &aiSettingsJSON{
|
||||
Groq: &groqSettingsJSON{DefaultModel: "llama-3.1-8b-instant"},
|
||||
Ollama: &ollamaSettingsJSON{Host: "http://localhost:11434", DefaultModel: "llama3", TimeoutSeconds: 120},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIV1Service) loadAISettings(ctx context.Context) (*aiSettingsJSON, error) {
|
||||
settings := defaultAISettings()
|
||||
list, err := s.Store.GetDriver().ListInstanceSettings(ctx, &store.FindInstanceSetting{Name: aiSettingsStoreKey})
|
||||
if err != nil || len(list) == 0 {
|
||||
return settings, nil
|
||||
}
|
||||
_ = json.Unmarshal([]byte(list[0].Value), settings)
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) saveAISettings(ctx context.Context, settings *aiSettingsJSON) error {
|
||||
b, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.Store.GetDriver().UpsertInstanceSetting(ctx, &store.InstanceSetting{
|
||||
Name: aiSettingsStoreKey,
|
||||
Value: string(b),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *APIV1Service) handleGetAISettings(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
if auth.GetUserID(ctx) == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
settings, err := s.loadAISettings(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, settings)
|
||||
}
|
||||
|
||||
func (s *APIV1Service) handleUpdateAISettings(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
if auth.GetUserID(ctx) == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
var req aiSettingsJSON
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||
}
|
||||
if err := s.saveAISettings(ctx, &req); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, req)
|
||||
}
|
||||
|
||||
type generateCompletionRequest struct {
|
||||
Provider int `json:"provider"`
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
MaxTokens int `json:"maxTokens"`
|
||||
AutoTag bool `json:"autoTag,omitempty"`
|
||||
SpellCheck bool `json:"spellCheck,omitempty"`
|
||||
}
|
||||
|
||||
type generateCompletionResponse struct {
|
||||
Text string `json:"text"`
|
||||
ModelUsed string `json:"modelUsed"`
|
||||
PromptTokens int `json:"promptTokens"`
|
||||
CompletionTokens int `json:"completionTokens"`
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
}
|
||||
|
||||
func (s *APIV1Service) handleGenerateCompletion(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
if auth.GetUserID(ctx) == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
var req generateCompletionRequest
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "prompt is required"})
|
||||
}
|
||||
|
||||
settings, err := s.loadAISettings(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
temp := float32(0.7)
|
||||
if req.Temperature > 0 {
|
||||
temp = float32(req.Temperature)
|
||||
}
|
||||
maxTokens := 1000
|
||||
if req.MaxTokens > 0 {
|
||||
maxTokens = req.MaxTokens
|
||||
}
|
||||
|
||||
switch req.Provider {
|
||||
case 1: // Groq
|
||||
if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Groq is not enabled or API key is missing. Configure it in Settings > AI."})
|
||||
}
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = settings.Groq.DefaultModel
|
||||
}
|
||||
if model == "" {
|
||||
model = "llama-3.1-8b-instant"
|
||||
}
|
||||
client := groqai.NewGroqClient(groqai.GroqConfig{
|
||||
APIKey: settings.Groq.APIKey,
|
||||
DefaultModel: model,
|
||||
})
|
||||
resp, err := client.GenerateCompletion(ctx, groqai.CompletionRequest{
|
||||
Model: model,
|
||||
Prompt: req.Prompt,
|
||||
Temperature: temp,
|
||||
MaxTokens: maxTokens,
|
||||
})
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Apply post-processing if requested
|
||||
text := resp.Text
|
||||
if req.SpellCheck {
|
||||
text = s.correctSpelling(ctx, text)
|
||||
}
|
||||
if req.AutoTag {
|
||||
text = s.addAutoTags(ctx, text)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, generateCompletionResponse{
|
||||
Text: text, ModelUsed: resp.Model,
|
||||
PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens,
|
||||
})
|
||||
|
||||
case 2: // Ollama
|
||||
if settings.Ollama == nil || !settings.Ollama.Enabled {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."})
|
||||
}
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = settings.Ollama.DefaultModel
|
||||
}
|
||||
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
|
||||
Host: settings.Ollama.Host,
|
||||
DefaultModel: model,
|
||||
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
|
||||
})
|
||||
resp, err := client.GenerateCompletion(ctx, ollamaai.CompletionRequest{
|
||||
Model: model,
|
||||
Prompt: req.Prompt,
|
||||
Temperature: temp,
|
||||
MaxTokens: maxTokens,
|
||||
})
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Apply post-processing if requested
|
||||
text := resp.Text
|
||||
if req.SpellCheck {
|
||||
text = s.correctSpelling(ctx, text)
|
||||
}
|
||||
if req.AutoTag {
|
||||
text = s.addAutoTags(ctx, text)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, generateCompletionResponse{
|
||||
Text: text, ModelUsed: resp.Model,
|
||||
PromptTokens: resp.PromptTokens, CompletionTokens: resp.CompletionTokens, TotalTokens: resp.TotalTokens,
|
||||
})
|
||||
|
||||
default:
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider. Use 1 for Groq, 2 for Ollama"})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIV1Service) handleTestAIProvider(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
if auth.GetUserID(ctx) == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
var req struct {
|
||||
Provider int `json:"provider"`
|
||||
}
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||
}
|
||||
settings, err := s.loadAISettings(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
switch req.Provider {
|
||||
case 1:
|
||||
if settings.Groq == nil || !settings.Groq.Enabled || settings.Groq.APIKey == "" {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Groq not enabled or API key missing"})
|
||||
}
|
||||
client := groqai.NewGroqClient(groqai.GroqConfig{APIKey: settings.Groq.APIKey, DefaultModel: settings.Groq.DefaultModel})
|
||||
if err := client.TestConnection(ctx); err != nil {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Groq connected successfully"})
|
||||
case 2:
|
||||
if settings.Ollama == nil || !settings.Ollama.Enabled {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": "Ollama not enabled"})
|
||||
}
|
||||
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
|
||||
Host: settings.Ollama.Host,
|
||||
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
|
||||
})
|
||||
if err := client.TestConnection(ctx); err != nil {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": false, "message": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"success": true, "message": "Ollama connected successfully"})
|
||||
default:
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "unsupported provider"})
|
||||
}
|
||||
}
|
||||
|
||||
// correctSpelling applies basic spelling corrections to the text
|
||||
func (s *APIV1Service) correctSpelling(ctx context.Context, text string) string {
|
||||
// TODO: Implement actual spell checking using a spell checker library
|
||||
// For now, return the text as-is
|
||||
return text
|
||||
}
|
||||
|
||||
// addAutoTags analyzes the text and adds relevant #tags
|
||||
func (s *APIV1Service) addAutoTags(ctx context.Context, text string) string {
|
||||
// Common programming languages and technologies
|
||||
tagMap := map[string][]string{
|
||||
"javascript": {"javascript", "js"},
|
||||
"python": {"python", "py"},
|
||||
"java": {"java"},
|
||||
"go": {"go", "golang"},
|
||||
"rust": {"rust"},
|
||||
"dart": {"dart"},
|
||||
"flutter": {"flutter"},
|
||||
"react": {"react", "reactjs"},
|
||||
"vue": {"vue", "vuejs"},
|
||||
"angular": {"angular"},
|
||||
"node": {"nodejs", "node"},
|
||||
"express": {"express"},
|
||||
"mongodb": {"mongodb"},
|
||||
"postgresql": {"postgresql", "postgres"},
|
||||
"mysql": {"mysql"},
|
||||
"redis": {"redis"},
|
||||
"docker": {"docker"},
|
||||
"kubernetes": {"kubernetes", "k8s"},
|
||||
"aws": {"aws", "amazon"},
|
||||
"azure": {"azure"},
|
||||
"gcp": {"gcp", "google-cloud"},
|
||||
"git": {"git"},
|
||||
"github": {"github"},
|
||||
"gitlab": {"gitlab"},
|
||||
"ci/cd": {"ci-cd"},
|
||||
"testing": {"testing"},
|
||||
"api": {"api"},
|
||||
"rest": {"rest", "api"},
|
||||
"graphql": {"graphql"},
|
||||
"database": {"database", "db"},
|
||||
"frontend": {"frontend"},
|
||||
"backend": {"backend"},
|
||||
"fullstack": {"fullstack"},
|
||||
"mobile": {"mobile"},
|
||||
"web": {"web"},
|
||||
"devops": {"devops"},
|
||||
"security": {"security"},
|
||||
"machine learning": {"ml", "machine-learning"},
|
||||
"ai": {"ai", "artificial-intelligence"},
|
||||
"blockchain": {"blockchain"},
|
||||
"crypto": {"crypto", "cryptocurrency"},
|
||||
}
|
||||
|
||||
// Convert text to lowercase for matching
|
||||
lowerText := strings.ToLower(text)
|
||||
var tags []string
|
||||
tagSet := make(map[string]bool)
|
||||
|
||||
// Check for programming languages and technologies
|
||||
for keyword, tagList := range tagMap {
|
||||
if strings.Contains(lowerText, keyword) {
|
||||
for _, tag := range tagList {
|
||||
if !tagSet[tag] {
|
||||
tags = append(tags, tag)
|
||||
tagSet[tag] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add tags to the end of the text if any were found
|
||||
if len(tags) > 0 {
|
||||
tagString := "\n\n"
|
||||
for i, tag := range tags {
|
||||
if i > 0 {
|
||||
tagString += " "
|
||||
}
|
||||
tagString += "#" + tag
|
||||
}
|
||||
text += tagString
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
func (s *APIV1Service) handleListOllamaModels(c *echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
if auth.GetUserID(ctx) == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
|
||||
settings, err := s.loadAISettings(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
if settings.Ollama == nil || !settings.Ollama.Enabled {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Ollama is not enabled. Configure it in Settings > AI."})
|
||||
}
|
||||
|
||||
client := ollamaai.NewOllamaClient(ollamaai.OllamaConfig{
|
||||
Host: settings.Ollama.Host,
|
||||
Timeout: time.Duration(settings.Ollama.TimeoutSeconds) * time.Second,
|
||||
})
|
||||
|
||||
models, err := client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": fmt.Sprintf("failed to list Ollama models: %v", err)})
|
||||
}
|
||||
|
||||
// Convert to the same format as frontend expects
|
||||
var modelList []map[string]string
|
||||
for _, model := range models {
|
||||
modelList = append(modelList, map[string]string{
|
||||
"id": model,
|
||||
"name": model,
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"models": modelList,
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user