Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
199 lines
5.1 KiB
Go
199 lines
5.1 KiB
Go
package ai
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
type OllamaClient struct {
|
|
httpClient *http.Client
|
|
config OllamaConfig
|
|
}
|
|
|
|
type OllamaConfig struct {
|
|
Host string
|
|
DefaultModel string
|
|
Timeout time.Duration
|
|
}
|
|
|
|
// CompletionRequest is the input for generating completions.
|
|
type CompletionRequest struct {
|
|
Model string
|
|
Prompt string
|
|
Temperature float32
|
|
MaxTokens int
|
|
}
|
|
|
|
// CompletionResponse is the output from generating completions.
|
|
type CompletionResponse struct {
|
|
Text string
|
|
Model string
|
|
PromptTokens int
|
|
CompletionTokens int
|
|
TotalTokens int
|
|
}
|
|
|
|
type OllamaGenerateRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
Stream bool `json:"stream"`
|
|
Options map[string]interface{} `json:"options,omitempty"`
|
|
}
|
|
|
|
type OllamaGenerateResponse struct {
|
|
Model string `json:"model"`
|
|
Response string `json:"response"`
|
|
Done bool `json:"done"`
|
|
Context []int `json:"context,omitempty"`
|
|
TotalDuration int64 `json:"total_duration,omitempty"`
|
|
LoadDuration int64 `json:"load_duration,omitempty"`
|
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
|
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
|
|
EvalCount int `json:"eval_count,omitempty"`
|
|
EvalDuration int64 `json:"eval_duration,omitempty"`
|
|
}
|
|
|
|
type OllamaTagsResponse struct {
|
|
Models []struct {
|
|
Name string `json:"name"`
|
|
Model string `json:"model"`
|
|
Size int64 `json:"size"`
|
|
Digest string `json:"digest"`
|
|
} `json:"models"`
|
|
}
|
|
|
|
func NewOllamaClient(config OllamaConfig) *OllamaClient {
|
|
if config.Timeout == 0 {
|
|
config.Timeout = 120 * time.Second // Increase to 2 minutes for generation
|
|
}
|
|
|
|
return &OllamaClient{
|
|
httpClient: &http.Client{
|
|
Timeout: config.Timeout,
|
|
},
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
func (o *OllamaClient) TestConnection(ctx context.Context) error {
|
|
url := fmt.Sprintf("%s/api/tags", o.config.Host)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
resp, err := o.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to Ollama: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (o *OllamaClient) GenerateCompletion(ctx context.Context, req CompletionRequest) (*CompletionResponse, error) {
|
|
model := req.Model
|
|
if model == "" {
|
|
model = o.config.DefaultModel
|
|
}
|
|
if model == "" {
|
|
model = "llama3"
|
|
}
|
|
|
|
ollamaReq := OllamaGenerateRequest{
|
|
Model: model,
|
|
Prompt: req.Prompt,
|
|
Stream: false,
|
|
Options: map[string]interface{}{
|
|
"temperature": req.Temperature,
|
|
},
|
|
}
|
|
|
|
if req.MaxTokens > 0 {
|
|
ollamaReq.Options["num_predict"] = req.MaxTokens
|
|
}
|
|
|
|
jsonData, err := json.Marshal(ollamaReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/api/generate", o.config.Host)
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := o.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
// Check if it's a timeout error
|
|
if ctx.Err() == context.DeadlineExceeded {
|
|
return nil, fmt.Errorf("Ollama request timed out after %.0f seconds. Try reducing the max tokens or using a smaller model", o.config.Timeout.Seconds())
|
|
}
|
|
return nil, fmt.Errorf("failed to call Ollama API: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var ollamaResp OllamaGenerateResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
return &CompletionResponse{
|
|
Text: ollamaResp.Response,
|
|
Model: ollamaResp.Model,
|
|
PromptTokens: ollamaResp.PromptEvalCount,
|
|
CompletionTokens: ollamaResp.EvalCount,
|
|
TotalTokens: ollamaResp.PromptEvalCount + ollamaResp.EvalCount,
|
|
}, nil
|
|
}
|
|
|
|
func (o *OllamaClient) ListModels(ctx context.Context) ([]string, error) {
|
|
url := fmt.Sprintf("%s/api/tags", o.config.Host)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
resp, err := o.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get models: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var tagsResp OllamaTagsResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&tagsResp); err != nil {
|
|
return nil, fmt.Errorf("failed to decode models response: %w", err)
|
|
}
|
|
|
|
var modelNames []string
|
|
for _, model := range tagsResp.Models {
|
|
modelNames = append(modelNames, model.Name)
|
|
}
|
|
|
|
return modelNames, nil
|
|
} |