Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
110 lines
2.3 KiB
Go
110 lines
2.3 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
type GroqClient struct {
|
|
client *openai.Client
|
|
config GroqConfig
|
|
}
|
|
|
|
type GroqConfig struct {
|
|
APIKey string
|
|
BaseURL string
|
|
DefaultModel string
|
|
}
|
|
|
|
type CompletionRequest struct {
|
|
Model string
|
|
Prompt string
|
|
Temperature float32
|
|
MaxTokens int
|
|
}
|
|
|
|
type CompletionResponse struct {
|
|
Text string
|
|
Model string
|
|
PromptTokens int
|
|
CompletionTokens int
|
|
TotalTokens int
|
|
}
|
|
|
|
func NewGroqClient(config GroqConfig) *GroqClient {
|
|
if config.BaseURL == "" {
|
|
config.BaseURL = "https://api.groq.com/openai/v1"
|
|
}
|
|
|
|
clientConfig := openai.DefaultConfig(config.APIKey)
|
|
clientConfig.BaseURL = config.BaseURL
|
|
|
|
return &GroqClient{
|
|
client: openai.NewClientWithConfig(clientConfig),
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
func (g *GroqClient) TestConnection(ctx context.Context) error {
|
|
// Test by listing available models
|
|
_, err := g.client.ListModels(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to Groq API: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (g *GroqClient) GenerateCompletion(ctx context.Context, req CompletionRequest) (*CompletionResponse, error) {
|
|
model := req.Model
|
|
if model == "" {
|
|
model = g.config.DefaultModel
|
|
}
|
|
if model == "" {
|
|
model = "llama-3.1-8b-instant"
|
|
}
|
|
|
|
chatReq := openai.ChatCompletionRequest{
|
|
Model: model,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: req.Prompt,
|
|
},
|
|
},
|
|
Temperature: req.Temperature,
|
|
MaxTokens: req.MaxTokens,
|
|
}
|
|
|
|
resp, err := g.client.CreateChatCompletion(ctx, chatReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate completion: %w", err)
|
|
}
|
|
|
|
if len(resp.Choices) == 0 {
|
|
return nil, fmt.Errorf("no completion choices returned")
|
|
}
|
|
|
|
return &CompletionResponse{
|
|
Text: resp.Choices[0].Message.Content,
|
|
Model: resp.Model,
|
|
PromptTokens: resp.Usage.PromptTokens,
|
|
CompletionTokens: resp.Usage.CompletionTokens,
|
|
TotalTokens: resp.Usage.TotalTokens,
|
|
}, nil
|
|
}
|
|
|
|
func (g *GroqClient) ListModels(ctx context.Context) ([]string, error) {
|
|
models, err := g.client.ListModels(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list models: %w", err)
|
|
}
|
|
|
|
var modelNames []string
|
|
for _, model := range models.Models {
|
|
modelNames = append(modelNames, model.ID)
|
|
}
|
|
|
|
return modelNames, nil
|
|
} |