package ai import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "time" ) type OllamaClient struct { httpClient *http.Client config OllamaConfig } type OllamaConfig struct { Host string DefaultModel string Timeout time.Duration } // CompletionRequest is the input for generating completions. type CompletionRequest struct { Model string Prompt string Temperature float32 MaxTokens int } // CompletionResponse is the output from generating completions. type CompletionResponse struct { Text string Model string PromptTokens int CompletionTokens int TotalTokens int } type OllamaGenerateRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` Stream bool `json:"stream"` Options map[string]interface{} `json:"options,omitempty"` } type OllamaGenerateResponse struct { Model string `json:"model"` Response string `json:"response"` Done bool `json:"done"` Context []int `json:"context,omitempty"` TotalDuration int64 `json:"total_duration,omitempty"` LoadDuration int64 `json:"load_duration,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"` EvalCount int `json:"eval_count,omitempty"` EvalDuration int64 `json:"eval_duration,omitempty"` } type OllamaTagsResponse struct { Models []struct { Name string `json:"name"` Model string `json:"model"` Size int64 `json:"size"` Digest string `json:"digest"` } `json:"models"` } func NewOllamaClient(config OllamaConfig) *OllamaClient { if config.Timeout == 0 { config.Timeout = 120 * time.Second // Increase to 2 minutes for generation } return &OllamaClient{ httpClient: &http.Client{ Timeout: config.Timeout, }, config: config, } } func (o *OllamaClient) TestConnection(ctx context.Context) error { url := fmt.Sprintf("%s/api/tags", o.config.Host) req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } resp, err := o.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to connect to Ollama: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body)) } return nil } func (o *OllamaClient) GenerateCompletion(ctx context.Context, req CompletionRequest) (*CompletionResponse, error) { model := req.Model if model == "" { model = o.config.DefaultModel } if model == "" { model = "llama3" } ollamaReq := OllamaGenerateRequest{ Model: model, Prompt: req.Prompt, Stream: false, Options: map[string]interface{}{ "temperature": req.Temperature, }, } if req.MaxTokens > 0 { ollamaReq.Options["num_predict"] = req.MaxTokens } jsonData, err := json.Marshal(ollamaReq) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } url := fmt.Sprintf("%s/api/generate", o.config.Host) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") resp, err := o.httpClient.Do(httpReq) if err != nil { // Check if it's a timeout error if ctx.Err() == context.DeadlineExceeded { return nil, fmt.Errorf("Ollama request timed out after %.0f seconds. Try reducing the max tokens or using a smaller model", o.config.Timeout.Seconds()) } return nil, fmt.Errorf("failed to call Ollama API: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body)) } var ollamaResp OllamaGenerateResponse if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } return &CompletionResponse{ Text: ollamaResp.Response, Model: ollamaResp.Model, PromptTokens: ollamaResp.PromptEvalCount, CompletionTokens: ollamaResp.EvalCount, TotalTokens: ollamaResp.PromptEvalCount + ollamaResp.EvalCount, }, nil } func (o *OllamaClient) ListModels(ctx context.Context) ([]string, error) { url := fmt.Sprintf("%s/api/tags", o.config.Host) req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } resp, err := o.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to get models: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body)) } var tagsResp OllamaTagsResponse if err := json.NewDecoder(resp.Body).Decode(&tagsResp); err != nil { return nil, fmt.Errorf("failed to decode models response: %w", err) } var modelNames []string for _, model := range tagsResp.Models { modelNames = append(modelNames, model.Name) } return modelNames, nil }