first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
This commit is contained in:
199
plugin/ai/ollama/ollama.go
Normal file
199
plugin/ai/ollama/ollama.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OllamaClient struct {
|
||||
httpClient *http.Client
|
||||
config OllamaConfig
|
||||
}
|
||||
|
||||
type OllamaConfig struct {
|
||||
Host string
|
||||
DefaultModel string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// CompletionRequest is the input for generating completions.
|
||||
type CompletionRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Temperature float32
|
||||
MaxTokens int
|
||||
}
|
||||
|
||||
// CompletionResponse is the output from generating completions.
|
||||
type CompletionResponse struct {
|
||||
Text string
|
||||
Model string
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
type OllamaGenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stream bool `json:"stream"`
|
||||
Options map[string]interface{} `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaGenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
TotalDuration int64 `json:"total_duration,omitempty"`
|
||||
LoadDuration int64 `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int64 `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaTagsResponse struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
} `json:"models"`
|
||||
}
|
||||
|
||||
func NewOllamaClient(config OllamaConfig) *OllamaClient {
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 120 * time.Second // Increase to 2 minutes for generation
|
||||
}
|
||||
|
||||
return &OllamaClient{
|
||||
httpClient: &http.Client{
|
||||
Timeout: config.Timeout,
|
||||
},
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OllamaClient) TestConnection(ctx context.Context) error {
|
||||
url := fmt.Sprintf("%s/api/tags", o.config.Host)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Ollama: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OllamaClient) GenerateCompletion(ctx context.Context, req CompletionRequest) (*CompletionResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = o.config.DefaultModel
|
||||
}
|
||||
if model == "" {
|
||||
model = "llama3"
|
||||
}
|
||||
|
||||
ollamaReq := OllamaGenerateRequest{
|
||||
Model: model,
|
||||
Prompt: req.Prompt,
|
||||
Stream: false,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": req.Temperature,
|
||||
},
|
||||
}
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
ollamaReq.Options["num_predict"] = req.MaxTokens
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(ollamaReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/generate", o.config.Host)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
// Check if it's a timeout error
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return nil, fmt.Errorf("Ollama request timed out after %.0f seconds. Try reducing the max tokens or using a smaller model", o.config.Timeout.Seconds())
|
||||
}
|
||||
return nil, fmt.Errorf("failed to call Ollama API: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var ollamaResp OllamaGenerateResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return &CompletionResponse{
|
||||
Text: ollamaResp.Response,
|
||||
Model: ollamaResp.Model,
|
||||
PromptTokens: ollamaResp.PromptEvalCount,
|
||||
CompletionTokens: ollamaResp.EvalCount,
|
||||
TotalTokens: ollamaResp.PromptEvalCount + ollamaResp.EvalCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *OllamaClient) ListModels(ctx context.Context) ([]string, error) {
|
||||
url := fmt.Sprintf("%s/api/tags", o.config.Host)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get models: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tagsResp OllamaTagsResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tagsResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode models response: %w", err)
|
||||
}
|
||||
|
||||
var modelNames []string
|
||||
for _, model := range tagsResp.Models {
|
||||
modelNames = append(modelNames, model.Name)
|
||||
}
|
||||
|
||||
return modelNames, nil
|
||||
}
|
||||
Reference in New Issue
Block a user