first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
This commit is contained in:
110
plugin/ai/groq/groq.go
Normal file
110
plugin/ai/groq/groq.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type GroqClient struct {
|
||||
client *openai.Client
|
||||
config GroqConfig
|
||||
}
|
||||
|
||||
type GroqConfig struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
DefaultModel string
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Temperature float32
|
||||
MaxTokens int
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Text string
|
||||
Model string
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
func NewGroqClient(config GroqConfig) *GroqClient {
|
||||
if config.BaseURL == "" {
|
||||
config.BaseURL = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
|
||||
clientConfig := openai.DefaultConfig(config.APIKey)
|
||||
clientConfig.BaseURL = config.BaseURL
|
||||
|
||||
return &GroqClient{
|
||||
client: openai.NewClientWithConfig(clientConfig),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GroqClient) TestConnection(ctx context.Context) error {
|
||||
// Test by listing available models
|
||||
_, err := g.client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Groq API: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GroqClient) GenerateCompletion(ctx context.Context, req CompletionRequest) (*CompletionResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = g.config.DefaultModel
|
||||
}
|
||||
if model == "" {
|
||||
model = "llama-3.1-8b-instant"
|
||||
}
|
||||
|
||||
chatReq := openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: req.Prompt,
|
||||
},
|
||||
},
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
}
|
||||
|
||||
resp, err := g.client.CreateChatCompletion(ctx, chatReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate completion: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("no completion choices returned")
|
||||
}
|
||||
|
||||
return &CompletionResponse{
|
||||
Text: resp.Choices[0].Message.Content,
|
||||
Model: resp.Model,
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *GroqClient) ListModels(ctx context.Context) ([]string, error) {
|
||||
models, err := g.client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list models: %w", err)
|
||||
}
|
||||
|
||||
var modelNames []string
|
||||
for _, model := range models.Models {
|
||||
modelNames = append(modelNames, model.ID)
|
||||
}
|
||||
|
||||
return modelNames, nil
|
||||
}
|
||||
199
plugin/ai/ollama/ollama.go
Normal file
199
plugin/ai/ollama/ollama.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OllamaClient struct {
|
||||
httpClient *http.Client
|
||||
config OllamaConfig
|
||||
}
|
||||
|
||||
type OllamaConfig struct {
|
||||
Host string
|
||||
DefaultModel string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// CompletionRequest is the input for generating completions.
|
||||
type CompletionRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Temperature float32
|
||||
MaxTokens int
|
||||
}
|
||||
|
||||
// CompletionResponse is the output from generating completions.
|
||||
type CompletionResponse struct {
|
||||
Text string
|
||||
Model string
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
type OllamaGenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stream bool `json:"stream"`
|
||||
Options map[string]interface{} `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaGenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
TotalDuration int64 `json:"total_duration,omitempty"`
|
||||
LoadDuration int64 `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int64 `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaTagsResponse struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
} `json:"models"`
|
||||
}
|
||||
|
||||
func NewOllamaClient(config OllamaConfig) *OllamaClient {
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 120 * time.Second // Increase to 2 minutes for generation
|
||||
}
|
||||
|
||||
return &OllamaClient{
|
||||
httpClient: &http.Client{
|
||||
Timeout: config.Timeout,
|
||||
},
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OllamaClient) TestConnection(ctx context.Context) error {
|
||||
url := fmt.Sprintf("%s/api/tags", o.config.Host)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Ollama: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OllamaClient) GenerateCompletion(ctx context.Context, req CompletionRequest) (*CompletionResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = o.config.DefaultModel
|
||||
}
|
||||
if model == "" {
|
||||
model = "llama3"
|
||||
}
|
||||
|
||||
ollamaReq := OllamaGenerateRequest{
|
||||
Model: model,
|
||||
Prompt: req.Prompt,
|
||||
Stream: false,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": req.Temperature,
|
||||
},
|
||||
}
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
ollamaReq.Options["num_predict"] = req.MaxTokens
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(ollamaReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/generate", o.config.Host)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
// Check if it's a timeout error
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return nil, fmt.Errorf("Ollama request timed out after %.0f seconds. Try reducing the max tokens or using a smaller model", o.config.Timeout.Seconds())
|
||||
}
|
||||
return nil, fmt.Errorf("failed to call Ollama API: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var ollamaResp OllamaGenerateResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return &CompletionResponse{
|
||||
Text: ollamaResp.Response,
|
||||
Model: ollamaResp.Model,
|
||||
PromptTokens: ollamaResp.PromptEvalCount,
|
||||
CompletionTokens: ollamaResp.EvalCount,
|
||||
TotalTokens: ollamaResp.PromptEvalCount + ollamaResp.EvalCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *OllamaClient) ListModels(ctx context.Context) ([]string, error) {
|
||||
url := fmt.Sprintf("%s/api/tags", o.config.Host)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get models: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tagsResp OllamaTagsResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tagsResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode models response: %w", err)
|
||||
}
|
||||
|
||||
var modelNames []string
|
||||
for _, model := range tagsResp.Models {
|
||||
modelNames = append(modelNames, model.Name)
|
||||
}
|
||||
|
||||
return modelNames, nil
|
||||
}
|
||||
1
plugin/cron/README.md
Normal file
1
plugin/cron/README.md
Normal file
@@ -0,0 +1 @@
|
||||
Fork from https://github.com/robfig/cron
|
||||
96
plugin/cron/chain.go
Normal file
96
plugin/cron/chain.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JobWrapper decorates the given Job with some behavior.
|
||||
type JobWrapper func(Job) Job
|
||||
|
||||
// Chain is a sequence of JobWrappers that decorates submitted jobs with
|
||||
// cross-cutting behaviors like logging or synchronization.
|
||||
type Chain struct {
|
||||
wrappers []JobWrapper
|
||||
}
|
||||
|
||||
// NewChain returns a Chain consisting of the given JobWrappers.
|
||||
func NewChain(c ...JobWrapper) Chain {
|
||||
return Chain{c}
|
||||
}
|
||||
|
||||
// Then decorates the given job with all JobWrappers in the chain.
|
||||
//
|
||||
// This:
|
||||
//
|
||||
// NewChain(m1, m2, m3).Then(job)
|
||||
//
|
||||
// is equivalent to:
|
||||
//
|
||||
// m1(m2(m3(job)))
|
||||
func (c Chain) Then(j Job) Job {
|
||||
for i := range c.wrappers {
|
||||
j = c.wrappers[len(c.wrappers)-i-1](j)
|
||||
}
|
||||
return j
|
||||
}
|
||||
|
||||
// Recover panics in wrapped jobs and log them with the provided logger.
|
||||
func Recover(logger Logger) JobWrapper {
|
||||
return func(j Job) Job {
|
||||
return FuncJob(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
err, ok := r.(error)
|
||||
if !ok {
|
||||
err = errors.New("panic: " + fmt.Sprint(r))
|
||||
}
|
||||
logger.Error(err, "panic", "stack", "...\n"+string(buf))
|
||||
}
|
||||
}()
|
||||
j.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// DelayIfStillRunning serializes jobs, delaying subsequent runs until the
|
||||
// previous one is complete. Jobs running after a delay of more than a minute
|
||||
// have the delay logged at Info.
|
||||
func DelayIfStillRunning(logger Logger) JobWrapper {
|
||||
return func(j Job) Job {
|
||||
var mu sync.Mutex
|
||||
return FuncJob(func() {
|
||||
start := time.Now()
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if dur := time.Since(start); dur > time.Minute {
|
||||
logger.Info("delay", "duration", dur)
|
||||
}
|
||||
j.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SkipIfStillRunning skips an invocation of the Job if a previous invocation is
|
||||
// still running. It logs skips to the given logger at Info level.
|
||||
func SkipIfStillRunning(logger Logger) JobWrapper {
|
||||
return func(j Job) Job {
|
||||
var ch = make(chan struct{}, 1)
|
||||
ch <- struct{}{}
|
||||
return FuncJob(func() {
|
||||
select {
|
||||
case v := <-ch:
|
||||
defer func() { ch <- v }()
|
||||
j.Run()
|
||||
default:
|
||||
logger.Info("skip")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
239
plugin/cron/chain_test.go
Normal file
239
plugin/cron/chain_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func appendingJob(slice *[]int, value int) Job {
|
||||
var m sync.Mutex
|
||||
return FuncJob(func() {
|
||||
m.Lock()
|
||||
*slice = append(*slice, value)
|
||||
m.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func appendingWrapper(slice *[]int, value int) JobWrapper {
|
||||
return func(j Job) Job {
|
||||
return FuncJob(func() {
|
||||
appendingJob(slice, value).Run()
|
||||
j.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain(t *testing.T) {
|
||||
var nums []int
|
||||
var (
|
||||
append1 = appendingWrapper(&nums, 1)
|
||||
append2 = appendingWrapper(&nums, 2)
|
||||
append3 = appendingWrapper(&nums, 3)
|
||||
append4 = appendingJob(&nums, 4)
|
||||
)
|
||||
NewChain(append1, append2, append3).Then(append4).Run()
|
||||
if !reflect.DeepEqual(nums, []int{1, 2, 3, 4}) {
|
||||
t.Error("unexpected order of calls:", nums)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainRecover(t *testing.T) {
|
||||
panickingJob := FuncJob(func() {
|
||||
panic("panickingJob panics")
|
||||
})
|
||||
|
||||
t.Run("panic exits job by default", func(*testing.T) {
|
||||
defer func() {
|
||||
if err := recover(); err == nil {
|
||||
t.Errorf("panic expected, but none received")
|
||||
}
|
||||
}()
|
||||
NewChain().Then(panickingJob).
|
||||
Run()
|
||||
})
|
||||
|
||||
t.Run("Recovering JobWrapper recovers", func(*testing.T) {
|
||||
NewChain(Recover(PrintfLogger(log.New(io.Discard, "", 0)))).
|
||||
Then(panickingJob).
|
||||
Run()
|
||||
})
|
||||
|
||||
t.Run("composed with the *IfStillRunning wrappers", func(*testing.T) {
|
||||
NewChain(Recover(PrintfLogger(log.New(io.Discard, "", 0)))).
|
||||
Then(panickingJob).
|
||||
Run()
|
||||
})
|
||||
}
|
||||
|
||||
type countJob struct {
|
||||
m sync.Mutex
|
||||
started int
|
||||
done int
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (j *countJob) Run() {
|
||||
j.m.Lock()
|
||||
j.started++
|
||||
j.m.Unlock()
|
||||
time.Sleep(j.delay)
|
||||
j.m.Lock()
|
||||
j.done++
|
||||
j.m.Unlock()
|
||||
}
|
||||
|
||||
func (j *countJob) Started() int {
|
||||
defer j.m.Unlock()
|
||||
j.m.Lock()
|
||||
return j.started
|
||||
}
|
||||
|
||||
func (j *countJob) Done() int {
|
||||
defer j.m.Unlock()
|
||||
j.m.Lock()
|
||||
return j.done
|
||||
}
|
||||
|
||||
func TestChainDelayIfStillRunning(t *testing.T) {
|
||||
t.Run("runs immediately", func(*testing.T) {
|
||||
var j countJob
|
||||
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
|
||||
if c := j.Done(); c != 1 {
|
||||
t.Errorf("expected job run once, immediately, got %d", c)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second run immediate if first done", func(*testing.T) {
|
||||
var j countJob
|
||||
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go func() {
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(time.Millisecond)
|
||||
go wrappedJob.Run()
|
||||
}()
|
||||
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
|
||||
if c := j.Done(); c != 2 {
|
||||
t.Errorf("expected job run twice, immediately, got %d", c)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second run delayed if first not done", func(*testing.T) {
|
||||
var j countJob
|
||||
j.delay = 10 * time.Millisecond
|
||||
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go func() {
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(time.Millisecond)
|
||||
go wrappedJob.Run()
|
||||
}()
|
||||
|
||||
// After 5ms, the first job is still in progress, and the second job was
|
||||
// run but should be waiting for it to finish.
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
started, done := j.Started(), j.Done()
|
||||
if started != 1 || done != 0 {
|
||||
t.Error("expected first job started, but not finished, got", started, done)
|
||||
}
|
||||
|
||||
// Verify that the second job completes.
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
started, done = j.Started(), j.Done()
|
||||
if started != 2 || done != 2 {
|
||||
t.Error("expected both jobs done, got", started, done)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChainSkipIfStillRunning(t *testing.T) {
|
||||
t.Run("runs immediately", func(*testing.T) {
|
||||
var j countJob
|
||||
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
|
||||
if c := j.Done(); c != 1 {
|
||||
t.Errorf("expected job run once, immediately, got %d", c)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second run immediate if first done", func(*testing.T) {
|
||||
var j countJob
|
||||
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go func() {
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(time.Millisecond)
|
||||
go wrappedJob.Run()
|
||||
}()
|
||||
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
|
||||
if c := j.Done(); c != 2 {
|
||||
t.Errorf("expected job run twice, immediately, got %d", c)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second run skipped if first not done", func(*testing.T) {
|
||||
var j countJob
|
||||
j.delay = 10 * time.Millisecond
|
||||
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go func() {
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(time.Millisecond)
|
||||
go wrappedJob.Run()
|
||||
}()
|
||||
|
||||
// After 5ms, the first job is still in progress, and the second job was
|
||||
// already skipped.
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
started, done := j.Started(), j.Done()
|
||||
if started != 1 || done != 0 {
|
||||
t.Error("expected first job started, but not finished, got", started, done)
|
||||
}
|
||||
|
||||
// Verify that the first job completes and second does not run.
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
started, done = j.Started(), j.Done()
|
||||
if started != 1 || done != 1 {
|
||||
t.Error("expected second job skipped, got", started, done)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skip 10 jobs on rapid fire", func(*testing.T) {
|
||||
var j countJob
|
||||
j.delay = 10 * time.Millisecond
|
||||
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
|
||||
for i := 0; i < 11; i++ {
|
||||
go wrappedJob.Run()
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
done := j.Done()
|
||||
if done != 1 {
|
||||
t.Error("expected 1 jobs executed, 10 jobs dropped, got", done)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different jobs independent", func(*testing.T) {
|
||||
var j1, j2 countJob
|
||||
j1.delay = 10 * time.Millisecond
|
||||
j2.delay = 10 * time.Millisecond
|
||||
chain := NewChain(SkipIfStillRunning(DiscardLogger))
|
||||
wrappedJob1 := chain.Then(&j1)
|
||||
wrappedJob2 := chain.Then(&j2)
|
||||
for i := 0; i < 11; i++ {
|
||||
go wrappedJob1.Run()
|
||||
go wrappedJob2.Run()
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
var (
|
||||
done1 = j1.Done()
|
||||
done2 = j2.Done()
|
||||
)
|
||||
if done1 != 1 || done2 != 1 {
|
||||
t.Error("expected both jobs executed once, got", done1, "and", done2)
|
||||
}
|
||||
})
|
||||
}
|
||||
27
plugin/cron/constantdelay.go
Normal file
27
plugin/cron/constantdelay.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package cron
|
||||
|
||||
import "time"
|
||||
|
||||
// ConstantDelaySchedule represents a simple recurring duty cycle, e.g. "Every 5 minutes".
|
||||
// It does not support jobs more frequent than once a second.
|
||||
type ConstantDelaySchedule struct {
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
// Every returns a crontab Schedule that activates once every duration.
|
||||
// Delays of less than a second are not supported (will round up to 1 second).
|
||||
// Any fields less than a Second are truncated.
|
||||
func Every(duration time.Duration) ConstantDelaySchedule {
|
||||
if duration < time.Second {
|
||||
duration = time.Second
|
||||
}
|
||||
return ConstantDelaySchedule{
|
||||
Delay: duration - time.Duration(duration.Nanoseconds())%time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Next returns the next time this should be run.
|
||||
// This rounds so that the next activation time will be on the second.
|
||||
func (schedule ConstantDelaySchedule) Next(t time.Time) time.Time {
|
||||
return t.Add(schedule.Delay - time.Duration(t.Nanosecond())*time.Nanosecond)
|
||||
}
|
||||
55
plugin/cron/constantdelay_test.go
Normal file
55
plugin/cron/constantdelay_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConstantDelayNext(t *testing.T) {
|
||||
tests := []struct {
|
||||
time string
|
||||
delay time.Duration
|
||||
expected string
|
||||
}{
|
||||
// Simple cases
|
||||
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
|
||||
{"Mon Jul 9 14:59 2012", 15 * time.Minute, "Mon Jul 9 15:14 2012"},
|
||||
{"Mon Jul 9 14:59:59 2012", 15 * time.Minute, "Mon Jul 9 15:14:59 2012"},
|
||||
|
||||
// Wrap around hours
|
||||
{"Mon Jul 9 15:45 2012", 35 * time.Minute, "Mon Jul 9 16:20 2012"},
|
||||
|
||||
// Wrap around days
|
||||
{"Mon Jul 9 23:46 2012", 14 * time.Minute, "Tue Jul 10 00:00 2012"},
|
||||
{"Mon Jul 9 23:45 2012", 35 * time.Minute, "Tue Jul 10 00:20 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", 44*time.Minute + 24*time.Second, "Tue Jul 10 00:20:15 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", 25*time.Hour + 44*time.Minute + 24*time.Second, "Thu Jul 11 01:20:15 2012"},
|
||||
|
||||
// Wrap around months
|
||||
{"Mon Jul 9 23:35 2012", 91*24*time.Hour + 25*time.Minute, "Thu Oct 9 00:00 2012"},
|
||||
|
||||
// Wrap around minute, hour, day, month, and year
|
||||
{"Mon Dec 31 23:59:45 2012", 15 * time.Second, "Tue Jan 1 00:00:00 2013"},
|
||||
|
||||
// Round to nearest second on the delay
|
||||
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
|
||||
|
||||
// Round up to 1 second if the duration is less.
|
||||
{"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:01 2012"},
|
||||
|
||||
// Round to nearest second when calculating the next time.
|
||||
{"Mon Jul 9 14:45:00.005 2012", 15 * time.Minute, "Mon Jul 9 15:00 2012"},
|
||||
|
||||
// Round to nearest second for both.
|
||||
{"Mon Jul 9 14:45:00.005 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
|
||||
}
|
||||
|
||||
for _, c := range tests {
|
||||
actual := Every(c.delay).Next(getTime(c.time))
|
||||
expected := getTime(c.expected)
|
||||
if actual != expected {
|
||||
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.delay, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
353
plugin/cron/cron.go
Normal file
353
plugin/cron/cron.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cron keeps track of any number of entries, invoking the associated func as
|
||||
// specified by the schedule. It may be started, stopped, and the entries may
|
||||
// be inspected while running.
|
||||
type Cron struct {
|
||||
entries []*Entry
|
||||
chain Chain
|
||||
stop chan struct{}
|
||||
add chan *Entry
|
||||
remove chan EntryID
|
||||
snapshot chan chan []Entry
|
||||
running bool
|
||||
logger Logger
|
||||
runningMu sync.Mutex
|
||||
location *time.Location
|
||||
parser ScheduleParser
|
||||
nextID EntryID
|
||||
jobWaiter sync.WaitGroup
|
||||
}
|
||||
|
||||
// ScheduleParser is an interface for schedule spec parsers that return a Schedule.
|
||||
type ScheduleParser interface {
|
||||
Parse(spec string) (Schedule, error)
|
||||
}
|
||||
|
||||
// Job is an interface for submitted cron jobs.
|
||||
type Job interface {
|
||||
Run()
|
||||
}
|
||||
|
||||
// Schedule describes a job's duty cycle.
|
||||
type Schedule interface {
|
||||
// Next returns the next activation time, later than the given time.
|
||||
// Next is invoked initially, and then each time the job is run.
|
||||
Next(time.Time) time.Time
|
||||
}
|
||||
|
||||
// EntryID identifies an entry within a Cron instance.
|
||||
type EntryID int
|
||||
|
||||
// Entry consists of a schedule and the func to execute on that schedule.
|
||||
type Entry struct {
|
||||
// ID is the cron-assigned ID of this entry, which may be used to look up a
|
||||
// snapshot or remove it.
|
||||
ID EntryID
|
||||
|
||||
// Schedule on which this job should be run.
|
||||
Schedule Schedule
|
||||
|
||||
// Next time the job will run, or the zero time if Cron has not been
|
||||
// started or this entry's schedule is unsatisfiable
|
||||
Next time.Time
|
||||
|
||||
// Prev is the last time this job was run, or the zero time if never.
|
||||
Prev time.Time
|
||||
|
||||
// WrappedJob is the thing to run when the Schedule is activated.
|
||||
WrappedJob Job
|
||||
|
||||
// Job is the thing that was submitted to cron.
|
||||
// It is kept around so that user code that needs to get at the job later,
|
||||
// e.g. via Entries() can do so.
|
||||
Job Job
|
||||
}
|
||||
|
||||
// Valid returns true if this is not the zero entry.
|
||||
func (e Entry) Valid() bool { return e.ID != 0 }
|
||||
|
||||
// byTime is a wrapper for sorting the entry array by time
|
||||
// (with zero time at the end).
|
||||
type byTime []*Entry
|
||||
|
||||
func (s byTime) Len() int { return len(s) }
|
||||
func (s byTime) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s byTime) Less(i, j int) bool {
|
||||
// Two zero times should return false.
|
||||
// Otherwise, zero is "greater" than any other time.
|
||||
// (To sort it at the end of the list.)
|
||||
if s[i].Next.IsZero() {
|
||||
return false
|
||||
}
|
||||
if s[j].Next.IsZero() {
|
||||
return true
|
||||
}
|
||||
return s[i].Next.Before(s[j].Next)
|
||||
}
|
||||
|
||||
// New returns a new Cron job runner, modified by the given options.
|
||||
//
|
||||
// Available Settings
|
||||
//
|
||||
// Time Zone
|
||||
// Description: The time zone in which schedules are interpreted
|
||||
// Default: time.Local
|
||||
//
|
||||
// Parser
|
||||
// Description: Parser converts cron spec strings into cron.Schedules.
|
||||
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
|
||||
//
|
||||
// Chain
|
||||
// Description: Wrap submitted jobs to customize behavior.
|
||||
// Default: A chain that recovers panics and logs them to stderr.
|
||||
//
|
||||
// See "cron.With*" to modify the default behavior.
|
||||
func New(opts ...Option) *Cron {
|
||||
c := &Cron{
|
||||
entries: nil,
|
||||
chain: NewChain(),
|
||||
add: make(chan *Entry),
|
||||
stop: make(chan struct{}),
|
||||
snapshot: make(chan chan []Entry),
|
||||
remove: make(chan EntryID),
|
||||
running: false,
|
||||
runningMu: sync.Mutex{},
|
||||
logger: DefaultLogger,
|
||||
location: time.Local,
|
||||
parser: standardParser,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// FuncJob is a wrapper that turns a func() into a cron.Job.
|
||||
type FuncJob func()
|
||||
|
||||
func (f FuncJob) Run() { f() }
|
||||
|
||||
// AddFunc adds a func to the Cron to be run on the given schedule.
|
||||
// The spec is parsed using the time zone of this Cron instance as the default.
|
||||
// An opaque ID is returned that can be used to later remove it.
|
||||
func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) {
|
||||
return c.AddJob(spec, FuncJob(cmd))
|
||||
}
|
||||
|
||||
// AddJob adds a Job to the Cron to be run on the given schedule.
|
||||
// The spec is parsed using the time zone of this Cron instance as the default.
|
||||
// An opaque ID is returned that can be used to later remove it.
|
||||
func (c *Cron) AddJob(spec string, cmd Job) (EntryID, error) {
|
||||
schedule, err := c.parser.Parse(spec)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.Schedule(schedule, cmd), nil
|
||||
}
|
||||
|
||||
// Schedule adds a Job to the Cron to be run on the given schedule.
|
||||
// The job is wrapped with the configured Chain.
|
||||
func (c *Cron) Schedule(schedule Schedule, cmd Job) EntryID {
|
||||
c.runningMu.Lock()
|
||||
defer c.runningMu.Unlock()
|
||||
c.nextID++
|
||||
entry := &Entry{
|
||||
ID: c.nextID,
|
||||
Schedule: schedule,
|
||||
WrappedJob: c.chain.Then(cmd),
|
||||
Job: cmd,
|
||||
}
|
||||
if !c.running {
|
||||
c.entries = append(c.entries, entry)
|
||||
} else {
|
||||
c.add <- entry
|
||||
}
|
||||
return entry.ID
|
||||
}
|
||||
|
||||
// Entries returns a snapshot of the cron entries.
|
||||
func (c *Cron) Entries() []Entry {
|
||||
c.runningMu.Lock()
|
||||
defer c.runningMu.Unlock()
|
||||
if c.running {
|
||||
replyChan := make(chan []Entry, 1)
|
||||
c.snapshot <- replyChan
|
||||
return <-replyChan
|
||||
}
|
||||
return c.entrySnapshot()
|
||||
}
|
||||
|
||||
// Location gets the time zone location.
|
||||
func (c *Cron) Location() *time.Location {
|
||||
return c.location
|
||||
}
|
||||
|
||||
// Entry returns a snapshot of the given entry, or nil if it couldn't be found.
|
||||
func (c *Cron) Entry(id EntryID) Entry {
|
||||
for _, entry := range c.Entries() {
|
||||
if id == entry.ID {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return Entry{}
|
||||
}
|
||||
|
||||
// Remove an entry from being run in the future.
|
||||
func (c *Cron) Remove(id EntryID) {
|
||||
c.runningMu.Lock()
|
||||
defer c.runningMu.Unlock()
|
||||
if c.running {
|
||||
c.remove <- id
|
||||
} else {
|
||||
c.removeEntry(id)
|
||||
}
|
||||
}
|
||||
|
||||
// Start the cron scheduler in its own goroutine, or no-op if already started.
|
||||
func (c *Cron) Start() {
|
||||
c.runningMu.Lock()
|
||||
defer c.runningMu.Unlock()
|
||||
if c.running {
|
||||
return
|
||||
}
|
||||
c.running = true
|
||||
go c.runScheduler()
|
||||
}
|
||||
|
||||
// Run the cron scheduler, or no-op if already running.
|
||||
func (c *Cron) Run() {
|
||||
c.runningMu.Lock()
|
||||
if c.running {
|
||||
c.runningMu.Unlock()
|
||||
return
|
||||
}
|
||||
c.running = true
|
||||
c.runningMu.Unlock()
|
||||
c.runScheduler()
|
||||
}
|
||||
|
||||
// runScheduler runs the scheduler.. this is private just due to the need to synchronize
|
||||
// access to the 'running' state variable.
|
||||
func (c *Cron) runScheduler() {
|
||||
c.logger.Info("start")
|
||||
|
||||
// Figure out the next activation times for each entry.
|
||||
now := c.now()
|
||||
for _, entry := range c.entries {
|
||||
entry.Next = entry.Schedule.Next(now)
|
||||
c.logger.Info("schedule", "now", now, "entry", entry.ID, "next", entry.Next)
|
||||
}
|
||||
|
||||
for {
|
||||
// Determine the next entry to run.
|
||||
sort.Sort(byTime(c.entries))
|
||||
|
||||
var timer *time.Timer
|
||||
if len(c.entries) == 0 || c.entries[0].Next.IsZero() {
|
||||
// If there are no entries yet, just sleep - it still handles new entries
|
||||
// and stop requests.
|
||||
timer = time.NewTimer(100000 * time.Hour)
|
||||
} else {
|
||||
timer = time.NewTimer(c.entries[0].Next.Sub(now))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case now = <-timer.C:
|
||||
now = now.In(c.location)
|
||||
c.logger.Info("wake", "now", now)
|
||||
|
||||
// Run every entry whose next time was less than now
|
||||
for _, e := range c.entries {
|
||||
if e.Next.After(now) || e.Next.IsZero() {
|
||||
break
|
||||
}
|
||||
c.startJob(e.WrappedJob)
|
||||
e.Prev = e.Next
|
||||
e.Next = e.Schedule.Next(now)
|
||||
c.logger.Info("run", "now", now, "entry", e.ID, "next", e.Next)
|
||||
}
|
||||
|
||||
case newEntry := <-c.add:
|
||||
timer.Stop()
|
||||
now = c.now()
|
||||
newEntry.Next = newEntry.Schedule.Next(now)
|
||||
c.entries = append(c.entries, newEntry)
|
||||
c.logger.Info("added", "now", now, "entry", newEntry.ID, "next", newEntry.Next)
|
||||
|
||||
case replyChan := <-c.snapshot:
|
||||
replyChan <- c.entrySnapshot()
|
||||
continue
|
||||
|
||||
case <-c.stop:
|
||||
timer.Stop()
|
||||
c.logger.Info("stop")
|
||||
return
|
||||
|
||||
case id := <-c.remove:
|
||||
timer.Stop()
|
||||
now = c.now()
|
||||
c.removeEntry(id)
|
||||
c.logger.Info("removed", "entry", id)
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startJob runs the given job in a new goroutine.
|
||||
func (c *Cron) startJob(j Job) {
|
||||
c.jobWaiter.Go(func() {
|
||||
j.Run()
|
||||
})
|
||||
}
|
||||
|
||||
// now returns current time in c location.
|
||||
func (c *Cron) now() time.Time {
|
||||
return time.Now().In(c.location)
|
||||
}
|
||||
|
||||
// Stop stops the cron scheduler if it is running; otherwise it does nothing.
|
||||
// A context is returned so the caller can wait for running jobs to complete.
|
||||
func (c *Cron) Stop() context.Context {
|
||||
c.runningMu.Lock()
|
||||
defer c.runningMu.Unlock()
|
||||
if c.running {
|
||||
c.stop <- struct{}{}
|
||||
c.running = false
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
c.jobWaiter.Wait()
|
||||
cancel()
|
||||
}()
|
||||
return ctx
|
||||
}
|
||||
|
||||
// entrySnapshot returns a copy of the current cron entry list.
|
||||
func (c *Cron) entrySnapshot() []Entry {
|
||||
var entries = make([]Entry, len(c.entries))
|
||||
for i, e := range c.entries {
|
||||
entries[i] = *e
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func (c *Cron) removeEntry(id EntryID) {
|
||||
var entries []*Entry
|
||||
for _, e := range c.entries {
|
||||
if e.ID != id {
|
||||
entries = append(entries, e)
|
||||
}
|
||||
}
|
||||
c.entries = entries
|
||||
}
|
||||
702
plugin/cron/cron_test.go
Normal file
702
plugin/cron/cron_test.go
Normal file
@@ -0,0 +1,702 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Many tests schedule a job for every second, and then wait at most a second
|
||||
// for it to run. This amount is just slightly larger than 1 second to
|
||||
// compensate for a few milliseconds of runtime.
|
||||
const OneSecond = 1*time.Second + 50*time.Millisecond
|
||||
|
||||
type syncWriter struct {
|
||||
wr bytes.Buffer
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
func (sw *syncWriter) Write(data []byte) (n int, err error) {
|
||||
sw.m.Lock()
|
||||
n, err = sw.wr.Write(data)
|
||||
sw.m.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (sw *syncWriter) String() string {
|
||||
sw.m.Lock()
|
||||
defer sw.m.Unlock()
|
||||
return sw.wr.String()
|
||||
}
|
||||
|
||||
func newBufLogger(sw *syncWriter) Logger {
|
||||
return PrintfLogger(log.New(sw, "", log.LstdFlags))
|
||||
}
|
||||
|
||||
func TestFuncPanicRecovery(t *testing.T) {
|
||||
var buf syncWriter
|
||||
cron := New(WithParser(secondParser),
|
||||
WithChain(Recover(newBufLogger(&buf))))
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
cron.AddFunc("* * * * * ?", func() {
|
||||
panic("YOLO")
|
||||
})
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
if !strings.Contains(buf.String(), "YOLO") {
|
||||
t.Error("expected a panic to be logged, got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type DummyJob struct{}
|
||||
|
||||
func (DummyJob) Run() {
|
||||
panic("YOLO")
|
||||
}
|
||||
|
||||
func TestJobPanicRecovery(t *testing.T) {
|
||||
var job DummyJob
|
||||
|
||||
var buf syncWriter
|
||||
cron := New(WithParser(secondParser),
|
||||
WithChain(Recover(newBufLogger(&buf))))
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
cron.AddJob("* * * * * ?", job)
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
if !strings.Contains(buf.String(), "YOLO") {
|
||||
t.Error("expected a panic to be logged, got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Start and stop cron with no entries.
|
||||
func TestNoEntries(t *testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Fatal("expected cron will be stopped immediately")
|
||||
case <-stop(cron):
|
||||
}
|
||||
}
|
||||
|
||||
// Start, stop, then add an entry. Verify entry doesn't run.
|
||||
func TestStopCausesJobsToNotRun(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
cron.Stop()
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
// No job ran!
|
||||
case <-wait(wg):
|
||||
t.Fatal("expected stopped cron does not run any job")
|
||||
}
|
||||
}
|
||||
|
||||
// Add a job, start cron, expect it runs.
|
||||
func TestAddBeforeRunning(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
// Give cron 2 seconds to run our job (which is always activated).
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Fatal("expected job runs")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Start cron, add a job, expect it runs.
|
||||
func TestAddWhileRunning(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Fatal("expected job runs")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test for #34. Adding a job after calling start results in multiple job invocations
|
||||
func TestAddWhileRunningWithDelay(t *testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
time.Sleep(5 * time.Second)
|
||||
var calls int64
|
||||
cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) })
|
||||
|
||||
<-time.After(OneSecond)
|
||||
if atomic.LoadInt64(&calls) != 1 {
|
||||
t.Errorf("called %d times, expected 1\n", calls)
|
||||
}
|
||||
}
|
||||
|
||||
// Add a job, remove a job, start cron, expect nothing runs.
|
||||
func TestRemoveBeforeRunning(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
cron.Remove(id)
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
// Success, shouldn't run
|
||||
case <-wait(wg):
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Start cron, add a job, remove it, expect it doesn't run.
|
||||
func TestRemoveWhileRunning(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
cron.Remove(id)
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
case <-wait(wg):
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Test timing with Entries.
|
||||
func TestSnapshotEntries(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := New()
|
||||
cron.AddFunc("@every 2s", func() { wg.Done() })
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
// Cron should fire in 2 seconds. After 1 second, call Entries.
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
cron.Entries()
|
||||
}
|
||||
|
||||
// Even though Entries was called, the cron should fire at the 2 second mark.
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Error("expected job runs at 2 second mark")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the entries are correctly sorted.
|
||||
// Add a bunch of long-in-the-future entries, and an immediate entry, and ensure
|
||||
// that the immediate entry runs immediately.
|
||||
// Also: Test that multiple jobs run in the same instant.
|
||||
func TestMultipleEntries(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("0 0 0 1 1 ?", func() {})
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
id1, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() })
|
||||
id2, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() })
|
||||
cron.AddFunc("0 0 0 31 12 ?", func() {})
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
|
||||
cron.Remove(id1)
|
||||
cron.Start()
|
||||
cron.Remove(id2)
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Error("expected job run in proper order")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test running the same job twice.
|
||||
func TestRunningJobTwice(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("0 0 0 1 1 ?", func() {})
|
||||
cron.AddFunc("0 0 0 31 12 ?", func() {})
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(2 * OneSecond):
|
||||
t.Error("expected job fires 2 times")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunningMultipleSchedules(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("0 0 0 1 1 ?", func() {})
|
||||
cron.AddFunc("0 0 0 31 12 ?", func() {})
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
cron.Schedule(Every(time.Minute), FuncJob(func() {}))
|
||||
cron.Schedule(Every(time.Second), FuncJob(func() { wg.Done() }))
|
||||
cron.Schedule(Every(time.Hour), FuncJob(func() {}))
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(2 * OneSecond):
|
||||
t.Error("expected job fires 2 times")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the cron is run in the local time zone (as opposed to UTC).
|
||||
func TestLocalTimezone(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
now := time.Now()
|
||||
// FIX: Issue #205
|
||||
// This calculation doesn't work in seconds 58 or 59.
|
||||
// Take the easy way out and sleep.
|
||||
if now.Second() >= 58 {
|
||||
time.Sleep(2 * time.Second)
|
||||
now = time.Now()
|
||||
}
|
||||
spec := fmt.Sprintf("%d,%d %d %d %d %d ?",
|
||||
now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month())
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc(spec, func() { wg.Done() })
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond * 2):
|
||||
t.Error("expected job fires 2 times")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the cron is run in the given time zone (as opposed to local).
|
||||
func TestNonLocalTimezone(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
loc, err := time.LoadLocation("Atlantic/Cape_Verde")
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to load time zone Atlantic/Cape_Verde: %+v", err)
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
now := time.Now().In(loc)
|
||||
// FIX: Issue #205
|
||||
// This calculation doesn't work in seconds 58 or 59.
|
||||
// Take the easy way out and sleep.
|
||||
if now.Second() >= 58 {
|
||||
time.Sleep(2 * time.Second)
|
||||
now = time.Now().In(loc)
|
||||
}
|
||||
spec := fmt.Sprintf("%d,%d %d %d %d %d ?",
|
||||
now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month())
|
||||
|
||||
cron := New(WithLocation(loc), WithParser(secondParser))
|
||||
cron.AddFunc(spec, func() { wg.Done() })
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond * 2):
|
||||
t.Error("expected job fires 2 times")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that calling stop before start silently returns without
|
||||
// blocking the stop channel.
|
||||
func TestStopWithoutStart(t *testing.T) {
|
||||
cron := New()
|
||||
cron.Stop()
|
||||
}
|
||||
|
||||
type testJob struct {
|
||||
wg *sync.WaitGroup
|
||||
name string
|
||||
}
|
||||
|
||||
func (t testJob) Run() {
|
||||
t.wg.Done()
|
||||
}
|
||||
|
||||
// Test that adding an invalid job spec returns an error
|
||||
func TestInvalidJobSpec(t *testing.T) {
|
||||
cron := New()
|
||||
_, err := cron.AddJob("this will not parse", nil)
|
||||
if err == nil {
|
||||
t.Errorf("expected an error with invalid spec, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// Test blocking run method behaves as Start()
|
||||
func TestBlockingRun(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
|
||||
var unblockChan = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
cron.Run()
|
||||
close(unblockChan)
|
||||
}()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Error("expected job fires")
|
||||
case <-unblockChan:
|
||||
t.Error("expected that Run() blocks")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that double-running is a no-op
|
||||
func TestStartNoop(t *testing.T) {
|
||||
var tickChan = make(chan struct{}, 2)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * ?", func() {
|
||||
tickChan <- struct{}{}
|
||||
})
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
// Wait for the first firing to ensure the runner is going
|
||||
<-tickChan
|
||||
|
||||
cron.Start()
|
||||
|
||||
<-tickChan
|
||||
|
||||
// Fail if this job fires again in a short period, indicating a double-run
|
||||
select {
|
||||
case <-time.After(time.Millisecond):
|
||||
case <-tickChan:
|
||||
t.Error("expected job fires exactly twice")
|
||||
}
|
||||
}
|
||||
|
||||
// Simple test using Runnables.
|
||||
func TestJob(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddJob("0 0 0 30 Feb ?", testJob{wg, "job0"})
|
||||
cron.AddJob("0 0 0 1 1 ?", testJob{wg, "job1"})
|
||||
job2, _ := cron.AddJob("* * * * * ?", testJob{wg, "job2"})
|
||||
cron.AddJob("1 0 0 1 1 ?", testJob{wg, "job3"})
|
||||
cron.Schedule(Every(5*time.Second+5*time.Nanosecond), testJob{wg, "job4"})
|
||||
job5 := cron.Schedule(Every(5*time.Minute), testJob{wg, "job5"})
|
||||
|
||||
// Test getting an Entry pre-Start.
|
||||
if actualName := cron.Entry(job2).Job.(testJob).name; actualName != "job2" {
|
||||
t.Error("wrong job retrieved:", actualName)
|
||||
}
|
||||
if actualName := cron.Entry(job5).Job.(testJob).name; actualName != "job5" {
|
||||
t.Error("wrong job retrieved:", actualName)
|
||||
}
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.FailNow()
|
||||
case <-wait(wg):
|
||||
}
|
||||
|
||||
// Ensure the entries are in the right order.
|
||||
expecteds := []string{"job2", "job4", "job5", "job1", "job3", "job0"}
|
||||
|
||||
var actuals []string
|
||||
for _, entry := range cron.Entries() {
|
||||
actuals = append(actuals, entry.Job.(testJob).name)
|
||||
}
|
||||
|
||||
for i, expected := range expecteds {
|
||||
if actuals[i] != expected {
|
||||
t.Fatalf("Jobs not in the right order. (expected) %s != %s (actual)", expecteds, actuals)
|
||||
}
|
||||
}
|
||||
|
||||
// Test getting Entries.
|
||||
if actualName := cron.Entry(job2).Job.(testJob).name; actualName != "job2" {
|
||||
t.Error("wrong job retrieved:", actualName)
|
||||
}
|
||||
if actualName := cron.Entry(job5).Job.(testJob).name; actualName != "job5" {
|
||||
t.Error("wrong job retrieved:", actualName)
|
||||
}
|
||||
}
|
||||
|
||||
// Issue #206
|
||||
// Ensure that the next run of a job after removing an entry is accurate.
|
||||
func TestScheduleAfterRemoval(t *testing.T) {
|
||||
var wg1 sync.WaitGroup
|
||||
var wg2 sync.WaitGroup
|
||||
wg1.Add(1)
|
||||
wg2.Add(1)
|
||||
|
||||
// The first time this job is run, set a timer and remove the other job
|
||||
// 750ms later. Correct behavior would be to still run the job again in
|
||||
// 250ms, but the bug would cause it to run instead 1s later.
|
||||
|
||||
var calls int
|
||||
var mu sync.Mutex
|
||||
|
||||
cron := newWithSeconds()
|
||||
hourJob := cron.Schedule(Every(time.Hour), FuncJob(func() {}))
|
||||
cron.Schedule(Every(time.Second), FuncJob(func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
switch calls {
|
||||
case 0:
|
||||
wg1.Done()
|
||||
calls++
|
||||
case 1:
|
||||
time.Sleep(750 * time.Millisecond)
|
||||
cron.Remove(hourJob)
|
||||
calls++
|
||||
case 2:
|
||||
calls++
|
||||
wg2.Done()
|
||||
case 3:
|
||||
panic("unexpected 3rd call")
|
||||
}
|
||||
}))
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
// the first run might be any length of time 0 - 1s, since the schedule
|
||||
// rounds to the second. wait for the first run to true up.
|
||||
wg1.Wait()
|
||||
|
||||
select {
|
||||
case <-time.After(2 * OneSecond):
|
||||
t.Error("expected job fires 2 times")
|
||||
case <-wait(&wg2):
|
||||
}
|
||||
}
|
||||
|
||||
type ZeroSchedule struct{}
|
||||
|
||||
func (*ZeroSchedule) Next(time.Time) time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// Tests that job without time does not run
|
||||
func TestJobWithZeroTimeDoesNotRun(t *testing.T) {
|
||||
cron := newWithSeconds()
|
||||
var calls int64
|
||||
cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) })
|
||||
cron.Schedule(new(ZeroSchedule), FuncJob(func() { t.Error("expected zero task will not run") }))
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
<-time.After(OneSecond)
|
||||
if atomic.LoadInt64(&calls) != 1 {
|
||||
t.Errorf("called %d times, expected 1\n", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopAndWait(t *testing.T) {
|
||||
t.Run("nothing running, returns immediately", func(*testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
ctx := cron.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(time.Millisecond):
|
||||
t.Error("context was not done immediately")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("repeated calls to Stop", func(*testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
_ = cron.Stop()
|
||||
time.Sleep(time.Millisecond)
|
||||
ctx := cron.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(time.Millisecond):
|
||||
t.Error("context was not done immediately")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a couple fast jobs added, still returns immediately", func(*testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
cron.Start()
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
time.Sleep(time.Second)
|
||||
ctx := cron.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(time.Millisecond):
|
||||
t.Error("context was not done immediately")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a couple fast jobs and a slow job added, waits for slow job", func(*testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
cron.Start()
|
||||
cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) })
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
time.Sleep(time.Second)
|
||||
|
||||
ctx := cron.Stop()
|
||||
|
||||
// Verify that it is not done for at least 750ms
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("context was done too quickly immediately")
|
||||
case <-time.After(750 * time.Millisecond):
|
||||
// expected, because the job sleeping for 1 second is still running
|
||||
}
|
||||
|
||||
// Verify that it IS done in the next 500ms (giving 250ms buffer)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// expected
|
||||
case <-time.After(1500 * time.Millisecond):
|
||||
t.Error("context not done after job should have completed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("repeated calls to stop, waiting for completion and after", func(*testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) })
|
||||
cron.Start()
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
time.Sleep(time.Second)
|
||||
ctx := cron.Stop()
|
||||
ctx2 := cron.Stop()
|
||||
|
||||
// Verify that it is not done for at least 1500ms
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("context was done too quickly immediately")
|
||||
case <-ctx2.Done():
|
||||
t.Error("context2 was done too quickly immediately")
|
||||
case <-time.After(1500 * time.Millisecond):
|
||||
// expected, because the job sleeping for 2 seconds is still running
|
||||
}
|
||||
|
||||
// Verify that it IS done in the next 1s (giving 500ms buffer)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// expected
|
||||
case <-time.After(time.Second):
|
||||
t.Error("context not done after job should have completed")
|
||||
}
|
||||
|
||||
// Verify that ctx2 is also done.
|
||||
select {
|
||||
case <-ctx2.Done():
|
||||
// expected
|
||||
case <-time.After(time.Millisecond):
|
||||
t.Error("context2 not done even though context1 is")
|
||||
}
|
||||
|
||||
// Verify that a new context retrieved from stop is immediately done.
|
||||
ctx3 := cron.Stop()
|
||||
select {
|
||||
case <-ctx3.Done():
|
||||
// expected
|
||||
case <-time.After(time.Millisecond):
|
||||
t.Error("context not done even when cron Stop is completed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiThreadedStartAndStop(t *testing.T) {
|
||||
cron := New()
|
||||
go cron.Run()
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
cron.Stop()
|
||||
}
|
||||
|
||||
func wait(wg *sync.WaitGroup) chan bool {
|
||||
ch := make(chan bool)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
ch <- true
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func stop(cron *Cron) chan bool {
|
||||
ch := make(chan bool)
|
||||
go func() {
|
||||
cron.Stop()
|
||||
ch <- true
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
// newWithSeconds returns a Cron with the seconds field enabled.
|
||||
func newWithSeconds() *Cron {
|
||||
return New(WithParser(secondParser), WithChain())
|
||||
}
|
||||
86
plugin/cron/logger.go
Normal file
86
plugin/cron/logger.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultLogger is used by Cron if none is specified.
|
||||
var DefaultLogger = PrintfLogger(log.New(os.Stdout, "cron: ", log.LstdFlags))
|
||||
|
||||
// DiscardLogger can be used by callers to discard all log messages.
|
||||
var DiscardLogger = PrintfLogger(log.New(io.Discard, "", 0))
|
||||
|
||||
// Logger is the interface used in this package for logging, so that any backend
|
||||
// can be plugged in. It is a subset of the github.com/go-logr/logr interface.
|
||||
type Logger interface {
|
||||
// Info logs routine messages about cron's operation.
|
||||
Info(msg string, keysAndValues ...interface{})
|
||||
// Error logs an error condition.
|
||||
Error(err error, msg string, keysAndValues ...interface{})
|
||||
}
|
||||
|
||||
// PrintfLogger wraps a Printf-based logger (such as the standard library "log")
|
||||
// into an implementation of the Logger interface which logs errors only.
|
||||
func PrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger {
|
||||
return printfLogger{l, false}
|
||||
}
|
||||
|
||||
// VerbosePrintfLogger wraps a Printf-based logger (such as the standard library
|
||||
// "log") into an implementation of the Logger interface which logs everything.
|
||||
func VerbosePrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger {
|
||||
return printfLogger{l, true}
|
||||
}
|
||||
|
||||
type printfLogger struct {
|
||||
logger interface{ Printf(string, ...interface{}) }
|
||||
logInfo bool
|
||||
}
|
||||
|
||||
func (pl printfLogger) Info(msg string, keysAndValues ...interface{}) {
|
||||
if pl.logInfo {
|
||||
keysAndValues = formatTimes(keysAndValues)
|
||||
pl.logger.Printf(
|
||||
formatString(len(keysAndValues)),
|
||||
append([]interface{}{msg}, keysAndValues...)...)
|
||||
}
|
||||
}
|
||||
|
||||
func (pl printfLogger) Error(err error, msg string, keysAndValues ...interface{}) {
|
||||
keysAndValues = formatTimes(keysAndValues)
|
||||
pl.logger.Printf(
|
||||
formatString(len(keysAndValues)+2),
|
||||
append([]interface{}{msg, "error", err}, keysAndValues...)...)
|
||||
}
|
||||
|
||||
// formatString returns a logfmt-like format string for the number of
|
||||
// key/values.
|
||||
func formatString(numKeysAndValues int) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("%s")
|
||||
if numKeysAndValues > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
for i := 0; i < numKeysAndValues/2; i++ {
|
||||
if i > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString("%v=%v")
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatTimes formats any time.Time values as RFC3339.
|
||||
func formatTimes(keysAndValues []interface{}) []interface{} {
|
||||
var formattedArgs []interface{}
|
||||
for _, arg := range keysAndValues {
|
||||
if t, ok := arg.(time.Time); ok {
|
||||
arg = t.Format(time.RFC3339)
|
||||
}
|
||||
formattedArgs = append(formattedArgs, arg)
|
||||
}
|
||||
return formattedArgs
|
||||
}
|
||||
45
plugin/cron/option.go
Normal file
45
plugin/cron/option.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Option represents a modification to the default behavior of a Cron.
|
||||
type Option func(*Cron)
|
||||
|
||||
// WithLocation overrides the timezone of the cron instance.
|
||||
func WithLocation(loc *time.Location) Option {
|
||||
return func(c *Cron) {
|
||||
c.location = loc
|
||||
}
|
||||
}
|
||||
|
||||
// WithSeconds overrides the parser used for interpreting job schedules to
|
||||
// include a seconds field as the first one.
|
||||
func WithSeconds() Option {
|
||||
return WithParser(NewParser(
|
||||
Second | Minute | Hour | Dom | Month | Dow | Descriptor,
|
||||
))
|
||||
}
|
||||
|
||||
// WithParser overrides the parser used for interpreting job schedules.
|
||||
func WithParser(p ScheduleParser) Option {
|
||||
return func(c *Cron) {
|
||||
c.parser = p
|
||||
}
|
||||
}
|
||||
|
||||
// WithChain specifies Job wrappers to apply to all jobs added to this cron.
|
||||
// Refer to the Chain* functions in this package for provided wrappers.
|
||||
func WithChain(wrappers ...JobWrapper) Option {
|
||||
return func(c *Cron) {
|
||||
c.chain = NewChain(wrappers...)
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger uses the provided logger.
|
||||
func WithLogger(logger Logger) Option {
|
||||
return func(c *Cron) {
|
||||
c.logger = logger
|
||||
}
|
||||
}
|
||||
43
plugin/cron/option_test.go
Normal file
43
plugin/cron/option_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWithLocation(t *testing.T) {
|
||||
c := New(WithLocation(time.UTC))
|
||||
if c.location != time.UTC {
|
||||
t.Errorf("expected UTC, got %v", c.location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithParser(t *testing.T) {
|
||||
var parser = NewParser(Dow)
|
||||
c := New(WithParser(parser))
|
||||
if c.parser != parser {
|
||||
t.Error("expected provided parser")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithVerboseLogger(t *testing.T) {
|
||||
var buf syncWriter
|
||||
var logger = log.New(&buf, "", log.LstdFlags)
|
||||
c := New(WithLogger(VerbosePrintfLogger(logger)))
|
||||
if c.logger.(printfLogger).logger != logger {
|
||||
t.Error("expected provided logger")
|
||||
}
|
||||
|
||||
c.AddFunc("@every 1s", func() {})
|
||||
c.Start()
|
||||
time.Sleep(OneSecond)
|
||||
c.Stop()
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "schedule,") ||
|
||||
!strings.Contains(out, "run,") {
|
||||
t.Error("expected to see some actions, got:", out)
|
||||
}
|
||||
}
|
||||
437
plugin/cron/parser.go
Normal file
437
plugin/cron/parser.go
Normal file
@@ -0,0 +1,437 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Configuration options for creating a parser. Most options specify which
|
||||
// fields should be included, while others enable features. If a field is not
|
||||
// included the parser will assume a default value. These options do not change
|
||||
// the order fields are parse in.
|
||||
type ParseOption int
|
||||
|
||||
const (
|
||||
Second ParseOption = 1 << iota // Seconds field, default 0
|
||||
SecondOptional // Optional seconds field, default 0
|
||||
Minute // Minutes field, default 0
|
||||
Hour // Hours field, default 0
|
||||
Dom // Day of month field, default *
|
||||
Month // Month field, default *
|
||||
Dow // Day of week field, default *
|
||||
DowOptional // Optional day of week field, default *
|
||||
Descriptor // Allow descriptors such as @monthly, @weekly, etc.
|
||||
)
|
||||
|
||||
var places = []ParseOption{
|
||||
Second,
|
||||
Minute,
|
||||
Hour,
|
||||
Dom,
|
||||
Month,
|
||||
Dow,
|
||||
}
|
||||
|
||||
var defaults = []string{
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
"*",
|
||||
"*",
|
||||
"*",
|
||||
}
|
||||
|
||||
// A custom Parser that can be configured.
|
||||
type Parser struct {
|
||||
options ParseOption
|
||||
}
|
||||
|
||||
// NewParser creates a Parser with custom options.
|
||||
//
|
||||
// It panics if more than one Optional is given, since it would be impossible to
|
||||
// correctly infer which optional is provided or missing in general.
|
||||
//
|
||||
// Examples
|
||||
//
|
||||
// // Standard parser without descriptors
|
||||
// specParser := NewParser(Minute | Hour | Dom | Month | Dow)
|
||||
// sched, err := specParser.Parse("0 0 15 */3 *")
|
||||
//
|
||||
// // Same as above, just excludes time fields
|
||||
// specParser := NewParser(Dom | Month | Dow)
|
||||
// sched, err := specParser.Parse("15 */3 *")
|
||||
//
|
||||
// // Same as above, just makes Dow optional
|
||||
// specParser := NewParser(Dom | Month | DowOptional)
|
||||
// sched, err := specParser.Parse("15 */3")
|
||||
func NewParser(options ParseOption) Parser {
|
||||
optionals := 0
|
||||
if options&DowOptional > 0 {
|
||||
optionals++
|
||||
}
|
||||
if options&SecondOptional > 0 {
|
||||
optionals++
|
||||
}
|
||||
if optionals > 1 {
|
||||
panic("multiple optionals may not be configured")
|
||||
}
|
||||
return Parser{options}
|
||||
}
|
||||
|
||||
// Parse returns a new crontab schedule representing the given spec.
|
||||
// It returns a descriptive error if the spec is not valid.
|
||||
// It accepts crontab specs and features configured by NewParser.
|
||||
func (p Parser) Parse(spec string) (Schedule, error) {
|
||||
if len(spec) == 0 {
|
||||
return nil, errors.New("empty spec string")
|
||||
}
|
||||
|
||||
// Extract timezone if present
|
||||
var loc = time.Local
|
||||
if strings.HasPrefix(spec, "TZ=") || strings.HasPrefix(spec, "CRON_TZ=") {
|
||||
var err error
|
||||
i := strings.Index(spec, " ")
|
||||
eq := strings.Index(spec, "=")
|
||||
if loc, err = time.LoadLocation(spec[eq+1 : i]); err != nil {
|
||||
return nil, errors.Wrap(err, "provided bad location")
|
||||
}
|
||||
spec = strings.TrimSpace(spec[i:])
|
||||
}
|
||||
|
||||
// Handle named schedules (descriptors), if configured
|
||||
if strings.HasPrefix(spec, "@") {
|
||||
if p.options&Descriptor == 0 {
|
||||
return nil, errors.New("descriptors not enabled")
|
||||
}
|
||||
return parseDescriptor(spec, loc)
|
||||
}
|
||||
|
||||
// Split on whitespace.
|
||||
fields := strings.Fields(spec)
|
||||
|
||||
// Validate & fill in any omitted or optional fields
|
||||
var err error
|
||||
fields, err = normalizeFields(fields, p.options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
field := func(field string, r bounds) uint64 {
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
var bits uint64
|
||||
bits, err = getField(field, r)
|
||||
return bits
|
||||
}
|
||||
|
||||
var (
|
||||
second = field(fields[0], seconds)
|
||||
minute = field(fields[1], minutes)
|
||||
hour = field(fields[2], hours)
|
||||
dayofmonth = field(fields[3], dom)
|
||||
month = field(fields[4], months)
|
||||
dayofweek = field(fields[5], dow)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SpecSchedule{
|
||||
Second: second,
|
||||
Minute: minute,
|
||||
Hour: hour,
|
||||
Dom: dayofmonth,
|
||||
Month: month,
|
||||
Dow: dayofweek,
|
||||
Location: loc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeFields takes a subset set of the time fields and returns the full set
|
||||
// with defaults (zeroes) populated for unset fields.
|
||||
//
|
||||
// As part of performing this function, it also validates that the provided
|
||||
// fields are compatible with the configured options.
|
||||
func normalizeFields(fields []string, options ParseOption) ([]string, error) {
|
||||
// Validate optionals & add their field to options
|
||||
optionals := 0
|
||||
if options&SecondOptional > 0 {
|
||||
options |= Second
|
||||
optionals++
|
||||
}
|
||||
if options&DowOptional > 0 {
|
||||
options |= Dow
|
||||
optionals++
|
||||
}
|
||||
if optionals > 1 {
|
||||
return nil, errors.New("multiple optionals may not be configured")
|
||||
}
|
||||
|
||||
// Figure out how many fields we need
|
||||
max := 0
|
||||
for _, place := range places {
|
||||
if options&place > 0 {
|
||||
max++
|
||||
}
|
||||
}
|
||||
min := max - optionals
|
||||
|
||||
// Validate number of fields
|
||||
if count := len(fields); count < min || count > max {
|
||||
if min == max {
|
||||
return nil, errors.New("incorrect number of fields")
|
||||
}
|
||||
return nil, errors.New("incorrect number of fields, expected " + strconv.Itoa(min) + "-" + strconv.Itoa(max))
|
||||
}
|
||||
|
||||
// Populate the optional field if not provided
|
||||
if min < max && len(fields) == min {
|
||||
switch {
|
||||
case options&DowOptional > 0:
|
||||
fields = append(fields, defaults[5]) // TODO: improve access to default
|
||||
case options&SecondOptional > 0:
|
||||
fields = append([]string{defaults[0]}, fields...)
|
||||
default:
|
||||
return nil, errors.New("unexpected optional field")
|
||||
}
|
||||
}
|
||||
|
||||
// Populate all fields not part of options with their defaults
|
||||
n := 0
|
||||
expandedFields := make([]string, len(places))
|
||||
copy(expandedFields, defaults)
|
||||
for i, place := range places {
|
||||
if options&place > 0 {
|
||||
expandedFields[i] = fields[n]
|
||||
n++
|
||||
}
|
||||
}
|
||||
return expandedFields, nil
|
||||
}
|
||||
|
||||
var standardParser = NewParser(
|
||||
Minute | Hour | Dom | Month | Dow | Descriptor,
|
||||
)
|
||||
|
||||
// ParseStandard returns a new crontab schedule representing the given
|
||||
// standardSpec (https://en.wikipedia.org/wiki/Cron). It requires 5 entries
|
||||
// representing: minute, hour, day of month, month and day of week, in that
|
||||
// order. It returns a descriptive error if the spec is not valid.
|
||||
//
|
||||
// It accepts
|
||||
// - Standard crontab specs, e.g. "* * * * ?"
|
||||
// - Descriptors, e.g. "@midnight", "@every 1h30m"
|
||||
func ParseStandard(standardSpec string) (Schedule, error) {
|
||||
return standardParser.Parse(standardSpec)
|
||||
}
|
||||
|
||||
// getField returns an Int with the bits set representing all of the times that
|
||||
// the field represents or error parsing field value. A "field" is a comma-separated
|
||||
// list of "ranges".
|
||||
func getField(field string, r bounds) (uint64, error) {
|
||||
var bits uint64
|
||||
ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' })
|
||||
for _, expr := range ranges {
|
||||
bit, err := getRange(expr, r)
|
||||
if err != nil {
|
||||
return bits, err
|
||||
}
|
||||
bits |= bit
|
||||
}
|
||||
return bits, nil
|
||||
}
|
||||
|
||||
// getRange returns the bits indicated by the given expression:
|
||||
//
|
||||
// number | number "-" number [ "/" number ]
|
||||
//
|
||||
// or error parsing range.
|
||||
func getRange(expr string, r bounds) (uint64, error) {
|
||||
var (
|
||||
start, end, step uint
|
||||
rangeAndStep = strings.Split(expr, "/")
|
||||
lowAndHigh = strings.Split(rangeAndStep[0], "-")
|
||||
singleDigit = len(lowAndHigh) == 1
|
||||
err error
|
||||
)
|
||||
|
||||
var extra uint64
|
||||
if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" {
|
||||
start = r.min
|
||||
end = r.max
|
||||
extra = starBit
|
||||
} else {
|
||||
start, err = parseIntOrName(lowAndHigh[0], r.names)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch len(lowAndHigh) {
|
||||
case 1:
|
||||
end = start
|
||||
case 2:
|
||||
end, err = parseIntOrName(lowAndHigh[1], r.names)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
default:
|
||||
return 0, errors.New("too many hyphens: " + expr)
|
||||
}
|
||||
}
|
||||
|
||||
switch len(rangeAndStep) {
|
||||
case 1:
|
||||
step = 1
|
||||
case 2:
|
||||
step, err = mustParseInt(rangeAndStep[1])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Special handling: "N/step" means "N-max/step".
|
||||
if singleDigit {
|
||||
end = r.max
|
||||
}
|
||||
if step > 1 {
|
||||
extra = 0
|
||||
}
|
||||
default:
|
||||
return 0, errors.New("too many slashes: " + expr)
|
||||
}
|
||||
|
||||
if start < r.min {
|
||||
return 0, errors.New("beginning of range below minimum: " + expr)
|
||||
}
|
||||
if end > r.max {
|
||||
return 0, errors.New("end of range above maximum: " + expr)
|
||||
}
|
||||
if start > end {
|
||||
return 0, errors.New("beginning of range after end: " + expr)
|
||||
}
|
||||
if step == 0 {
|
||||
return 0, errors.New("step cannot be zero: " + expr)
|
||||
}
|
||||
|
||||
return getBits(start, end, step) | extra, nil
|
||||
}
|
||||
|
||||
// parseIntOrName returns the (possibly-named) integer contained in expr.
|
||||
func parseIntOrName(expr string, names map[string]uint) (uint, error) {
|
||||
if names != nil {
|
||||
if namedInt, ok := names[strings.ToLower(expr)]; ok {
|
||||
return namedInt, nil
|
||||
}
|
||||
}
|
||||
return mustParseInt(expr)
|
||||
}
|
||||
|
||||
// mustParseInt parses the given expression as an int or returns an error.
|
||||
func mustParseInt(expr string) (uint, error) {
|
||||
num, err := strconv.Atoi(expr)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to parse number")
|
||||
}
|
||||
if num < 0 {
|
||||
return 0, errors.New("number must be positive")
|
||||
}
|
||||
|
||||
return uint(num), nil
|
||||
}
|
||||
|
||||
// getBits sets all bits in the range [min, max], modulo the given step size.
|
||||
func getBits(min, max, step uint) uint64 {
|
||||
var bits uint64
|
||||
|
||||
// If step is 1, use shifts.
|
||||
if step == 1 {
|
||||
return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min)
|
||||
}
|
||||
|
||||
// Else, use a simple loop.
|
||||
for i := min; i <= max; i += step {
|
||||
bits |= 1 << i
|
||||
}
|
||||
return bits
|
||||
}
|
||||
|
||||
// all returns all bits within the given bounds.
|
||||
func all(r bounds) uint64 {
|
||||
return getBits(r.min, r.max, 1) | starBit
|
||||
}
|
||||
|
||||
// parseDescriptor returns a predefined schedule for the expression, or error if none matches.
|
||||
func parseDescriptor(descriptor string, loc *time.Location) (Schedule, error) {
|
||||
switch descriptor {
|
||||
case "@yearly", "@annually":
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: 1 << hours.min,
|
||||
Dom: 1 << dom.min,
|
||||
Month: 1 << months.min,
|
||||
Dow: all(dow),
|
||||
Location: loc,
|
||||
}, nil
|
||||
|
||||
case "@monthly":
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: 1 << hours.min,
|
||||
Dom: 1 << dom.min,
|
||||
Month: all(months),
|
||||
Dow: all(dow),
|
||||
Location: loc,
|
||||
}, nil
|
||||
|
||||
case "@weekly":
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: 1 << hours.min,
|
||||
Dom: all(dom),
|
||||
Month: all(months),
|
||||
Dow: 1 << dow.min,
|
||||
Location: loc,
|
||||
}, nil
|
||||
|
||||
case "@daily", "@midnight":
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: 1 << hours.min,
|
||||
Dom: all(dom),
|
||||
Month: all(months),
|
||||
Dow: all(dow),
|
||||
Location: loc,
|
||||
}, nil
|
||||
|
||||
case "@hourly":
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: all(hours),
|
||||
Dom: all(dom),
|
||||
Month: all(months),
|
||||
Dow: all(dow),
|
||||
Location: loc,
|
||||
}, nil
|
||||
default:
|
||||
// Continue to check @every prefix below
|
||||
}
|
||||
|
||||
const every = "@every "
|
||||
if strings.HasPrefix(descriptor, every) {
|
||||
duration, err := time.ParseDuration(descriptor[len(every):])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse duration")
|
||||
}
|
||||
return Every(duration), nil
|
||||
}
|
||||
|
||||
return nil, errors.New("unrecognized descriptor: " + descriptor)
|
||||
}
|
||||
384
plugin/cron/parser_test.go
Normal file
384
plugin/cron/parser_test.go
Normal file
@@ -0,0 +1,384 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var secondParser = NewParser(Second | Minute | Hour | Dom | Month | DowOptional | Descriptor)
|
||||
|
||||
func TestRange(t *testing.T) {
|
||||
zero := uint64(0)
|
||||
ranges := []struct {
|
||||
expr string
|
||||
min, max uint
|
||||
expected uint64
|
||||
err string
|
||||
}{
|
||||
{"5", 0, 7, 1 << 5, ""},
|
||||
{"0", 0, 7, 1 << 0, ""},
|
||||
{"7", 0, 7, 1 << 7, ""},
|
||||
|
||||
{"5-5", 0, 7, 1 << 5, ""},
|
||||
{"5-6", 0, 7, 1<<5 | 1<<6, ""},
|
||||
{"5-7", 0, 7, 1<<5 | 1<<6 | 1<<7, ""},
|
||||
|
||||
{"5-6/2", 0, 7, 1 << 5, ""},
|
||||
{"5-7/2", 0, 7, 1<<5 | 1<<7, ""},
|
||||
{"5-7/1", 0, 7, 1<<5 | 1<<6 | 1<<7, ""},
|
||||
|
||||
{"*", 1, 3, 1<<1 | 1<<2 | 1<<3 | starBit, ""},
|
||||
{"*/2", 1, 3, 1<<1 | 1<<3, ""},
|
||||
|
||||
{"5--5", 0, 0, zero, "too many hyphens"},
|
||||
{"jan-x", 0, 0, zero, `failed to parse number: strconv.Atoi: parsing "jan": invalid syntax`},
|
||||
{"2-x", 1, 5, zero, `failed to parse number: strconv.Atoi: parsing "x": invalid syntax`},
|
||||
{"*/-12", 0, 0, zero, "number must be positive"},
|
||||
{"*//2", 0, 0, zero, "too many slashes"},
|
||||
{"1", 3, 5, zero, "below minimum"},
|
||||
{"6", 3, 5, zero, "above maximum"},
|
||||
{"5-3", 3, 5, zero, "beginning of range after end: 5-3"},
|
||||
{"*/0", 0, 0, zero, "step cannot be zero: */0"},
|
||||
}
|
||||
|
||||
for _, c := range ranges {
|
||||
actual, err := getRange(c.expr, bounds{c.min, c.max, nil})
|
||||
if len(c.err) != 0 && (err == nil || !strings.Contains(err.Error(), c.err)) {
|
||||
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
|
||||
}
|
||||
if len(c.err) == 0 && err != nil {
|
||||
t.Errorf("%s => unexpected error %v", c.expr, err)
|
||||
}
|
||||
if actual != c.expected {
|
||||
t.Errorf("%s => expected %d, got %d", c.expr, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestField(t *testing.T) {
|
||||
fields := []struct {
|
||||
expr string
|
||||
min, max uint
|
||||
expected uint64
|
||||
}{
|
||||
{"5", 1, 7, 1 << 5},
|
||||
{"5,6", 1, 7, 1<<5 | 1<<6},
|
||||
{"5,6,7", 1, 7, 1<<5 | 1<<6 | 1<<7},
|
||||
{"1,5-7/2,3", 1, 7, 1<<1 | 1<<5 | 1<<7 | 1<<3},
|
||||
}
|
||||
|
||||
for _, c := range fields {
|
||||
actual, _ := getField(c.expr, bounds{c.min, c.max, nil})
|
||||
if actual != c.expected {
|
||||
t.Errorf("%s => expected %d, got %d", c.expr, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
allBits := []struct {
|
||||
r bounds
|
||||
expected uint64
|
||||
}{
|
||||
{minutes, 0xfffffffffffffff}, // 0-59: 60 ones
|
||||
{hours, 0xffffff}, // 0-23: 24 ones
|
||||
{dom, 0xfffffffe}, // 1-31: 31 ones, 1 zero
|
||||
{months, 0x1ffe}, // 1-12: 12 ones, 1 zero
|
||||
{dow, 0x7f}, // 0-6: 7 ones
|
||||
}
|
||||
|
||||
for _, c := range allBits {
|
||||
actual := all(c.r) // all() adds the starBit, so compensate for that..
|
||||
if c.expected|starBit != actual {
|
||||
t.Errorf("%d-%d/%d => expected %b, got %b",
|
||||
c.r.min, c.r.max, 1, c.expected|starBit, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
bits := []struct {
|
||||
min, max, step uint
|
||||
expected uint64
|
||||
}{
|
||||
{0, 0, 1, 0x1},
|
||||
{1, 1, 1, 0x2},
|
||||
{1, 5, 2, 0x2a}, // 101010
|
||||
{1, 4, 2, 0xa}, // 1010
|
||||
}
|
||||
|
||||
for _, c := range bits {
|
||||
actual := getBits(c.min, c.max, c.step)
|
||||
if c.expected != actual {
|
||||
t.Errorf("%d-%d/%d => expected %b, got %b",
|
||||
c.min, c.max, c.step, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseScheduleErrors(t *testing.T) {
|
||||
var tests = []struct{ expr, err string }{
|
||||
{"* 5 j * * *", `failed to parse number: strconv.Atoi: parsing "j": invalid syntax`},
|
||||
{"@every Xm", "failed to parse duration"},
|
||||
{"@unrecognized", "unrecognized descriptor"},
|
||||
{"* * * *", "incorrect number of fields, expected 5-6"},
|
||||
{"", "empty spec string"},
|
||||
}
|
||||
for _, c := range tests {
|
||||
actual, err := secondParser.Parse(c.expr)
|
||||
if err == nil || !strings.Contains(err.Error(), c.err) {
|
||||
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
|
||||
}
|
||||
if actual != nil {
|
||||
t.Errorf("expected nil schedule on error, got %v", actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSchedule(t *testing.T) {
|
||||
tokyo, _ := time.LoadLocation("Asia/Tokyo")
|
||||
entries := []struct {
|
||||
parser Parser
|
||||
expr string
|
||||
expected Schedule
|
||||
}{
|
||||
{secondParser, "0 5 * * * *", every5min(time.Local)},
|
||||
{standardParser, "5 * * * *", every5min(time.Local)},
|
||||
{secondParser, "CRON_TZ=UTC 0 5 * * * *", every5min(time.UTC)},
|
||||
{standardParser, "CRON_TZ=UTC 5 * * * *", every5min(time.UTC)},
|
||||
{secondParser, "CRON_TZ=Asia/Tokyo 0 5 * * * *", every5min(tokyo)},
|
||||
{secondParser, "@every 5m", ConstantDelaySchedule{5 * time.Minute}},
|
||||
{secondParser, "@midnight", midnight(time.Local)},
|
||||
{secondParser, "TZ=UTC @midnight", midnight(time.UTC)},
|
||||
{secondParser, "TZ=Asia/Tokyo @midnight", midnight(tokyo)},
|
||||
{secondParser, "@yearly", annual(time.Local)},
|
||||
{secondParser, "@annually", annual(time.Local)},
|
||||
{
|
||||
parser: secondParser,
|
||||
expr: "* 5 * * * *",
|
||||
expected: &SpecSchedule{
|
||||
Second: all(seconds),
|
||||
Minute: 1 << 5,
|
||||
Hour: all(hours),
|
||||
Dom: all(dom),
|
||||
Month: all(months),
|
||||
Dow: all(dow),
|
||||
Location: time.Local,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range entries {
|
||||
actual, err := c.parser.Parse(c.expr)
|
||||
if err != nil {
|
||||
t.Errorf("%s => unexpected error %v", c.expr, err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, c.expected) {
|
||||
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionalSecondSchedule(t *testing.T) {
|
||||
parser := NewParser(SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor)
|
||||
entries := []struct {
|
||||
expr string
|
||||
expected Schedule
|
||||
}{
|
||||
{"0 5 * * * *", every5min(time.Local)},
|
||||
{"5 5 * * * *", every5min5s(time.Local)},
|
||||
{"5 * * * *", every5min(time.Local)},
|
||||
}
|
||||
|
||||
for _, c := range entries {
|
||||
actual, err := parser.Parse(c.expr)
|
||||
if err != nil {
|
||||
t.Errorf("%s => unexpected error %v", c.expr, err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, c.expected) {
|
||||
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
options ParseOption
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
"AllFields_NoOptional",
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
Second | Minute | Hour | Dom | Month | Dow | Descriptor,
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
},
|
||||
{
|
||||
"AllFields_SecondOptional_Provided",
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor,
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
},
|
||||
{
|
||||
"AllFields_SecondOptional_NotProvided",
|
||||
[]string{"5", "*", "*", "*", "*"},
|
||||
SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor,
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
},
|
||||
{
|
||||
"SubsetFields_NoOptional",
|
||||
[]string{"5", "15", "*"},
|
||||
Hour | Dom | Month,
|
||||
[]string{"0", "0", "5", "15", "*", "*"},
|
||||
},
|
||||
{
|
||||
"SubsetFields_DowOptional_Provided",
|
||||
[]string{"5", "15", "*", "4"},
|
||||
Hour | Dom | Month | DowOptional,
|
||||
[]string{"0", "0", "5", "15", "*", "4"},
|
||||
},
|
||||
{
|
||||
"SubsetFields_DowOptional_NotProvided",
|
||||
[]string{"5", "15", "*"},
|
||||
Hour | Dom | Month | DowOptional,
|
||||
[]string{"0", "0", "5", "15", "*", "*"},
|
||||
},
|
||||
{
|
||||
"SubsetFields_SecondOptional_NotProvided",
|
||||
[]string{"5", "15", "*"},
|
||||
SecondOptional | Hour | Dom | Month,
|
||||
[]string{"0", "0", "5", "15", "*", "*"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(*testing.T) {
|
||||
actual, err := normalizeFields(test.input, test.options)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, test.expected) {
|
||||
t.Errorf("expected %v, got %v", test.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeFields_Errors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
options ParseOption
|
||||
err string
|
||||
}{
|
||||
{
|
||||
"TwoOptionals",
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
SecondOptional | Minute | Hour | Dom | Month | DowOptional,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"TooManyFields",
|
||||
[]string{"0", "5", "*", "*"},
|
||||
SecondOptional | Minute | Hour,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"NoFields",
|
||||
[]string{},
|
||||
SecondOptional | Minute | Hour,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"TooFewFields",
|
||||
[]string{"*"},
|
||||
SecondOptional | Minute | Hour,
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(*testing.T) {
|
||||
actual, err := normalizeFields(test.input, test.options)
|
||||
if err == nil {
|
||||
t.Errorf("expected an error, got none. results: %v", actual)
|
||||
}
|
||||
if !strings.Contains(err.Error(), test.err) {
|
||||
t.Errorf("expected error %q, got %q", test.err, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStandardSpecSchedule(t *testing.T) {
|
||||
entries := []struct {
|
||||
expr string
|
||||
expected Schedule
|
||||
err string
|
||||
}{
|
||||
{
|
||||
expr: "5 * * * *",
|
||||
expected: &SpecSchedule{1 << seconds.min, 1 << 5, all(hours), all(dom), all(months), all(dow), time.Local},
|
||||
},
|
||||
{
|
||||
expr: "@every 5m",
|
||||
expected: ConstantDelaySchedule{time.Duration(5) * time.Minute},
|
||||
},
|
||||
{
|
||||
expr: "5 j * * *",
|
||||
err: `failed to parse number: strconv.Atoi: parsing "j": invalid syntax`,
|
||||
},
|
||||
{
|
||||
expr: "* * * *",
|
||||
err: "incorrect number of fields",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range entries {
|
||||
actual, err := ParseStandard(c.expr)
|
||||
if len(c.err) != 0 && (err == nil || !strings.Contains(err.Error(), c.err)) {
|
||||
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
|
||||
}
|
||||
if len(c.err) == 0 && err != nil {
|
||||
t.Errorf("%s => unexpected error %v", c.expr, err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, c.expected) {
|
||||
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoDescriptorParser(t *testing.T) {
|
||||
parser := NewParser(Minute | Hour)
|
||||
_, err := parser.Parse("@every 1m")
|
||||
if err == nil {
|
||||
t.Error("expected an error, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func every5min(loc *time.Location) *SpecSchedule {
|
||||
return &SpecSchedule{1 << 0, 1 << 5, all(hours), all(dom), all(months), all(dow), loc}
|
||||
}
|
||||
|
||||
func every5min5s(loc *time.Location) *SpecSchedule {
|
||||
return &SpecSchedule{1 << 5, 1 << 5, all(hours), all(dom), all(months), all(dow), loc}
|
||||
}
|
||||
|
||||
func midnight(loc *time.Location) *SpecSchedule {
|
||||
return &SpecSchedule{1, 1, 1, all(dom), all(months), all(dow), loc}
|
||||
}
|
||||
|
||||
func annual(loc *time.Location) *SpecSchedule {
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: 1 << hours.min,
|
||||
Dom: 1 << dom.min,
|
||||
Month: 1 << months.min,
|
||||
Dow: all(dow),
|
||||
Location: loc,
|
||||
}
|
||||
}
|
||||
188
plugin/cron/spec.go
Normal file
188
plugin/cron/spec.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package cron
|
||||
|
||||
import "time"
|
||||
|
||||
// SpecSchedule specifies a duty cycle (to the second granularity), based on a
|
||||
// traditional crontab specification. It is computed initially and stored as bit sets.
|
||||
type SpecSchedule struct {
|
||||
Second, Minute, Hour, Dom, Month, Dow uint64
|
||||
|
||||
// Override location for this schedule.
|
||||
Location *time.Location
|
||||
}
|
||||
|
||||
// bounds provides a range of acceptable values (plus a map of name to value).
|
||||
type bounds struct {
|
||||
min, max uint
|
||||
names map[string]uint
|
||||
}
|
||||
|
||||
// The bounds for each field.
|
||||
var (
|
||||
seconds = bounds{0, 59, nil}
|
||||
minutes = bounds{0, 59, nil}
|
||||
hours = bounds{0, 23, nil}
|
||||
dom = bounds{1, 31, nil}
|
||||
months = bounds{1, 12, map[string]uint{
|
||||
"jan": 1,
|
||||
"feb": 2,
|
||||
"mar": 3,
|
||||
"apr": 4,
|
||||
"may": 5,
|
||||
"jun": 6,
|
||||
"jul": 7,
|
||||
"aug": 8,
|
||||
"sep": 9,
|
||||
"oct": 10,
|
||||
"nov": 11,
|
||||
"dec": 12,
|
||||
}}
|
||||
dow = bounds{0, 6, map[string]uint{
|
||||
"sun": 0,
|
||||
"mon": 1,
|
||||
"tue": 2,
|
||||
"wed": 3,
|
||||
"thu": 4,
|
||||
"fri": 5,
|
||||
"sat": 6,
|
||||
}}
|
||||
)
|
||||
|
||||
const (
|
||||
// Set the top bit if a star was included in the expression.
|
||||
starBit = 1 << 63
|
||||
)
|
||||
|
||||
// Next returns the next time this schedule is activated, greater than the given
|
||||
// time. If no time can be found to satisfy the schedule, return the zero time.
|
||||
func (s *SpecSchedule) Next(t time.Time) time.Time {
|
||||
// General approach
|
||||
//
|
||||
// For Month, Day, Hour, Minute, Second:
|
||||
// Check if the time value matches. If yes, continue to the next field.
|
||||
// If the field doesn't match the schedule, then increment the field until it matches.
|
||||
// While incrementing the field, a wrap-around brings it back to the beginning
|
||||
// of the field list (since it is necessary to re-verify previous field
|
||||
// values)
|
||||
|
||||
// Convert the given time into the schedule's timezone, if one is specified.
|
||||
// Save the original timezone so we can convert back after we find a time.
|
||||
// Note that schedules without a time zone specified (time.Local) are treated
|
||||
// as local to the time provided.
|
||||
origLocation := t.Location()
|
||||
loc := s.Location
|
||||
if loc == time.Local {
|
||||
loc = t.Location()
|
||||
}
|
||||
if s.Location != time.Local {
|
||||
t = t.In(s.Location)
|
||||
}
|
||||
|
||||
// Start at the earliest possible time (the upcoming second).
|
||||
t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond)
|
||||
|
||||
// This flag indicates whether a field has been incremented.
|
||||
added := false
|
||||
|
||||
// If no time is found within five years, return zero.
|
||||
yearLimit := t.Year() + 5
|
||||
|
||||
WRAP:
|
||||
if t.Year() > yearLimit {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// Find the first applicable month.
|
||||
// If it's this month, then do nothing.
|
||||
for 1<<uint(t.Month())&s.Month == 0 {
|
||||
// If we have to add a month, reset the other parts to 0.
|
||||
if !added {
|
||||
added = true
|
||||
// Otherwise, set the date at the beginning (since the current time is irrelevant).
|
||||
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
|
||||
}
|
||||
t = t.AddDate(0, 1, 0)
|
||||
|
||||
// Wrapped around.
|
||||
if t.Month() == time.January {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
// Now get a day in that month.
|
||||
//
|
||||
// NOTE: This causes issues for daylight savings regimes where midnight does
|
||||
// not exist. For example: Sao Paulo has DST that transforms midnight on
|
||||
// 11/3 into 1am. Handle that by noticing when the Hour ends up != 0.
|
||||
for !dayMatches(s, t) {
|
||||
if !added {
|
||||
added = true
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
t = t.AddDate(0, 0, 1)
|
||||
// Notice if the hour is no longer midnight due to DST.
|
||||
// Add an hour if it's 23, subtract an hour if it's 1.
|
||||
if t.Hour() != 0 {
|
||||
if t.Hour() > 12 {
|
||||
t = t.Add(time.Duration(24-t.Hour()) * time.Hour)
|
||||
} else {
|
||||
t = t.Add(time.Duration(-t.Hour()) * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
if t.Day() == 1 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
for 1<<uint(t.Hour())&s.Hour == 0 {
|
||||
if !added {
|
||||
added = true
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), 0, 0, 0, loc)
|
||||
}
|
||||
t = t.Add(1 * time.Hour)
|
||||
|
||||
if t.Hour() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
for 1<<uint(t.Minute())&s.Minute == 0 {
|
||||
if !added {
|
||||
added = true
|
||||
t = t.Truncate(time.Minute)
|
||||
}
|
||||
t = t.Add(1 * time.Minute)
|
||||
|
||||
if t.Minute() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
for 1<<uint(t.Second())&s.Second == 0 {
|
||||
if !added {
|
||||
added = true
|
||||
t = t.Truncate(time.Second)
|
||||
}
|
||||
t = t.Add(1 * time.Second)
|
||||
|
||||
if t.Second() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
return t.In(origLocation)
|
||||
}
|
||||
|
||||
// dayMatches returns true if the schedule's day-of-week and day-of-month
|
||||
// restrictions are satisfied by the given time.
|
||||
func dayMatches(s *SpecSchedule, t time.Time) bool {
|
||||
var (
|
||||
domMatch = 1<<uint(t.Day())&s.Dom > 0
|
||||
dowMatch = 1<<uint(t.Weekday())&s.Dow > 0
|
||||
)
|
||||
if s.Dom&starBit > 0 || s.Dow&starBit > 0 {
|
||||
return domMatch && dowMatch
|
||||
}
|
||||
return domMatch || dowMatch
|
||||
}
|
||||
301
plugin/cron/spec_test.go
Normal file
301
plugin/cron/spec_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestActivation(t *testing.T) {
|
||||
tests := []struct {
|
||||
time, spec string
|
||||
expected bool
|
||||
}{
|
||||
// Every fifteen minutes.
|
||||
{"Mon Jul 9 15:00 2012", "0/15 * * * *", true},
|
||||
{"Mon Jul 9 15:45 2012", "0/15 * * * *", true},
|
||||
{"Mon Jul 9 15:40 2012", "0/15 * * * *", false},
|
||||
|
||||
// Every fifteen minutes, starting at 5 minutes.
|
||||
{"Mon Jul 9 15:05 2012", "5/15 * * * *", true},
|
||||
{"Mon Jul 9 15:20 2012", "5/15 * * * *", true},
|
||||
{"Mon Jul 9 15:50 2012", "5/15 * * * *", true},
|
||||
|
||||
// Named months
|
||||
{"Sun Jul 15 15:00 2012", "0/15 * * Jul *", true},
|
||||
{"Sun Jul 15 15:00 2012", "0/15 * * Jun *", false},
|
||||
|
||||
// Everything set.
|
||||
{"Sun Jul 15 08:30 2012", "30 08 ? Jul Sun", true},
|
||||
{"Sun Jul 15 08:30 2012", "30 08 15 Jul ?", true},
|
||||
{"Mon Jul 16 08:30 2012", "30 08 ? Jul Sun", false},
|
||||
{"Mon Jul 16 08:30 2012", "30 08 15 Jul ?", false},
|
||||
|
||||
// Predefined schedules
|
||||
{"Mon Jul 9 15:00 2012", "@hourly", true},
|
||||
{"Mon Jul 9 15:04 2012", "@hourly", false},
|
||||
{"Mon Jul 9 15:00 2012", "@daily", false},
|
||||
{"Mon Jul 9 00:00 2012", "@daily", true},
|
||||
{"Mon Jul 9 00:00 2012", "@weekly", false},
|
||||
{"Sun Jul 8 00:00 2012", "@weekly", true},
|
||||
{"Sun Jul 8 01:00 2012", "@weekly", false},
|
||||
{"Sun Jul 8 00:00 2012", "@monthly", false},
|
||||
{"Sun Jul 1 00:00 2012", "@monthly", true},
|
||||
|
||||
// Test interaction of DOW and DOM.
|
||||
// If both are restricted, then only one needs to match.
|
||||
{"Sun Jul 15 00:00 2012", "* * 1,15 * Sun", true},
|
||||
{"Fri Jun 15 00:00 2012", "* * 1,15 * Sun", true},
|
||||
{"Wed Aug 1 00:00 2012", "* * 1,15 * Sun", true},
|
||||
{"Sun Jul 15 00:00 2012", "* * */10 * Sun", true}, // verifies #70
|
||||
|
||||
// However, if one has a star, then both need to match.
|
||||
{"Sun Jul 15 00:00 2012", "* * * * Mon", false},
|
||||
{"Mon Jul 9 00:00 2012", "* * 1,15 * *", false},
|
||||
{"Sun Jul 15 00:00 2012", "* * 1,15 * *", true},
|
||||
{"Sun Jul 15 00:00 2012", "* * */2 * Sun", true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
sched, err := ParseStandard(test.spec)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
actual := sched.Next(getTime(test.time).Add(-1 * time.Second))
|
||||
expected := getTime(test.time)
|
||||
if test.expected && expected != actual || !test.expected && expected == actual {
|
||||
t.Errorf("Fail evaluating %s on %s: (expected) %s != %s (actual)",
|
||||
test.spec, test.time, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNext(t *testing.T) {
|
||||
runs := []struct {
|
||||
time, spec string
|
||||
expected string
|
||||
}{
|
||||
// Simple cases
|
||||
{"Mon Jul 9 14:45 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
|
||||
{"Mon Jul 9 14:59 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
|
||||
{"Mon Jul 9 14:59:59 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
|
||||
|
||||
// Wrap around hours
|
||||
{"Mon Jul 9 15:45 2012", "0 20-35/15 * * * *", "Mon Jul 9 16:20 2012"},
|
||||
|
||||
// Wrap around days
|
||||
{"Mon Jul 9 23:46 2012", "0 */15 * * * *", "Tue Jul 10 00:00 2012"},
|
||||
{"Mon Jul 9 23:45 2012", "0 20-35/15 * * * *", "Tue Jul 10 00:20 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * * * *", "Tue Jul 10 00:20:15 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 1/2 * * *", "Tue Jul 10 01:20:15 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 10-12 * * *", "Tue Jul 10 10:20:15 2012"},
|
||||
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 1/2 */2 * *", "Thu Jul 11 01:20:15 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * 9-20 * *", "Wed Jul 10 00:20:15 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * 9-20 Jul *", "Wed Jul 10 00:20:15 2012"},
|
||||
|
||||
// Wrap around months
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 9 Apr-Oct ?", "Thu Aug 9 00:00 2012"},
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 */5 Apr,Aug,Oct Mon", "Tue Aug 1 00:00 2012"},
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 */5 Oct Mon", "Mon Oct 1 00:00 2012"},
|
||||
|
||||
// Wrap around years
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 * Feb Mon", "Mon Feb 4 00:00 2013"},
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 * Feb Mon/2", "Fri Feb 1 00:00 2013"},
|
||||
|
||||
// Wrap around minute, hour, day, month, and year
|
||||
{"Mon Dec 31 23:59:45 2012", "0 * * * * *", "Tue Jan 1 00:00:00 2013"},
|
||||
|
||||
// Leap year
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 29 Feb ?", "Mon Feb 29 00:00 2016"},
|
||||
|
||||
// Daylight savings time 2am EST (-5) -> 3am EDT (-4)
|
||||
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 30 2 11 Mar ?", "2013-03-11T02:30:00-0400"},
|
||||
|
||||
// hourly job
|
||||
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T01:00:00-0500"},
|
||||
{"2012-03-11T01:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T03:00:00-0400"},
|
||||
{"2012-03-11T03:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T04:00:00-0400"},
|
||||
{"2012-03-11T04:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T05:00:00-0400"},
|
||||
|
||||
// hourly job using CRON_TZ
|
||||
{"2012-03-11T00:00:00-0500", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T01:00:00-0500"},
|
||||
{"2012-03-11T01:00:00-0500", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T03:00:00-0400"},
|
||||
{"2012-03-11T03:00:00-0400", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T04:00:00-0400"},
|
||||
{"2012-03-11T04:00:00-0400", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T05:00:00-0400"},
|
||||
|
||||
// 1am nightly job
|
||||
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-03-11T01:00:00-0500"},
|
||||
{"2012-03-11T01:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-03-12T01:00:00-0400"},
|
||||
|
||||
// 2am nightly job (skipped)
|
||||
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 2 * * ?", "2012-03-12T02:00:00-0400"},
|
||||
|
||||
// Daylight savings time 2am EDT (-4) => 1am EST (-5)
|
||||
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 30 2 04 Nov ?", "2012-11-04T02:30:00-0500"},
|
||||
{"2012-11-04T01:45:00-0400", "TZ=America/New_York 0 30 1 04 Nov ?", "2012-11-04T01:30:00-0500"},
|
||||
|
||||
// hourly job
|
||||
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T01:00:00-0400"},
|
||||
{"2012-11-04T01:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T01:00:00-0500"},
|
||||
{"2012-11-04T01:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T02:00:00-0500"},
|
||||
|
||||
// 1am nightly job (runs twice)
|
||||
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 1 * * ?", "2012-11-04T01:00:00-0400"},
|
||||
{"2012-11-04T01:00:00-0400", "TZ=America/New_York 0 0 1 * * ?", "2012-11-04T01:00:00-0500"},
|
||||
{"2012-11-04T01:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-11-05T01:00:00-0500"},
|
||||
|
||||
// 2am nightly job
|
||||
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 2 * * ?", "2012-11-04T02:00:00-0500"},
|
||||
{"2012-11-04T02:00:00-0500", "TZ=America/New_York 0 0 2 * * ?", "2012-11-05T02:00:00-0500"},
|
||||
|
||||
// 3am nightly job
|
||||
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 3 * * ?", "2012-11-04T03:00:00-0500"},
|
||||
{"2012-11-04T03:00:00-0500", "TZ=America/New_York 0 0 3 * * ?", "2012-11-05T03:00:00-0500"},
|
||||
|
||||
// hourly job
|
||||
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 * * * ?", "2012-11-04T01:00:00-0400"},
|
||||
{"TZ=America/New_York 2012-11-04T01:00:00-0400", "0 0 * * * ?", "2012-11-04T01:00:00-0500"},
|
||||
{"TZ=America/New_York 2012-11-04T01:00:00-0500", "0 0 * * * ?", "2012-11-04T02:00:00-0500"},
|
||||
|
||||
// 1am nightly job (runs twice)
|
||||
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 1 * * ?", "2012-11-04T01:00:00-0400"},
|
||||
{"TZ=America/New_York 2012-11-04T01:00:00-0400", "0 0 1 * * ?", "2012-11-04T01:00:00-0500"},
|
||||
{"TZ=America/New_York 2012-11-04T01:00:00-0500", "0 0 1 * * ?", "2012-11-05T01:00:00-0500"},
|
||||
|
||||
// 2am nightly job
|
||||
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 2 * * ?", "2012-11-04T02:00:00-0500"},
|
||||
{"TZ=America/New_York 2012-11-04T02:00:00-0500", "0 0 2 * * ?", "2012-11-05T02:00:00-0500"},
|
||||
|
||||
// 3am nightly job
|
||||
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 3 * * ?", "2012-11-04T03:00:00-0500"},
|
||||
{"TZ=America/New_York 2012-11-04T03:00:00-0500", "0 0 3 * * ?", "2012-11-05T03:00:00-0500"},
|
||||
|
||||
// Unsatisfiable
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 30 Feb ?", ""},
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 31 Apr ?", ""},
|
||||
|
||||
// Monthly job
|
||||
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 3 3 * ?", "2012-12-03T03:00:00-0500"},
|
||||
|
||||
// Test the scenario of DST resulting in midnight not being a valid time.
|
||||
// https://github.com/robfig/cron/issues/157
|
||||
{"2018-10-17T05:00:00-0400", "TZ=America/Sao_Paulo 0 0 9 10 * ?", "2018-11-10T06:00:00-0500"},
|
||||
{"2018-02-14T05:00:00-0500", "TZ=America/Sao_Paulo 0 0 9 22 * ?", "2018-02-22T07:00:00-0500"},
|
||||
}
|
||||
|
||||
for _, c := range runs {
|
||||
sched, err := secondParser.Parse(c.spec)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
actual := sched.Next(getTime(c.time))
|
||||
expected := getTime(c.expected)
|
||||
if !actual.Equal(expected) {
|
||||
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.spec, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrors(t *testing.T) {
|
||||
invalidSpecs := []string{
|
||||
"xyz",
|
||||
"60 0 * * *",
|
||||
"0 60 * * *",
|
||||
"0 0 * * XYZ",
|
||||
}
|
||||
for _, spec := range invalidSpecs {
|
||||
_, err := ParseStandard(spec)
|
||||
if err == nil {
|
||||
t.Error("expected an error parsing: ", spec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getTime(value string) time.Time {
|
||||
if value == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
var location = time.Local
|
||||
if strings.HasPrefix(value, "TZ=") {
|
||||
parts := strings.Fields(value)
|
||||
loc, err := time.LoadLocation(parts[0][len("TZ="):])
|
||||
if err != nil {
|
||||
panic("could not parse location:" + err.Error())
|
||||
}
|
||||
location = loc
|
||||
value = parts[1]
|
||||
}
|
||||
|
||||
var layouts = []string{
|
||||
"Mon Jan 2 15:04 2006",
|
||||
"Mon Jan 2 15:04:05 2006",
|
||||
}
|
||||
for _, layout := range layouts {
|
||||
if t, err := time.ParseInLocation(layout, value, location); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
if t, err := time.ParseInLocation("2006-01-02T15:04:05-0700", value, location); err == nil {
|
||||
return t
|
||||
}
|
||||
panic("could not parse time value " + value)
|
||||
}
|
||||
|
||||
func TestNextWithTz(t *testing.T) {
|
||||
runs := []struct {
|
||||
time, spec string
|
||||
expected string
|
||||
}{
|
||||
// Failing tests
|
||||
{"2016-01-03T13:09:03+0530", "14 14 * * *", "2016-01-03T14:14:00+0530"},
|
||||
{"2016-01-03T04:09:03+0530", "14 14 * * ?", "2016-01-03T14:14:00+0530"},
|
||||
|
||||
// Passing tests
|
||||
{"2016-01-03T14:09:03+0530", "14 14 * * *", "2016-01-03T14:14:00+0530"},
|
||||
{"2016-01-03T14:00:00+0530", "14 14 * * ?", "2016-01-03T14:14:00+0530"},
|
||||
}
|
||||
for _, c := range runs {
|
||||
sched, err := ParseStandard(c.spec)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
actual := sched.Next(getTimeTZ(c.time))
|
||||
expected := getTimeTZ(c.expected)
|
||||
if !actual.Equal(expected) {
|
||||
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.spec, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getTimeTZ(value string) time.Time {
|
||||
if value == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
t, err := time.Parse("Mon Jan 2 15:04 2006", value)
|
||||
if err != nil {
|
||||
t, err = time.Parse("Mon Jan 2 15:04:05 2006", value)
|
||||
if err != nil {
|
||||
t, err = time.Parse("2006-01-02T15:04:05-0700", value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// https://github.com/robfig/cron/issues/144
|
||||
func TestSlash0NoHang(t *testing.T) {
|
||||
schedule := "TZ=America/New_York 15/0 * * * *"
|
||||
_, err := ParseStandard(schedule)
|
||||
if err == nil {
|
||||
t.Error("expected an error on 0 increment")
|
||||
}
|
||||
}
|
||||
507
plugin/email/README.md
Normal file
507
plugin/email/README.md
Normal file
@@ -0,0 +1,507 @@
|
||||
# Email Plugin
|
||||
|
||||
SMTP email sending functionality for self-hosted Memos instances.
|
||||
|
||||
## Overview
|
||||
|
||||
This plugin provides a simple, reliable email sending interface following industry-standard SMTP protocols. It's designed for self-hosted environments where instance administrators configure their own email service, similar to platforms like GitHub, GitLab, and Discourse.
|
||||
|
||||
## Features
|
||||
|
||||
- Standard SMTP protocol support
|
||||
- TLS/STARTTLS and SSL/TLS encryption
|
||||
- HTML and plain text emails
|
||||
- Multiple recipients (To, Cc, Bcc)
|
||||
- Synchronous and asynchronous sending
|
||||
- Detailed error reporting with context
|
||||
- Works with all major email providers
|
||||
- Reply-To header support
|
||||
- RFC 5322 compliant message formatting
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Configure SMTP Settings
|
||||
|
||||
```go
|
||||
import "github.com/usememos/memos/plugin/email"
|
||||
|
||||
config := &email.Config{
|
||||
SMTPHost: "smtp.gmail.com",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "your-email@gmail.com",
|
||||
SMTPPassword: "your-app-password",
|
||||
FromEmail: "noreply@yourdomain.com",
|
||||
FromName: "Memos",
|
||||
UseTLS: true,
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Create and Send Email
|
||||
|
||||
```go
|
||||
message := &email.Message{
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Welcome to Memos!",
|
||||
Body: "Thanks for signing up.",
|
||||
IsHTML: false,
|
||||
}
|
||||
|
||||
// Synchronous send (waits for result)
|
||||
err := email.Send(config, message)
|
||||
if err != nil {
|
||||
log.Printf("Failed to send email: %v", err)
|
||||
}
|
||||
|
||||
// Asynchronous send (returns immediately)
|
||||
email.SendAsync(config, message)
|
||||
```
|
||||
|
||||
## Provider Configuration
|
||||
|
||||
### Gmail
|
||||
|
||||
Requires an [App Password](https://support.google.com/accounts/answer/185833) (2FA must be enabled):
|
||||
|
||||
```go
|
||||
config := &email.Config{
|
||||
SMTPHost: "smtp.gmail.com",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "your-email@gmail.com",
|
||||
SMTPPassword: "your-16-char-app-password",
|
||||
FromEmail: "your-email@gmail.com",
|
||||
FromName: "Memos",
|
||||
UseTLS: true,
|
||||
}
|
||||
```
|
||||
|
||||
**Alternative (SSL):**
|
||||
```go
|
||||
config := &email.Config{
|
||||
SMTPHost: "smtp.gmail.com",
|
||||
SMTPPort: 465,
|
||||
SMTPUsername: "your-email@gmail.com",
|
||||
SMTPPassword: "your-16-char-app-password",
|
||||
FromEmail: "your-email@gmail.com",
|
||||
FromName: "Memos",
|
||||
UseSSL: true,
|
||||
}
|
||||
```
|
||||
|
||||
### SendGrid
|
||||
|
||||
```go
|
||||
config := &email.Config{
|
||||
SMTPHost: "smtp.sendgrid.net",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "apikey",
|
||||
SMTPPassword: "your-sendgrid-api-key",
|
||||
FromEmail: "noreply@yourdomain.com",
|
||||
FromName: "Memos",
|
||||
UseTLS: true,
|
||||
}
|
||||
```
|
||||
|
||||
### AWS SES
|
||||
|
||||
```go
|
||||
config := &email.Config{
|
||||
SMTPHost: "email-smtp.us-east-1.amazonaws.com",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "your-smtp-username",
|
||||
SMTPPassword: "your-smtp-password",
|
||||
FromEmail: "verified@yourdomain.com",
|
||||
FromName: "Memos",
|
||||
UseTLS: true,
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** Replace `us-east-1` with your AWS region. Email must be verified in SES.
|
||||
|
||||
### Mailgun
|
||||
|
||||
```go
|
||||
config := &email.Config{
|
||||
SMTPHost: "smtp.mailgun.org",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "postmaster@yourdomain.com",
|
||||
SMTPPassword: "your-mailgun-smtp-password",
|
||||
FromEmail: "noreply@yourdomain.com",
|
||||
FromName: "Memos",
|
||||
UseTLS: true,
|
||||
}
|
||||
```
|
||||
|
||||
### Self-Hosted SMTP (Postfix, Exim, etc.)
|
||||
|
||||
```go
|
||||
config := &email.Config{
|
||||
SMTPHost: "mail.yourdomain.com",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "username",
|
||||
SMTPPassword: "password",
|
||||
FromEmail: "noreply@yourdomain.com",
|
||||
FromName: "Memos",
|
||||
UseTLS: true,
|
||||
}
|
||||
```
|
||||
|
||||
## HTML Emails
|
||||
|
||||
```go
|
||||
message := &email.Message{
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Welcome to Memos!",
|
||||
Body: `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
</head>
|
||||
<body style="font-family: Arial, sans-serif;">
|
||||
<h1 style="color: #333;">Welcome to Memos!</h1>
|
||||
<p>We're excited to have you on board.</p>
|
||||
<a href="https://yourdomain.com" style="background-color: #4CAF50; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">Get Started</a>
|
||||
</body>
|
||||
</html>
|
||||
`,
|
||||
IsHTML: true,
|
||||
}
|
||||
|
||||
email.Send(config, message)
|
||||
```
|
||||
|
||||
## Multiple Recipients
|
||||
|
||||
```go
|
||||
message := &email.Message{
|
||||
To: []string{"user1@example.com", "user2@example.com"},
|
||||
Cc: []string{"manager@example.com"},
|
||||
Bcc: []string{"admin@example.com"},
|
||||
Subject: "Team Update",
|
||||
Body: "Important team announcement...",
|
||||
ReplyTo: "support@yourdomain.com",
|
||||
}
|
||||
|
||||
email.Send(config, message)
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Run Tests
|
||||
|
||||
```bash
|
||||
# All tests
|
||||
go test ./plugin/email/... -v
|
||||
|
||||
# With coverage
|
||||
go test ./plugin/email/... -v -cover
|
||||
|
||||
# With race detector
|
||||
go test ./plugin/email/... -race
|
||||
```
|
||||
|
||||
### Manual Testing
|
||||
|
||||
Create a simple test program:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"github.com/usememos/memos/plugin/email"
|
||||
)
|
||||
|
||||
func main() {
|
||||
config := &email.Config{
|
||||
SMTPHost: "smtp.gmail.com",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "your-email@gmail.com",
|
||||
SMTPPassword: "your-app-password",
|
||||
FromEmail: "your-email@gmail.com",
|
||||
FromName: "Test",
|
||||
UseTLS: true,
|
||||
}
|
||||
|
||||
message := &email.Message{
|
||||
To: []string{"recipient@example.com"},
|
||||
Subject: "Test Email",
|
||||
Body: "This is a test email from Memos email plugin.",
|
||||
}
|
||||
|
||||
if err := email.Send(config, message); err != nil {
|
||||
log.Fatalf("Failed to send email: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Email sent successfully!")
|
||||
}
|
||||
```
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
### 1. Use TLS/SSL Encryption
|
||||
|
||||
Always enable encryption in production:
|
||||
|
||||
```go
|
||||
// STARTTLS (port 587) - Recommended
|
||||
config.UseTLS = true
|
||||
|
||||
// SSL/TLS (port 465)
|
||||
config.UseSSL = true
|
||||
```
|
||||
|
||||
### 2. Secure Credential Storage
|
||||
|
||||
Never hardcode credentials. Use environment variables:
|
||||
|
||||
```go
|
||||
import "os"
|
||||
|
||||
config := &email.Config{
|
||||
SMTPHost: os.Getenv("SMTP_HOST"),
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: os.Getenv("SMTP_USERNAME"),
|
||||
SMTPPassword: os.Getenv("SMTP_PASSWORD"),
|
||||
FromEmail: os.Getenv("SMTP_FROM_EMAIL"),
|
||||
UseTLS: true,
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Use App-Specific Passwords
|
||||
|
||||
For Gmail and similar services, use app passwords instead of your main account password.
|
||||
|
||||
### 4. Validate and Sanitize Input
|
||||
|
||||
Always validate email addresses and sanitize content:
|
||||
|
||||
```go
|
||||
// Validate before sending
|
||||
if err := message.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Implement Rate Limiting
|
||||
|
||||
Prevent abuse by limiting email sending:
|
||||
|
||||
```go
|
||||
// Example using golang.org/x/time/rate
|
||||
limiter := rate.NewLimiter(rate.Every(time.Second), 10) // 10 emails per second
|
||||
|
||||
if !limiter.Allow() {
|
||||
return errors.New("rate limit exceeded")
|
||||
}
|
||||
```
|
||||
|
||||
### 6. Monitor and Log
|
||||
|
||||
Log email sending activity for security monitoring:
|
||||
|
||||
```go
|
||||
if err := email.Send(config, message); err != nil {
|
||||
slog.Error("Email send failed",
|
||||
slog.String("recipient", message.To[0]),
|
||||
slog.Any("error", err))
|
||||
}
|
||||
```
|
||||
|
||||
## Common Ports
|
||||
|
||||
| Port | Protocol | Security | Use Case |
|
||||
|------|----------|----------|----------|
|
||||
| **587** | SMTP + STARTTLS | Encrypted | **Recommended** for most providers |
|
||||
| **465** | SMTP over SSL/TLS | Encrypted | Alternative secure option |
|
||||
| **25** | SMTP | Unencrypted | Legacy, often blocked by ISPs |
|
||||
| **2525** | SMTP + STARTTLS | Encrypted | Alternative when 587 is blocked |
|
||||
|
||||
**Port 587 (STARTTLS)** is the recommended standard for modern SMTP:
|
||||
```go
|
||||
config := &email.Config{
|
||||
SMTPPort: 587,
|
||||
UseTLS: true,
|
||||
}
|
||||
```
|
||||
|
||||
**Port 465 (SSL/TLS)** is the alternative:
|
||||
```go
|
||||
config := &email.Config{
|
||||
SMTPPort: 465,
|
||||
UseSSL: true,
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The package provides detailed, contextual errors:
|
||||
|
||||
```go
|
||||
err := email.Send(config, message)
|
||||
if err != nil {
|
||||
// Error messages include context:
|
||||
switch {
|
||||
case strings.Contains(err.Error(), "invalid email configuration"):
|
||||
// Configuration error (missing host, invalid port, etc.)
|
||||
log.Printf("Configuration error: %v", err)
|
||||
|
||||
case strings.Contains(err.Error(), "invalid email message"):
|
||||
// Message validation error (missing recipients, subject, body)
|
||||
log.Printf("Message error: %v", err)
|
||||
|
||||
case strings.Contains(err.Error(), "authentication failed"):
|
||||
// SMTP authentication failed (wrong credentials)
|
||||
log.Printf("Auth error: %v", err)
|
||||
|
||||
case strings.Contains(err.Error(), "failed to connect"):
|
||||
// Network/connection error
|
||||
log.Printf("Connection error: %v", err)
|
||||
|
||||
case strings.Contains(err.Error(), "recipient rejected"):
|
||||
// SMTP server rejected recipient
|
||||
log.Printf("Recipient error: %v", err)
|
||||
|
||||
default:
|
||||
log.Printf("Unknown error: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Common Error Messages
|
||||
|
||||
```
|
||||
❌ "invalid email configuration: SMTP host is required"
|
||||
→ Fix: Set config.SMTPHost
|
||||
|
||||
❌ "invalid email configuration: SMTP port must be between 1 and 65535"
|
||||
→ Fix: Set valid config.SMTPPort (usually 587 or 465)
|
||||
|
||||
❌ "invalid email configuration: from email is required"
|
||||
→ Fix: Set config.FromEmail
|
||||
|
||||
❌ "invalid email message: at least one recipient is required"
|
||||
→ Fix: Set message.To with at least one email address
|
||||
|
||||
❌ "invalid email message: subject is required"
|
||||
→ Fix: Set message.Subject
|
||||
|
||||
❌ "invalid email message: body is required"
|
||||
→ Fix: Set message.Body
|
||||
|
||||
❌ "SMTP authentication failed"
|
||||
→ Fix: Check credentials (username/password)
|
||||
|
||||
❌ "failed to connect to SMTP server"
|
||||
→ Fix: Verify host/port, check firewall, ensure TLS/SSL settings match server
|
||||
```
|
||||
|
||||
### Async Error Handling
|
||||
|
||||
For async sending, errors are logged automatically:
|
||||
|
||||
```go
|
||||
email.SendAsync(config, message)
|
||||
// Errors logged as:
|
||||
// [WARN] Failed to send email asynchronously recipients=user@example.com error=...
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Required
|
||||
|
||||
- **Go 1.25+**
|
||||
- Standard library: `net/smtp`, `crypto/tls`
|
||||
- `github.com/pkg/errors` - Error wrapping with context
|
||||
|
||||
### No External SMTP Libraries
|
||||
|
||||
This plugin uses Go's standard `net/smtp` library for maximum compatibility and minimal dependencies.
|
||||
|
||||
## API Reference
|
||||
|
||||
### Types
|
||||
|
||||
#### `Config`
|
||||
```go
|
||||
type Config struct {
|
||||
SMTPHost string // SMTP server hostname
|
||||
SMTPPort int // SMTP server port
|
||||
SMTPUsername string // SMTP auth username
|
||||
SMTPPassword string // SMTP auth password
|
||||
FromEmail string // From email address
|
||||
FromName string // From display name (optional)
|
||||
UseTLS bool // Enable STARTTLS (port 587)
|
||||
UseSSL bool // Enable SSL/TLS (port 465)
|
||||
}
|
||||
```
|
||||
|
||||
#### `Message`
|
||||
```go
|
||||
type Message struct {
|
||||
To []string // Recipients
|
||||
Cc []string // CC recipients (optional)
|
||||
Bcc []string // BCC recipients (optional)
|
||||
Subject string // Email subject
|
||||
Body string // Email body (plain text or HTML)
|
||||
IsHTML bool // true for HTML, false for plain text
|
||||
ReplyTo string // Reply-To address (optional)
|
||||
}
|
||||
```
|
||||
|
||||
### Functions
|
||||
|
||||
#### `Send(config *Config, message *Message) error`
|
||||
Sends an email synchronously. Blocks until email is sent or error occurs.
|
||||
|
||||
#### `SendAsync(config *Config, message *Message)`
|
||||
Sends an email asynchronously in a goroutine. Returns immediately. Errors are logged.
|
||||
|
||||
#### `NewClient(config *Config) *Client`
|
||||
Creates a new SMTP client for advanced usage.
|
||||
|
||||
#### `Client.Send(message *Message) error`
|
||||
Sends email using the client's configuration.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
plugin/email/
|
||||
├── config.go # SMTP configuration types
|
||||
├── message.go # Email message types and formatting
|
||||
├── client.go # SMTP client implementation
|
||||
├── email.go # High-level Send/SendAsync API
|
||||
├── doc.go # Package documentation
|
||||
└── *_test.go # Unit tests
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Part of the Memos project. See main repository for license details.
|
||||
|
||||
## Contributing
|
||||
|
||||
This plugin follows the Memos contribution guidelines. Please ensure:
|
||||
|
||||
1. All code is tested (TDD approach)
|
||||
2. Tests pass: `go test ./plugin/email/... -v`
|
||||
3. Code is formatted: `go fmt ./plugin/email/...`
|
||||
4. No linting errors: `golangci-lint run ./plugin/email/...`
|
||||
|
||||
## Support
|
||||
|
||||
For issues and questions:
|
||||
|
||||
- Memos GitHub Issues: https://github.com/usememos/memos/issues
|
||||
- Memos Documentation: https://usememos.com/docs
|
||||
|
||||
## Roadmap
|
||||
|
||||
Future enhancements may include:
|
||||
|
||||
- Email template system
|
||||
- Attachment support
|
||||
- Inline image embedding
|
||||
- Email queuing system
|
||||
- Delivery status tracking
|
||||
- Bounce handling
|
||||
143
plugin/email/client.go
Normal file
143
plugin/email/client.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/smtp"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Client represents an SMTP email client.
|
||||
type Client struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
// NewClient creates a new email client with the given configuration.
|
||||
func NewClient(config *Config) *Client {
|
||||
return &Client{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// validateConfig validates the client configuration.
|
||||
func (c *Client) validateConfig() error {
|
||||
if c.config == nil {
|
||||
return errors.New("email configuration is required")
|
||||
}
|
||||
return c.config.Validate()
|
||||
}
|
||||
|
||||
// createAuth creates an SMTP auth mechanism if credentials are provided.
|
||||
func (c *Client) createAuth() smtp.Auth {
|
||||
if c.config.SMTPUsername == "" && c.config.SMTPPassword == "" {
|
||||
return nil
|
||||
}
|
||||
return smtp.PlainAuth("", c.config.SMTPUsername, c.config.SMTPPassword, c.config.SMTPHost)
|
||||
}
|
||||
|
||||
// createTLSConfig creates a TLS configuration for secure connections.
|
||||
func (c *Client) createTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
ServerName: c.config.SMTPHost,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
// Send sends an email message via SMTP.
|
||||
func (c *Client) Send(message *Message) error {
|
||||
// Validate configuration
|
||||
if err := c.validateConfig(); err != nil {
|
||||
return errors.Wrap(err, "invalid email configuration")
|
||||
}
|
||||
|
||||
// Validate message
|
||||
if message == nil {
|
||||
return errors.New("message is required")
|
||||
}
|
||||
if err := message.Validate(); err != nil {
|
||||
return errors.Wrap(err, "invalid email message")
|
||||
}
|
||||
|
||||
// Format the message
|
||||
body := message.Format(c.config.FromEmail, c.config.FromName)
|
||||
|
||||
// Get all recipients
|
||||
recipients := message.GetAllRecipients()
|
||||
|
||||
// Create auth
|
||||
auth := c.createAuth()
|
||||
|
||||
// Send based on encryption type
|
||||
if c.config.UseSSL {
|
||||
return c.sendWithSSL(auth, recipients, body)
|
||||
}
|
||||
return c.sendWithTLS(auth, recipients, body)
|
||||
}
|
||||
|
||||
// sendWithTLS sends email using STARTTLS (port 587).
|
||||
func (c *Client) sendWithTLS(auth smtp.Auth, recipients []string, body string) error {
|
||||
serverAddr := c.config.GetServerAddress()
|
||||
|
||||
if c.config.UseTLS {
|
||||
// Use STARTTLS
|
||||
return smtp.SendMail(serverAddr, auth, c.config.FromEmail, recipients, []byte(body))
|
||||
}
|
||||
|
||||
// Send without encryption (not recommended)
|
||||
return smtp.SendMail(serverAddr, auth, c.config.FromEmail, recipients, []byte(body))
|
||||
}
|
||||
|
||||
// sendWithSSL sends email using SSL/TLS (port 465).
|
||||
func (c *Client) sendWithSSL(auth smtp.Auth, recipients []string, body string) error {
|
||||
serverAddr := c.config.GetServerAddress()
|
||||
|
||||
// Create TLS connection
|
||||
tlsConfig := c.createTLSConfig()
|
||||
conn, err := tls.Dial("tcp", serverAddr, tlsConfig)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to connect to SMTP server with SSL: %s", serverAddr)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Create SMTP client
|
||||
client, err := smtp.NewClient(conn, c.config.SMTPHost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create SMTP client")
|
||||
}
|
||||
defer client.Quit()
|
||||
|
||||
// Authenticate
|
||||
if auth != nil {
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return errors.Wrap(err, "SMTP authentication failed")
|
||||
}
|
||||
}
|
||||
|
||||
// Set sender
|
||||
if err := client.Mail(c.config.FromEmail); err != nil {
|
||||
return errors.Wrap(err, "failed to set sender")
|
||||
}
|
||||
|
||||
// Set recipients
|
||||
for _, recipient := range recipients {
|
||||
if err := client.Rcpt(recipient); err != nil {
|
||||
return errors.Wrapf(err, "failed to set recipient: %s", recipient)
|
||||
}
|
||||
}
|
||||
|
||||
// Send message body
|
||||
writer, err := client.Data()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to send DATA command")
|
||||
}
|
||||
|
||||
if _, err := writer.Write([]byte(body)); err != nil {
|
||||
return errors.Wrap(err, "failed to write message body")
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
return errors.Wrap(err, "failed to close message writer")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
121
plugin/email/client_test.go
Normal file
121
plugin/email/client_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
config := &Config{
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "user@example.com",
|
||||
SMTPPassword: "password",
|
||||
FromEmail: "noreply@example.com",
|
||||
FromName: "Test App",
|
||||
UseTLS: true,
|
||||
}
|
||||
|
||||
client := NewClient(config)
|
||||
|
||||
assert.NotNil(t, client)
|
||||
assert.Equal(t, config, client.config)
|
||||
}
|
||||
|
||||
func TestClientValidateConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &Config{
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
FromEmail: "test@example.com",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid config",
|
||||
config: &Config{
|
||||
SMTPHost: "",
|
||||
SMTPPort: 587,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client := NewClient(tt.config)
|
||||
err := client.validateConfig()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSendValidation(t *testing.T) {
|
||||
config := &Config{
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
FromEmail: "test@example.com",
|
||||
}
|
||||
client := NewClient(config)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message *Message
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid message",
|
||||
message: &Message{
|
||||
To: []string{"recipient@example.com"},
|
||||
Subject: "Test",
|
||||
Body: "Test body",
|
||||
},
|
||||
wantErr: false, // Will fail on actual send, but passes validation
|
||||
},
|
||||
{
|
||||
name: "nil message",
|
||||
message: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid message",
|
||||
message: &Message{
|
||||
To: []string{},
|
||||
Subject: "Test",
|
||||
Body: "Test",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := client.Send(tt.message)
|
||||
// We expect validation errors for invalid messages
|
||||
// For valid messages, we'll get connection errors (which is expected in tests)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
// Should fail validation before attempting connection
|
||||
assert.NotContains(t, err.Error(), "dial")
|
||||
}
|
||||
// Note: We don't assert NoError for valid messages because
|
||||
// we don't have a real SMTP server in tests
|
||||
})
|
||||
}
|
||||
}
|
||||
47
plugin/email/config.go
Normal file
47
plugin/email/config.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Config represents the SMTP configuration for email sending.
|
||||
// These settings should be provided by the self-hosted instance administrator.
|
||||
type Config struct {
|
||||
// SMTPHost is the SMTP server hostname (e.g., "smtp.gmail.com")
|
||||
SMTPHost string
|
||||
// SMTPPort is the SMTP server port (common: 587 for TLS, 465 for SSL, 25 for unencrypted)
|
||||
SMTPPort int
|
||||
// SMTPUsername is the SMTP authentication username (usually the email address)
|
||||
SMTPUsername string
|
||||
// SMTPPassword is the SMTP authentication password or app-specific password
|
||||
SMTPPassword string
|
||||
// FromEmail is the email address that will appear in the "From" field
|
||||
FromEmail string
|
||||
// FromName is the display name that will appear in the "From" field
|
||||
FromName string
|
||||
// UseTLS enables STARTTLS encryption (recommended for port 587)
|
||||
UseTLS bool
|
||||
// UseSSL enables SSL/TLS encryption (for port 465)
|
||||
UseSSL bool
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid.
|
||||
func (c *Config) Validate() error {
|
||||
if c.SMTPHost == "" {
|
||||
return errors.New("SMTP host is required")
|
||||
}
|
||||
if c.SMTPPort <= 0 || c.SMTPPort > 65535 {
|
||||
return errors.New("SMTP port must be between 1 and 65535")
|
||||
}
|
||||
if c.FromEmail == "" {
|
||||
return errors.New("from email is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServerAddress returns the SMTP server address in the format "host:port".
|
||||
func (c *Config) GetServerAddress() string {
|
||||
return fmt.Sprintf("%s:%d", c.SMTPHost, c.SMTPPort)
|
||||
}
|
||||
80
plugin/email/config_test.go
Normal file
80
plugin/email/config_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfigValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &Config{
|
||||
SMTPHost: "smtp.gmail.com",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "user@example.com",
|
||||
SMTPPassword: "password",
|
||||
FromEmail: "noreply@example.com",
|
||||
FromName: "Memos",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing host",
|
||||
config: &Config{
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "user@example.com",
|
||||
SMTPPassword: "password",
|
||||
FromEmail: "noreply@example.com",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid port",
|
||||
config: &Config{
|
||||
SMTPHost: "smtp.gmail.com",
|
||||
SMTPPort: 0,
|
||||
SMTPUsername: "user@example.com",
|
||||
SMTPPassword: "password",
|
||||
FromEmail: "noreply@example.com",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing from email",
|
||||
config: &Config{
|
||||
SMTPHost: "smtp.gmail.com",
|
||||
SMTPPort: 587,
|
||||
SMTPUsername: "user@example.com",
|
||||
SMTPPassword: "password",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigGetServerAddress(t *testing.T) {
|
||||
config := &Config{
|
||||
SMTPHost: "smtp.gmail.com",
|
||||
SMTPPort: 587,
|
||||
}
|
||||
|
||||
expected := "smtp.gmail.com:587"
|
||||
assert.Equal(t, expected, config.GetServerAddress())
|
||||
}
|
||||
98
plugin/email/doc.go
Normal file
98
plugin/email/doc.go
Normal file
@@ -0,0 +1,98 @@
|
||||
// Package email provides SMTP email sending functionality for self-hosted Memos instances.
|
||||
//
|
||||
// This package is designed for self-hosted environments where instance administrators
|
||||
// configure their own SMTP servers. It follows industry-standard patterns used by
|
||||
// platforms like GitHub, GitLab, and Discourse.
|
||||
//
|
||||
// # Configuration
|
||||
//
|
||||
// The package requires SMTP server configuration provided by the instance administrator:
|
||||
//
|
||||
// config := &email.Config{
|
||||
// SMTPHost: "smtp.gmail.com",
|
||||
// SMTPPort: 587,
|
||||
// SMTPUsername: "your-email@gmail.com",
|
||||
// SMTPPassword: "your-app-password",
|
||||
// FromEmail: "noreply@yourdomain.com",
|
||||
// FromName: "Memos Notifications",
|
||||
// UseTLS: true,
|
||||
// }
|
||||
//
|
||||
// # Common SMTP Settings
|
||||
//
|
||||
// Gmail (requires App Password):
|
||||
// - Host: smtp.gmail.com
|
||||
// - Port: 587 (TLS) or 465 (SSL)
|
||||
// - Username: your-email@gmail.com
|
||||
// - UseTLS: true (for port 587) or UseSSL: true (for port 465)
|
||||
//
|
||||
// SendGrid:
|
||||
// - Host: smtp.sendgrid.net
|
||||
// - Port: 587
|
||||
// - Username: apikey
|
||||
// - Password: your-sendgrid-api-key
|
||||
// - UseTLS: true
|
||||
//
|
||||
// AWS SES:
|
||||
// - Host: email-smtp.[region].amazonaws.com
|
||||
// - Port: 587
|
||||
// - Username: your-smtp-username
|
||||
// - Password: your-smtp-password
|
||||
// - UseTLS: true
|
||||
//
|
||||
// Mailgun:
|
||||
// - Host: smtp.mailgun.org
|
||||
// - Port: 587
|
||||
// - Username: your-mailgun-smtp-username
|
||||
// - Password: your-mailgun-smtp-password
|
||||
// - UseTLS: true
|
||||
//
|
||||
// # Sending Email
|
||||
//
|
||||
// Synchronous (waits for completion):
|
||||
//
|
||||
// message := &email.Message{
|
||||
// To: []string{"user@example.com"},
|
||||
// Subject: "Welcome to Memos",
|
||||
// Body: "Thank you for joining!",
|
||||
// IsHTML: false,
|
||||
// }
|
||||
//
|
||||
// err := email.Send(config, message)
|
||||
// if err != nil {
|
||||
// // Handle error
|
||||
// }
|
||||
//
|
||||
// Asynchronous (returns immediately):
|
||||
//
|
||||
// email.SendAsync(config, message)
|
||||
// // Errors are logged but not returned
|
||||
//
|
||||
// # HTML Email
|
||||
//
|
||||
// message := &email.Message{
|
||||
// To: []string{"user@example.com"},
|
||||
// Subject: "Welcome!",
|
||||
// Body: "<html><body><h1>Welcome to Memos!</h1></body></html>",
|
||||
// IsHTML: true,
|
||||
// }
|
||||
//
|
||||
// # Security Considerations
|
||||
//
|
||||
// - Always use TLS (port 587) or SSL (port 465) for production
|
||||
// - Store SMTP credentials securely (environment variables or secrets management)
|
||||
// - Use app-specific passwords for services like Gmail
|
||||
// - Validate and sanitize email content to prevent injection attacks
|
||||
// - Rate limit email sending to prevent abuse
|
||||
//
|
||||
// # Error Handling
|
||||
//
|
||||
// The package returns descriptive errors for common issues:
|
||||
// - Configuration validation errors (missing host, invalid port, etc.)
|
||||
// - Message validation errors (missing recipients, subject, or body)
|
||||
// - Connection errors (cannot reach SMTP server)
|
||||
// - Authentication errors (invalid credentials)
|
||||
// - SMTP protocol errors (recipient rejected, etc.)
|
||||
//
|
||||
// All errors are wrapped with context using github.com/pkg/errors for better debugging.
|
||||
package email
|
||||
43
plugin/email/email.go
Normal file
43
plugin/email/email.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Send sends an email synchronously.
|
||||
// Returns an error if the email fails to send.
|
||||
func Send(config *Config, message *Message) error {
|
||||
if config == nil {
|
||||
return errors.New("email configuration is required")
|
||||
}
|
||||
if message == nil {
|
||||
return errors.New("email message is required")
|
||||
}
|
||||
|
||||
client := NewClient(config)
|
||||
return client.Send(message)
|
||||
}
|
||||
|
||||
// SendAsync sends an email asynchronously.
|
||||
// It spawns a new goroutine to handle the sending and does not wait for the response.
|
||||
// Any errors are logged but not returned.
|
||||
func SendAsync(config *Config, message *Message) {
|
||||
go func() {
|
||||
if err := Send(config, message); err != nil {
|
||||
// Since we're in a goroutine, we can only log the error
|
||||
recipients := ""
|
||||
if message != nil && len(message.To) > 0 {
|
||||
recipients = message.To[0]
|
||||
if len(message.To) > 1 {
|
||||
recipients += " and others"
|
||||
}
|
||||
}
|
||||
|
||||
slog.Warn("Failed to send email asynchronously",
|
||||
slog.String("recipients", recipients),
|
||||
slog.Any("error", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
127
plugin/email/email_test.go
Normal file
127
plugin/email/email_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func TestSend(t *testing.T) {
|
||||
config := &Config{
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
FromEmail: "test@example.com",
|
||||
}
|
||||
|
||||
message := &Message{
|
||||
To: []string{"recipient@example.com"},
|
||||
Subject: "Test",
|
||||
Body: "Test body",
|
||||
}
|
||||
|
||||
// This will fail to connect (no real server), but should validate inputs
|
||||
err := Send(config, message)
|
||||
// We expect an error because there's no real SMTP server
|
||||
// But it should be a connection error, not a validation error
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "dial")
|
||||
}
|
||||
|
||||
func TestSendValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
message *Message
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
message: &Message{To: []string{"test@example.com"}, Subject: "Test", Body: "Test"},
|
||||
wantErr: true,
|
||||
errMsg: "configuration is required",
|
||||
},
|
||||
{
|
||||
name: "nil message",
|
||||
config: &Config{SMTPHost: "smtp.example.com", SMTPPort: 587, FromEmail: "from@example.com"},
|
||||
message: nil,
|
||||
wantErr: true,
|
||||
errMsg: "message is required",
|
||||
},
|
||||
{
|
||||
name: "invalid config",
|
||||
config: &Config{
|
||||
SMTPHost: "",
|
||||
SMTPPort: 587,
|
||||
},
|
||||
message: &Message{To: []string{"test@example.com"}, Subject: "Test", Body: "Test"},
|
||||
wantErr: true,
|
||||
errMsg: "invalid email configuration",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := Send(tt.config, tt.message)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendAsync(t *testing.T) {
|
||||
config := &Config{
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
FromEmail: "test@example.com",
|
||||
}
|
||||
|
||||
message := &Message{
|
||||
To: []string{"recipient@example.com"},
|
||||
Subject: "Test Async",
|
||||
Body: "Test async body",
|
||||
}
|
||||
|
||||
// SendAsync should not block
|
||||
start := time.Now()
|
||||
SendAsync(config, message)
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should return almost immediately (< 100ms)
|
||||
assert.Less(t, duration, 100*time.Millisecond)
|
||||
|
||||
// Give goroutine time to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSendAsyncConcurrent(t *testing.T) {
|
||||
config := &Config{
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
FromEmail: "test@example.com",
|
||||
}
|
||||
|
||||
g := errgroup.Group{}
|
||||
count := 5
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
g.Go(func() error {
|
||||
message := &Message{
|
||||
To: []string{"recipient@example.com"},
|
||||
Subject: "Concurrent Test",
|
||||
Body: "Test body",
|
||||
}
|
||||
SendAsync(config, message)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
t.Fatalf("SendAsync calls failed: %v", err)
|
||||
}
|
||||
}
|
||||
91
plugin/email/message.go
Normal file
91
plugin/email/message.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Message represents an email message to be sent.
|
||||
type Message struct {
|
||||
To []string // Required: recipient email addresses
|
||||
Cc []string // Optional: carbon copy recipients
|
||||
Bcc []string // Optional: blind carbon copy recipients
|
||||
Subject string // Required: email subject
|
||||
Body string // Required: email body content
|
||||
IsHTML bool // Whether the body is HTML (default: false for plain text)
|
||||
ReplyTo string // Optional: reply-to address
|
||||
}
|
||||
|
||||
// Validate checks that the message has all required fields.
|
||||
func (m *Message) Validate() error {
|
||||
if len(m.To) == 0 {
|
||||
return errors.New("at least one recipient is required")
|
||||
}
|
||||
if m.Subject == "" {
|
||||
return errors.New("subject is required")
|
||||
}
|
||||
if m.Body == "" {
|
||||
return errors.New("body is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Format creates an RFC 5322 formatted email message.
|
||||
func (m *Message) Format(fromEmail, fromName string) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// From header
|
||||
if fromName != "" {
|
||||
sb.WriteString(fmt.Sprintf("From: %s <%s>\r\n", fromName, fromEmail))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("From: %s\r\n", fromEmail))
|
||||
}
|
||||
|
||||
// To header
|
||||
sb.WriteString(fmt.Sprintf("To: %s\r\n", strings.Join(m.To, ", ")))
|
||||
|
||||
// Cc header (optional)
|
||||
if len(m.Cc) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("Cc: %s\r\n", strings.Join(m.Cc, ", ")))
|
||||
}
|
||||
|
||||
// Reply-To header (optional)
|
||||
if m.ReplyTo != "" {
|
||||
sb.WriteString(fmt.Sprintf("Reply-To: %s\r\n", m.ReplyTo))
|
||||
}
|
||||
|
||||
// Subject header
|
||||
sb.WriteString(fmt.Sprintf("Subject: %s\r\n", m.Subject))
|
||||
|
||||
// Date header (RFC 5322 format)
|
||||
sb.WriteString(fmt.Sprintf("Date: %s\r\n", time.Now().Format(time.RFC1123Z)))
|
||||
|
||||
// MIME headers
|
||||
sb.WriteString("MIME-Version: 1.0\r\n")
|
||||
|
||||
// Content-Type header
|
||||
if m.IsHTML {
|
||||
sb.WriteString("Content-Type: text/html; charset=utf-8\r\n")
|
||||
} else {
|
||||
sb.WriteString("Content-Type: text/plain; charset=utf-8\r\n")
|
||||
}
|
||||
|
||||
// Empty line separating headers from body
|
||||
sb.WriteString("\r\n")
|
||||
|
||||
// Body
|
||||
sb.WriteString(m.Body)
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// GetAllRecipients returns all recipients (To, Cc, Bcc) as a single slice.
|
||||
func (m *Message) GetAllRecipients() []string {
|
||||
var recipients []string
|
||||
recipients = append(recipients, m.To...)
|
||||
recipients = append(recipients, m.Cc...)
|
||||
recipients = append(recipients, m.Bcc...)
|
||||
return recipients
|
||||
}
|
||||
181
plugin/email/message_test.go
Normal file
181
plugin/email/message_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg Message
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid message",
|
||||
msg: Message{
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Test Subject",
|
||||
Body: "Test Body",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no recipients",
|
||||
msg: Message{
|
||||
To: []string{},
|
||||
Subject: "Test Subject",
|
||||
Body: "Test Body",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no subject",
|
||||
msg: Message{
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "",
|
||||
Body: "Test Body",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no body",
|
||||
msg: Message{
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Test Subject",
|
||||
Body: "",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple recipients",
|
||||
msg: Message{
|
||||
To: []string{"user1@example.com", "user2@example.com"},
|
||||
Cc: []string{"cc@example.com"},
|
||||
Subject: "Test Subject",
|
||||
Body: "Test Body",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.msg.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageFormatPlainText(t *testing.T) {
|
||||
msg := Message{
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Test Subject",
|
||||
Body: "Test Body",
|
||||
IsHTML: false,
|
||||
}
|
||||
|
||||
formatted := msg.Format("sender@example.com", "Sender Name")
|
||||
|
||||
// Check required headers
|
||||
if !strings.Contains(formatted, "From: Sender Name <sender@example.com>") {
|
||||
t.Error("Missing or incorrect From header")
|
||||
}
|
||||
if !strings.Contains(formatted, "To: user@example.com") {
|
||||
t.Error("Missing or incorrect To header")
|
||||
}
|
||||
if !strings.Contains(formatted, "Subject: Test Subject") {
|
||||
t.Error("Missing or incorrect Subject header")
|
||||
}
|
||||
if !strings.Contains(formatted, "Content-Type: text/plain; charset=utf-8") {
|
||||
t.Error("Missing or incorrect Content-Type header for plain text")
|
||||
}
|
||||
if !strings.Contains(formatted, "Test Body") {
|
||||
t.Error("Missing message body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageFormatHTML(t *testing.T) {
|
||||
msg := Message{
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Test Subject",
|
||||
Body: "<html><body>Test Body</body></html>",
|
||||
IsHTML: true,
|
||||
}
|
||||
|
||||
formatted := msg.Format("sender@example.com", "Sender Name")
|
||||
|
||||
// Check HTML content-type
|
||||
if !strings.Contains(formatted, "Content-Type: text/html; charset=utf-8") {
|
||||
t.Error("Missing or incorrect Content-Type header for HTML")
|
||||
}
|
||||
if !strings.Contains(formatted, "<html><body>Test Body</body></html>") {
|
||||
t.Error("Missing HTML body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageFormatMultipleRecipients(t *testing.T) {
|
||||
msg := Message{
|
||||
To: []string{"user1@example.com", "user2@example.com"},
|
||||
Cc: []string{"cc1@example.com", "cc2@example.com"},
|
||||
Bcc: []string{"bcc@example.com"},
|
||||
Subject: "Test Subject",
|
||||
Body: "Test Body",
|
||||
ReplyTo: "reply@example.com",
|
||||
}
|
||||
|
||||
formatted := msg.Format("sender@example.com", "Sender Name")
|
||||
|
||||
// Check To header formatting
|
||||
if !strings.Contains(formatted, "To: user1@example.com, user2@example.com") {
|
||||
t.Error("Missing or incorrect To header with multiple recipients")
|
||||
}
|
||||
// Check Cc header formatting
|
||||
if !strings.Contains(formatted, "Cc: cc1@example.com, cc2@example.com") {
|
||||
t.Error("Missing or incorrect Cc header")
|
||||
}
|
||||
// Bcc should NOT appear in the formatted message
|
||||
if strings.Contains(formatted, "Bcc:") {
|
||||
t.Error("Bcc header should not appear in formatted message")
|
||||
}
|
||||
// Check Reply-To header
|
||||
if !strings.Contains(formatted, "Reply-To: reply@example.com") {
|
||||
t.Error("Missing or incorrect Reply-To header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllRecipients(t *testing.T) {
|
||||
msg := Message{
|
||||
To: []string{"user1@example.com", "user2@example.com"},
|
||||
Cc: []string{"cc@example.com"},
|
||||
Bcc: []string{"bcc@example.com"},
|
||||
}
|
||||
|
||||
recipients := msg.GetAllRecipients()
|
||||
|
||||
// Should have all 4 recipients
|
||||
if len(recipients) != 4 {
|
||||
t.Errorf("GetAllRecipients() returned %d recipients, want 4", len(recipients))
|
||||
}
|
||||
|
||||
// Check all recipients are present
|
||||
expectedRecipients := map[string]bool{
|
||||
"user1@example.com": true,
|
||||
"user2@example.com": true,
|
||||
"cc@example.com": true,
|
||||
"bcc@example.com": true,
|
||||
}
|
||||
|
||||
for _, recipient := range recipients {
|
||||
if !expectedRecipients[recipient] {
|
||||
t.Errorf("Unexpected recipient: %s", recipient)
|
||||
}
|
||||
delete(expectedRecipients, recipient)
|
||||
}
|
||||
|
||||
if len(expectedRecipients) > 0 {
|
||||
t.Error("Not all expected recipients were returned")
|
||||
}
|
||||
}
|
||||
50
plugin/filter/MAINTENANCE.md
Normal file
50
plugin/filter/MAINTENANCE.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Maintaining the Memo Filter Engine
|
||||
|
||||
The engine is memo-specific; any future field or behavior changes must stay
|
||||
consistent with the memo schema and store implementations. Use this guide when
|
||||
extending or debugging the package.
|
||||
|
||||
## Adding a New Memo Field
|
||||
|
||||
1. **Update the schema**
|
||||
- Add the field entry in `schema.go`.
|
||||
- Define the backing column (`Column`), JSON path (if applicable), type, and
|
||||
allowed operators.
|
||||
- Include the CEL variable in `EnvOptions`.
|
||||
2. **Adjust parser or renderer (if needed)**
|
||||
- For non-scalar fields (JSON booleans, lists), add handling in
|
||||
`parser.go` or extend the renderer helpers.
|
||||
- Keep validation in the parser (e.g., reject unsupported operators).
|
||||
3. **Write a golden test**
|
||||
- Extend the dialect-specific memo filter tests under
|
||||
`store/db/{sqlite,mysql,postgres}/memo_filter_test.go` with a case that
|
||||
exercises the new field.
|
||||
4. **Run `go test ./...`** to ensure the SQL output matches expectations across
|
||||
all dialects.
|
||||
|
||||
## Supporting Dialect Nuances
|
||||
|
||||
- Centralize differences inside `render.go`. If a new dialect-specific behavior
|
||||
emerges (e.g., JSON operators), add the logic there rather than leaking it
|
||||
into store code.
|
||||
- Use the renderer helpers (`jsonExtractExpr`, `jsonArrayExpr`, etc.) rather than
|
||||
sprinkling ad-hoc SQL strings.
|
||||
- When placeholders change, adjust `addArg` so that argument numbering stays in
|
||||
sync with store queries.
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
- **Parser errors** – Most originate in `buildCondition` or schema validation.
|
||||
Enable logging around `parser.go` when diagnosing unknown identifier/operator
|
||||
messages.
|
||||
- **Renderer output** – Temporary printf/log statements in `renderCondition` help
|
||||
identify which IR node produced unexpected SQL.
|
||||
- **Store integration** – Ensure drivers call `filter.DefaultEngine()` exactly once
|
||||
per process; the singleton caches the parsed CEL environment.
|
||||
|
||||
## Testing Checklist
|
||||
|
||||
- `go test ./store/...` ensures all dialect tests consume the engine correctly.
|
||||
- Add targeted unit tests whenever new IR nodes or renderer paths are introduced.
|
||||
- When changing boolean or JSON handling, verify all three dialect test suites
|
||||
(SQLite, MySQL, Postgres) to avoid regression.
|
||||
63
plugin/filter/README.md
Normal file
63
plugin/filter/README.md
Normal file
@@ -0,0 +1,63 @@
|
||||
# Memo Filter Engine
|
||||
|
||||
This package houses the memo-only filter engine that turns CEL expressions into
|
||||
SQL fragments. The engine follows a three phase pipeline inspired by systems
|
||||
such as Calcite or Prisma:
|
||||
|
||||
1. **Parsing** – CEL expressions are parsed with `cel-go` and validated against
|
||||
the memo-specific environment declared in `schema.go`. Only fields that
|
||||
exist in the schema can surface in the filter.
|
||||
2. **Normalization** – the raw CEL AST is converted into an intermediate
|
||||
representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of
|
||||
conditions (logical operators, comparisons, list membership, etc.). This
|
||||
step enforces schema rules (e.g. operator compatibility, type checks).
|
||||
3. **Rendering** – the renderer in `render.go` walks the IR and produces a SQL
|
||||
fragment plus placeholder arguments tailored to a target dialect
|
||||
(`sqlite`, `mysql`, or `postgres`). Dialect differences such as JSON access,
|
||||
boolean semantics, placeholders, and `LIKE` vs `ILIKE` are encapsulated in
|
||||
renderer helpers.
|
||||
|
||||
The entry point is `filter.DefaultEngine()` from `engine.go`. It lazily constructs
|
||||
an `Engine` configured with the memo schema and exposes:
|
||||
|
||||
```go
|
||||
engine, _ := filter.DefaultEngine()
|
||||
stmt, _ := engine.CompileToStatement(ctx, `has_task_list && visibility == "PUBLIC"`, filter.RenderOptions{
|
||||
Dialect: filter.DialectPostgres,
|
||||
})
|
||||
// stmt.SQL -> "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.visibility = $1)"
|
||||
// stmt.Args -> ["PUBLIC"]
|
||||
```
|
||||
|
||||
## Core Files
|
||||
|
||||
| File | Responsibility |
|
||||
| ------------- | ------------------------------------------------------------------------------- |
|
||||
| `schema.go` | Declares memo fields, their types, backing columns, CEL environment options |
|
||||
| `ir.go` | IR node definitions used across the pipeline |
|
||||
| `parser.go` | Converts CEL `Expr` into IR while applying schema validation |
|
||||
| `render.go` | Translates IR into SQL, handling dialect-specific behavior |
|
||||
| `engine.go` | Glue between the phases; exposes `Compile`, `CompileToStatement`, and `DefaultEngine` |
|
||||
| `helpers.go` | Convenience helpers for store integration (appending conditions) |
|
||||
|
||||
## SQL Generation Notes
|
||||
|
||||
- **Placeholders** — `?` is used for SQLite/MySQL, `$n` for Postgres. The renderer
|
||||
tracks offsets to compose queries with pre-existing arguments.
|
||||
- **JSON Fields** — Memo metadata lives in `memo.payload`. The renderer handles
|
||||
`JSON_EXTRACT`/`json_extract`/`->`/`->>` variations and boolean coercion.
|
||||
- **Tag Operations** — `tag in [...]` and `"tag" in tags` become JSON array
|
||||
predicates. SQLite uses `LIKE` patterns, MySQL uses `JSON_CONTAINS`, and
|
||||
Postgres uses `@>`.
|
||||
- **Boolean Flags** — Fields such as `has_task_list` render as `IS TRUE` equality
|
||||
checks, or comparisons against `CAST('true' AS JSON)` depending on the dialect.
|
||||
|
||||
## Typical Integration
|
||||
|
||||
1. Fetch the engine with `filter.DefaultEngine()`.
|
||||
2. Call `CompileToStatement` using the appropriate dialect enum.
|
||||
3. Append the emitted SQL fragment/args to the existing `WHERE` clause.
|
||||
4. Execute the resulting query through the store driver.
|
||||
|
||||
The `helpers.AppendConditions` helper encapsulates steps 2–3 when a driver needs
|
||||
to process an array of filters.
|
||||
191
plugin/filter/engine.go
Normal file
191
plugin/filter/engine.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Engine parses CEL filters into a dialect-agnostic condition tree.
|
||||
type Engine struct {
|
||||
schema Schema
|
||||
env *cel.Env
|
||||
}
|
||||
|
||||
// NewEngine builds a new Engine for the provided schema.
|
||||
func NewEngine(schema Schema) (*Engine, error) {
|
||||
env, err := cel.NewEnv(schema.EnvOptions...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create CEL environment")
|
||||
}
|
||||
return &Engine{
|
||||
schema: schema,
|
||||
env: env,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Program stores a compiled filter condition.
|
||||
type Program struct {
|
||||
schema Schema
|
||||
condition Condition
|
||||
}
|
||||
|
||||
// ConditionTree exposes the underlying condition tree.
|
||||
func (p *Program) ConditionTree() Condition {
|
||||
return p.condition
|
||||
}
|
||||
|
||||
// Compile parses the filter string into an executable program.
|
||||
func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
|
||||
if strings.TrimSpace(filter) == "" {
|
||||
return nil, errors.New("filter expression is empty")
|
||||
}
|
||||
|
||||
filter = normalizeLegacyFilter(filter)
|
||||
|
||||
ast, issues := e.env.Compile(filter)
|
||||
if issues != nil && issues.Err() != nil {
|
||||
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
|
||||
}
|
||||
parsed, err := cel.AstToParsedExpr(ast)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert AST")
|
||||
}
|
||||
|
||||
cond, err := buildCondition(parsed.GetExpr(), e.schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Program{
|
||||
schema: e.schema,
|
||||
condition: cond,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CompileToStatement compiles and renders the filter in a single step.
|
||||
func (e *Engine) CompileToStatement(ctx context.Context, filter string, opts RenderOptions) (Statement, error) {
|
||||
program, err := e.Compile(ctx, filter)
|
||||
if err != nil {
|
||||
return Statement{}, err
|
||||
}
|
||||
return program.Render(opts)
|
||||
}
|
||||
|
||||
// RenderOptions configure SQL rendering.
|
||||
type RenderOptions struct {
|
||||
Dialect DialectName
|
||||
PlaceholderOffset int
|
||||
DisableNullChecks bool
|
||||
}
|
||||
|
||||
// Statement contains the rendered SQL fragment and its args.
|
||||
type Statement struct {
|
||||
SQL string
|
||||
Args []any
|
||||
}
|
||||
|
||||
// Render converts the program into a dialect-specific SQL fragment.
|
||||
func (p *Program) Render(opts RenderOptions) (Statement, error) {
|
||||
renderer := newRenderer(p.schema, opts)
|
||||
return renderer.Render(p.condition)
|
||||
}
|
||||
|
||||
var (
|
||||
defaultOnce sync.Once
|
||||
defaultInst *Engine
|
||||
defaultErr error
|
||||
defaultAttachmentOnce sync.Once
|
||||
defaultAttachmentInst *Engine
|
||||
defaultAttachmentErr error
|
||||
)
|
||||
|
||||
// DefaultEngine returns the process-wide memo filter engine.
|
||||
func DefaultEngine() (*Engine, error) {
|
||||
defaultOnce.Do(func() {
|
||||
defaultInst, defaultErr = NewEngine(NewSchema())
|
||||
})
|
||||
return defaultInst, defaultErr
|
||||
}
|
||||
|
||||
// DefaultAttachmentEngine returns the process-wide attachment filter engine.
|
||||
func DefaultAttachmentEngine() (*Engine, error) {
|
||||
defaultAttachmentOnce.Do(func() {
|
||||
defaultAttachmentInst, defaultAttachmentErr = NewEngine(NewAttachmentSchema())
|
||||
})
|
||||
return defaultAttachmentInst, defaultAttachmentErr
|
||||
}
|
||||
|
||||
func normalizeLegacyFilter(expr string) string {
|
||||
expr = rewriteNumericLogicalOperand(expr, "&&")
|
||||
expr = rewriteNumericLogicalOperand(expr, "||")
|
||||
return expr
|
||||
}
|
||||
|
||||
func rewriteNumericLogicalOperand(expr, op string) string {
|
||||
var builder strings.Builder
|
||||
n := len(expr)
|
||||
i := 0
|
||||
var inQuote rune
|
||||
|
||||
for i < n {
|
||||
ch := expr[i]
|
||||
|
||||
if inQuote != 0 {
|
||||
builder.WriteByte(ch)
|
||||
if ch == '\\' && i+1 < n {
|
||||
builder.WriteByte(expr[i+1])
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if ch == byte(inQuote) {
|
||||
inQuote = 0
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\'' || ch == '"' {
|
||||
inQuote = rune(ch)
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(expr[i:], op) {
|
||||
builder.WriteString(op)
|
||||
i += len(op)
|
||||
|
||||
// Preserve whitespace following the operator.
|
||||
wsStart := i
|
||||
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
|
||||
i++
|
||||
}
|
||||
builder.WriteString(expr[wsStart:i])
|
||||
|
||||
signStart := i
|
||||
if i < n && (expr[i] == '+' || expr[i] == '-') {
|
||||
i++
|
||||
}
|
||||
for i < n && expr[i] >= '0' && expr[i] <= '9' {
|
||||
i++
|
||||
}
|
||||
if i > signStart {
|
||||
numLiteral := expr[signStart:i]
|
||||
builder.WriteString(fmt.Sprintf("(%s != 0)", numLiteral))
|
||||
} else {
|
||||
builder.WriteString(expr[signStart:i])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
25
plugin/filter/helpers.go
Normal file
25
plugin/filter/helpers.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// AppendConditions compiles the provided filters and appends the resulting SQL fragments and args.
|
||||
func AppendConditions(ctx context.Context, engine *Engine, filters []string, dialect DialectName, where *[]string, args *[]any) error {
|
||||
for _, filterStr := range filters {
|
||||
stmt, err := engine.CompileToStatement(ctx, filterStr, RenderOptions{
|
||||
Dialect: dialect,
|
||||
PlaceholderOffset: len(*args),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stmt.SQL == "" {
|
||||
continue
|
||||
}
|
||||
*where = append(*where, fmt.Sprintf("(%s)", stmt.SQL))
|
||||
*args = append(*args, stmt.Args...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
159
plugin/filter/ir.go
Normal file
159
plugin/filter/ir.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package filter
|
||||
|
||||
// Condition represents a boolean expression derived from the CEL filter.
|
||||
type Condition interface {
|
||||
isCondition()
|
||||
}
|
||||
|
||||
// LogicalOperator enumerates the supported logical operators.
|
||||
type LogicalOperator string
|
||||
|
||||
const (
|
||||
LogicalAnd LogicalOperator = "AND"
|
||||
LogicalOr LogicalOperator = "OR"
|
||||
)
|
||||
|
||||
// LogicalCondition composes two conditions with a logical operator.
|
||||
type LogicalCondition struct {
|
||||
Operator LogicalOperator
|
||||
Left Condition
|
||||
Right Condition
|
||||
}
|
||||
|
||||
func (*LogicalCondition) isCondition() {}
|
||||
|
||||
// NotCondition negates a child condition.
|
||||
type NotCondition struct {
|
||||
Expr Condition
|
||||
}
|
||||
|
||||
func (*NotCondition) isCondition() {}
|
||||
|
||||
// FieldPredicateCondition asserts that a field evaluates to true.
|
||||
type FieldPredicateCondition struct {
|
||||
Field string
|
||||
}
|
||||
|
||||
func (*FieldPredicateCondition) isCondition() {}
|
||||
|
||||
// ComparisonOperator lists supported comparison operators.
|
||||
type ComparisonOperator string
|
||||
|
||||
const (
|
||||
CompareEq ComparisonOperator = "="
|
||||
CompareNeq ComparisonOperator = "!="
|
||||
CompareLt ComparisonOperator = "<"
|
||||
CompareLte ComparisonOperator = "<="
|
||||
CompareGt ComparisonOperator = ">"
|
||||
CompareGte ComparisonOperator = ">="
|
||||
)
|
||||
|
||||
// ComparisonCondition represents a binary comparison.
|
||||
type ComparisonCondition struct {
|
||||
Left ValueExpr
|
||||
Operator ComparisonOperator
|
||||
Right ValueExpr
|
||||
}
|
||||
|
||||
func (*ComparisonCondition) isCondition() {}
|
||||
|
||||
// InCondition represents an IN predicate with literal list values.
|
||||
type InCondition struct {
|
||||
Left ValueExpr
|
||||
Values []ValueExpr
|
||||
}
|
||||
|
||||
func (*InCondition) isCondition() {}
|
||||
|
||||
// ElementInCondition represents the CEL syntax `"value" in field`.
|
||||
type ElementInCondition struct {
|
||||
Element ValueExpr
|
||||
Field string
|
||||
}
|
||||
|
||||
func (*ElementInCondition) isCondition() {}
|
||||
|
||||
// ContainsCondition models the <field>.contains(<value>) call.
|
||||
type ContainsCondition struct {
|
||||
Field string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (*ContainsCondition) isCondition() {}
|
||||
|
||||
// ConstantCondition captures a literal boolean outcome.
|
||||
type ConstantCondition struct {
|
||||
Value bool
|
||||
}
|
||||
|
||||
func (*ConstantCondition) isCondition() {}
|
||||
|
||||
// ValueExpr models arithmetic or scalar expressions whose result feeds a comparison.
|
||||
type ValueExpr interface {
|
||||
isValueExpr()
|
||||
}
|
||||
|
||||
// FieldRef references a named schema field.
|
||||
type FieldRef struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (*FieldRef) isValueExpr() {}
|
||||
|
||||
// LiteralValue holds a literal scalar.
|
||||
type LiteralValue struct {
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (*LiteralValue) isValueExpr() {}
|
||||
|
||||
// FunctionValue captures simple function calls like size(tags).
|
||||
type FunctionValue struct {
|
||||
Name string
|
||||
Args []ValueExpr
|
||||
}
|
||||
|
||||
func (*FunctionValue) isValueExpr() {}
|
||||
|
||||
// ListComprehensionCondition represents CEL macros like exists(), all(), filter().
|
||||
type ListComprehensionCondition struct {
|
||||
Kind ComprehensionKind
|
||||
Field string // The list field to iterate over (e.g., "tags")
|
||||
IterVar string // The iteration variable name (e.g., "t")
|
||||
Predicate PredicateExpr // The predicate to evaluate for each element
|
||||
}
|
||||
|
||||
func (*ListComprehensionCondition) isCondition() {}
|
||||
|
||||
// ComprehensionKind enumerates the types of list comprehensions.
|
||||
type ComprehensionKind string
|
||||
|
||||
const (
|
||||
ComprehensionExists ComprehensionKind = "exists"
|
||||
)
|
||||
|
||||
// PredicateExpr represents predicates used in comprehensions.
|
||||
type PredicateExpr interface {
|
||||
isPredicateExpr()
|
||||
}
|
||||
|
||||
// StartsWithPredicate represents t.startsWith("prefix").
|
||||
type StartsWithPredicate struct {
|
||||
Prefix string
|
||||
}
|
||||
|
||||
func (*StartsWithPredicate) isPredicateExpr() {}
|
||||
|
||||
// EndsWithPredicate represents t.endsWith("suffix").
|
||||
type EndsWithPredicate struct {
|
||||
Suffix string
|
||||
}
|
||||
|
||||
func (*EndsWithPredicate) isPredicateExpr() {}
|
||||
|
||||
// ContainsPredicate represents t.contains("substring").
|
||||
type ContainsPredicate struct {
|
||||
Substring string
|
||||
}
|
||||
|
||||
func (*ContainsPredicate) isPredicateExpr() {}
|
||||
586
plugin/filter/parser.go
Normal file
586
plugin/filter/parser.go
Normal file
@@ -0,0 +1,586 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
|
||||
switch v := expr.ExprKind.(type) {
|
||||
case *exprv1.Expr_CallExpr:
|
||||
return buildCallCondition(v.CallExpr, schema)
|
||||
case *exprv1.Expr_ConstExpr:
|
||||
val, err := getConstValue(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
return &ConstantCondition{Value: v}, nil
|
||||
case int64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
case float64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
default:
|
||||
return nil, errors.New("filter must evaluate to a boolean value")
|
||||
}
|
||||
case *exprv1.Expr_IdentExpr:
|
||||
name := v.IdentExpr.GetName()
|
||||
field, ok := schema.Field(name)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", name)
|
||||
}
|
||||
if field.Type != FieldTypeBool {
|
||||
return nil, errors.Errorf("identifier %q is not boolean", name)
|
||||
}
|
||||
return &FieldPredicateCondition{Field: name}, nil
|
||||
case *exprv1.Expr_ComprehensionExpr:
|
||||
return buildComprehensionCondition(v.ComprehensionExpr, schema)
|
||||
default:
|
||||
return nil, errors.New("unsupported top-level expression")
|
||||
}
|
||||
}
|
||||
|
||||
func buildCallCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
switch call.Function {
|
||||
case "_&&_":
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("logical AND expects two arguments")
|
||||
}
|
||||
left, err := buildCondition(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := buildCondition(call.Args[1], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LogicalCondition{
|
||||
Operator: LogicalAnd,
|
||||
Left: left,
|
||||
Right: right,
|
||||
}, nil
|
||||
case "_||_":
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("logical OR expects two arguments")
|
||||
}
|
||||
left, err := buildCondition(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := buildCondition(call.Args[1], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LogicalCondition{
|
||||
Operator: LogicalOr,
|
||||
Left: left,
|
||||
Right: right,
|
||||
}, nil
|
||||
case "!_":
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("logical NOT expects one argument")
|
||||
}
|
||||
child, err := buildCondition(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &NotCondition{Expr: child}, nil
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
return buildComparisonCondition(call, schema)
|
||||
case "@in":
|
||||
return buildInCondition(call, schema)
|
||||
case "contains":
|
||||
return buildContainsCondition(call, schema)
|
||||
default:
|
||||
val, ok, err := evaluateBool(call)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return &ConstantCondition{Value: val}, nil
|
||||
}
|
||||
return nil, errors.Errorf("unsupported call expression %q", call.Function)
|
||||
}
|
||||
}
|
||||
|
||||
func buildComparisonCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("comparison expects two arguments")
|
||||
}
|
||||
op, err := toComparisonOperator(call.Function)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
left, err := buildValueExpr(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := buildValueExpr(call.Args[1], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the left side is a field, validate allowed operators.
|
||||
if field, ok := left.(*FieldRef); ok {
|
||||
def, exists := schema.Field(field.Name)
|
||||
if !exists {
|
||||
return nil, errors.Errorf("unknown identifier %q", field.Name)
|
||||
}
|
||||
if def.Kind == FieldKindVirtualAlias {
|
||||
def, exists = schema.ResolveAlias(field.Name)
|
||||
if !exists {
|
||||
return nil, errors.Errorf("invalid alias %q", field.Name)
|
||||
}
|
||||
}
|
||||
if def.AllowedComparisonOps != nil {
|
||||
if _, allowed := def.AllowedComparisonOps[op]; !allowed {
|
||||
return nil, errors.Errorf("operator %s not allowed for field %q", op, field.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ComparisonCondition{
|
||||
Left: left,
|
||||
Operator: op,
|
||||
Right: right,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildInCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("in operator expects two arguments")
|
||||
}
|
||||
|
||||
// Handle identifier in list syntax.
|
||||
if identName, err := getIdentName(call.Args[0]); err == nil {
|
||||
if field, ok := schema.Field(identName); ok && field.Kind == FieldKindVirtualAlias {
|
||||
if _, aliasOk := schema.ResolveAlias(identName); !aliasOk {
|
||||
return nil, errors.Errorf("invalid alias %q", identName)
|
||||
}
|
||||
} else if !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", identName)
|
||||
}
|
||||
|
||||
if listExpr := call.Args[1].GetListExpr(); listExpr != nil {
|
||||
values := make([]ValueExpr, 0, len(listExpr.Elements))
|
||||
for _, element := range listExpr.Elements {
|
||||
value, err := buildValueExpr(element, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
return &InCondition{
|
||||
Left: &FieldRef{Name: identName},
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle "value in identifier" syntax.
|
||||
if identName, err := getIdentName(call.Args[1]); err == nil {
|
||||
if _, ok := schema.Field(identName); !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", identName)
|
||||
}
|
||||
element, err := buildValueExpr(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ElementInCondition{
|
||||
Element: element,
|
||||
Field: identName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid use of in operator")
|
||||
}
|
||||
|
||||
func buildContainsCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
if call.Target == nil {
|
||||
return nil, errors.New("contains requires a target")
|
||||
}
|
||||
targetName, err := getIdentName(call.Target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
field, ok := schema.Field(targetName)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", targetName)
|
||||
}
|
||||
if !field.SupportsContains {
|
||||
return nil, errors.Errorf("identifier %q does not support contains()", targetName)
|
||||
}
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("contains expects exactly one argument")
|
||||
}
|
||||
value, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "contains only supports literal arguments")
|
||||
}
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("contains argument must be a string")
|
||||
}
|
||||
return &ContainsCondition{
|
||||
Field: targetName,
|
||||
Value: str,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildValueExpr(expr *exprv1.Expr, schema Schema) (ValueExpr, error) {
|
||||
if identName, err := getIdentName(expr); err == nil {
|
||||
if _, ok := schema.Field(identName); !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", identName)
|
||||
}
|
||||
return &FieldRef{Name: identName}, nil
|
||||
}
|
||||
|
||||
if literal, err := getConstValue(expr); err == nil {
|
||||
return &LiteralValue{Value: literal}, nil
|
||||
}
|
||||
|
||||
if value, ok, err := evaluateNumeric(expr); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
return &LiteralValue{Value: value}, nil
|
||||
}
|
||||
|
||||
if boolVal, ok, err := evaluateBoolExpr(expr); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
return &LiteralValue{Value: boolVal}, nil
|
||||
}
|
||||
|
||||
if call := expr.GetCallExpr(); call != nil {
|
||||
switch call.Function {
|
||||
case "size":
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("size() expects one argument")
|
||||
}
|
||||
arg, err := buildValueExpr(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FunctionValue{
|
||||
Name: "size",
|
||||
Args: []ValueExpr{arg},
|
||||
}, nil
|
||||
case "now":
|
||||
return &LiteralValue{Value: timeNowUnix()}, nil
|
||||
case "_+_", "_-_", "_*_":
|
||||
value, ok, err := evaluateNumeric(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return &LiteralValue{Value: value}, nil
|
||||
}
|
||||
default:
|
||||
// Fall through to error return below
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported value expression")
|
||||
}
|
||||
|
||||
func toComparisonOperator(fn string) (ComparisonOperator, error) {
|
||||
switch fn {
|
||||
case "_==_":
|
||||
return CompareEq, nil
|
||||
case "_!=_":
|
||||
return CompareNeq, nil
|
||||
case "_<_":
|
||||
return CompareLt, nil
|
||||
case "_>_":
|
||||
return CompareGt, nil
|
||||
case "_<=_":
|
||||
return CompareLte, nil
|
||||
case "_>=_":
|
||||
return CompareGte, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported comparison operator %q", fn)
|
||||
}
|
||||
}
|
||||
|
||||
func getIdentName(expr *exprv1.Expr) (string, error) {
|
||||
if ident := expr.GetIdentExpr(); ident != nil {
|
||||
return ident.GetName(), nil
|
||||
}
|
||||
return "", errors.New("expression is not an identifier")
|
||||
}
|
||||
|
||||
func getConstValue(expr *exprv1.Expr) (interface{}, error) {
|
||||
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
|
||||
if !ok {
|
||||
return nil, errors.New("expression is not a literal")
|
||||
}
|
||||
switch x := v.ConstExpr.ConstantKind.(type) {
|
||||
case *exprv1.Constant_StringValue:
|
||||
return v.ConstExpr.GetStringValue(), nil
|
||||
case *exprv1.Constant_Int64Value:
|
||||
return v.ConstExpr.GetInt64Value(), nil
|
||||
case *exprv1.Constant_Uint64Value:
|
||||
return int64(v.ConstExpr.GetUint64Value()), nil
|
||||
case *exprv1.Constant_DoubleValue:
|
||||
return v.ConstExpr.GetDoubleValue(), nil
|
||||
case *exprv1.Constant_BoolValue:
|
||||
return v.ConstExpr.GetBoolValue(), nil
|
||||
case *exprv1.Constant_NullValue:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported constant %T", x)
|
||||
}
|
||||
}
|
||||
|
||||
func evaluateBool(call *exprv1.Expr_Call) (bool, bool, error) {
|
||||
val, ok, err := evaluateBoolExpr(&exprv1.Expr{ExprKind: &exprv1.Expr_CallExpr{CallExpr: call}})
|
||||
return val, ok, err
|
||||
}
|
||||
|
||||
func evaluateBoolExpr(expr *exprv1.Expr) (bool, bool, error) {
|
||||
if literal, err := getConstValue(expr); err == nil {
|
||||
if b, ok := literal.(bool); ok {
|
||||
return b, true, nil
|
||||
}
|
||||
return false, false, nil
|
||||
}
|
||||
if call := expr.GetCallExpr(); call != nil && call.Function == "!_" {
|
||||
if len(call.Args) != 1 {
|
||||
return false, false, errors.New("NOT expects exactly one argument")
|
||||
}
|
||||
val, ok, err := evaluateBoolExpr(call.Args[0])
|
||||
if err != nil || !ok {
|
||||
return false, false, err
|
||||
}
|
||||
return !val, true, nil
|
||||
}
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) {
|
||||
if literal, err := getConstValue(expr); err == nil {
|
||||
switch v := literal.(type) {
|
||||
case int64:
|
||||
return v, true, nil
|
||||
case float64:
|
||||
return int64(v), true, nil
|
||||
}
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
call := expr.GetCallExpr()
|
||||
if call == nil {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
switch call.Function {
|
||||
case "now":
|
||||
return timeNowUnix(), true, nil
|
||||
case "_+_", "_-_", "_*_":
|
||||
if len(call.Args) != 2 {
|
||||
return 0, false, errors.New("arithmetic requires two arguments")
|
||||
}
|
||||
left, ok, err := evaluateNumeric(call.Args[0])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
right, ok, err := evaluateNumeric(call.Args[1])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
switch call.Function {
|
||||
case "_+_":
|
||||
return left + right, true, nil
|
||||
case "_-_":
|
||||
return left - right, true, nil
|
||||
case "_*_":
|
||||
return left * right, true, nil
|
||||
default:
|
||||
return 0, false, errors.Errorf("unsupported arithmetic operator %q", call.Function)
|
||||
}
|
||||
default:
|
||||
return 0, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func timeNowUnix() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
// buildComprehensionCondition handles CEL comprehension expressions (exists, all, etc.).
|
||||
func buildComprehensionCondition(comp *exprv1.Expr_Comprehension, schema Schema) (Condition, error) {
|
||||
// Determine the comprehension kind by examining the loop initialization and step
|
||||
kind, err := detectComprehensionKind(comp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the field being iterated over
|
||||
iterRangeIdent := comp.IterRange.GetIdentExpr()
|
||||
if iterRangeIdent == nil {
|
||||
return nil, errors.New("comprehension range must be a field identifier")
|
||||
}
|
||||
fieldName := iterRangeIdent.GetName()
|
||||
|
||||
// Validate the field
|
||||
field, ok := schema.Field(fieldName)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown field %q in comprehension", fieldName)
|
||||
}
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return nil, errors.Errorf("field %q does not support comprehension (must be a list)", fieldName)
|
||||
}
|
||||
|
||||
// Extract the predicate from the loop step
|
||||
predicate, err := extractPredicate(comp, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ListComprehensionCondition{
|
||||
Kind: kind,
|
||||
Field: fieldName,
|
||||
IterVar: comp.IterVar,
|
||||
Predicate: predicate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// detectComprehensionKind determines if this is an exists() macro.
|
||||
// Only exists() is currently supported.
|
||||
func detectComprehensionKind(comp *exprv1.Expr_Comprehension) (ComprehensionKind, error) {
|
||||
// Check the accumulator initialization
|
||||
accuInit := comp.AccuInit.GetConstExpr()
|
||||
if accuInit == nil {
|
||||
return "", errors.New("comprehension accumulator must be initialized with a constant")
|
||||
}
|
||||
|
||||
// exists() starts with false and uses OR (||) in loop step
|
||||
if !accuInit.GetBoolValue() {
|
||||
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_||_" {
|
||||
return ComprehensionExists, nil
|
||||
}
|
||||
}
|
||||
|
||||
// all() starts with true and uses AND (&&) - not supported
|
||||
if accuInit.GetBoolValue() {
|
||||
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_&&_" {
|
||||
return "", errors.New("all() comprehension is not supported; use exists() instead")
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("unsupported comprehension type; only exists() is supported")
|
||||
}
|
||||
|
||||
// extractPredicate extracts the predicate expression from the comprehension loop step.
|
||||
func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, error) {
|
||||
// The loop step is: @result || predicate(t) for exists
|
||||
// or: @result && predicate(t) for all
|
||||
step := comp.LoopStep.GetCallExpr()
|
||||
if step == nil {
|
||||
return nil, errors.New("comprehension loop step must be a call expression")
|
||||
}
|
||||
|
||||
if len(step.Args) != 2 {
|
||||
return nil, errors.New("comprehension loop step must have two arguments")
|
||||
}
|
||||
|
||||
// The predicate is the second argument
|
||||
predicateExpr := step.Args[1]
|
||||
predicateCall := predicateExpr.GetCallExpr()
|
||||
if predicateCall == nil {
|
||||
return nil, errors.New("comprehension predicate must be a function call")
|
||||
}
|
||||
|
||||
// Handle different predicate functions
|
||||
switch predicateCall.Function {
|
||||
case "startsWith":
|
||||
return buildStartsWithPredicate(predicateCall, comp.IterVar)
|
||||
case "endsWith":
|
||||
return buildEndsWithPredicate(predicateCall, comp.IterVar)
|
||||
case "contains":
|
||||
return buildContainsPredicate(predicateCall, comp.IterVar)
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
|
||||
}
|
||||
}
|
||||
|
||||
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").
|
||||
func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
// Verify the target is the iteration variable
|
||||
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
|
||||
return nil, errors.Errorf("startsWith target must be the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("startsWith expects exactly one argument")
|
||||
}
|
||||
|
||||
prefix, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "startsWith argument must be a constant string")
|
||||
}
|
||||
|
||||
prefixStr, ok := prefix.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("startsWith argument must be a string")
|
||||
}
|
||||
|
||||
return &StartsWithPredicate{Prefix: prefixStr}, nil
|
||||
}
|
||||
|
||||
// buildEndsWithPredicate extracts the pattern from t.endsWith("suffix").
|
||||
func buildEndsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
|
||||
return nil, errors.Errorf("endsWith target must be the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("endsWith expects exactly one argument")
|
||||
}
|
||||
|
||||
suffix, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "endsWith argument must be a constant string")
|
||||
}
|
||||
|
||||
suffixStr, ok := suffix.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("endsWith argument must be a string")
|
||||
}
|
||||
|
||||
return &EndsWithPredicate{Suffix: suffixStr}, nil
|
||||
}
|
||||
|
||||
// buildContainsPredicate extracts the pattern from t.contains("substring").
|
||||
func buildContainsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
|
||||
return nil, errors.Errorf("contains target must be the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("contains expects exactly one argument")
|
||||
}
|
||||
|
||||
substring, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "contains argument must be a constant string")
|
||||
}
|
||||
|
||||
substringStr, ok := substring.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("contains argument must be a string")
|
||||
}
|
||||
|
||||
return &ContainsPredicate{Substring: substringStr}, nil
|
||||
}
|
||||
748
plugin/filter/render.go
Normal file
748
plugin/filter/render.go
Normal file
@@ -0,0 +1,748 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type renderer struct {
|
||||
schema Schema
|
||||
dialect DialectName
|
||||
placeholderOffset int
|
||||
placeholderCounter int
|
||||
args []any
|
||||
}
|
||||
|
||||
type renderResult struct {
|
||||
sql string
|
||||
trivial bool
|
||||
unsatisfiable bool
|
||||
}
|
||||
|
||||
func newRenderer(schema Schema, opts RenderOptions) *renderer {
|
||||
return &renderer{
|
||||
schema: schema,
|
||||
dialect: opts.Dialect,
|
||||
placeholderOffset: opts.PlaceholderOffset,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) Render(cond Condition) (Statement, error) {
|
||||
result, err := r.renderCondition(cond)
|
||||
if err != nil {
|
||||
return Statement{}, err
|
||||
}
|
||||
args := r.args
|
||||
if args == nil {
|
||||
args = []any{}
|
||||
}
|
||||
|
||||
switch {
|
||||
case result.unsatisfiable:
|
||||
return Statement{
|
||||
SQL: "1 = 0",
|
||||
Args: args,
|
||||
}, nil
|
||||
case result.trivial:
|
||||
return Statement{
|
||||
SQL: "",
|
||||
Args: args,
|
||||
}, nil
|
||||
default:
|
||||
return Statement{
|
||||
SQL: result.sql,
|
||||
Args: args,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderCondition(cond Condition) (renderResult, error) {
|
||||
switch c := cond.(type) {
|
||||
case *LogicalCondition:
|
||||
return r.renderLogicalCondition(c)
|
||||
case *NotCondition:
|
||||
return r.renderNotCondition(c)
|
||||
case *FieldPredicateCondition:
|
||||
return r.renderFieldPredicate(c)
|
||||
case *ComparisonCondition:
|
||||
return r.renderComparison(c)
|
||||
case *InCondition:
|
||||
return r.renderInCondition(c)
|
||||
case *ElementInCondition:
|
||||
return r.renderElementInCondition(c)
|
||||
case *ContainsCondition:
|
||||
return r.renderContainsCondition(c)
|
||||
case *ListComprehensionCondition:
|
||||
return r.renderListComprehension(c)
|
||||
case *ConstantCondition:
|
||||
if c.Value {
|
||||
return renderResult{trivial: true}, nil
|
||||
}
|
||||
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported condition type %T", c)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderLogicalCondition(cond *LogicalCondition) (renderResult, error) {
|
||||
left, err := r.renderCondition(cond.Left)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
right, err := r.renderCondition(cond.Right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
switch cond.Operator {
|
||||
case LogicalAnd:
|
||||
return combineAnd(left, right), nil
|
||||
case LogicalOr:
|
||||
return combineOr(left, right), nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported logical operator %s", cond.Operator)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderNotCondition(cond *NotCondition) (renderResult, error) {
|
||||
child, err := r.renderCondition(cond.Expr)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
if child.trivial {
|
||||
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
|
||||
}
|
||||
if child.unsatisfiable {
|
||||
return renderResult{trivial: true}, nil
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("NOT (%s)", child.sql),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderFieldPredicate(cond *FieldPredicateCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
|
||||
switch field.Kind {
|
||||
case FieldKindBoolColumn:
|
||||
column := qualifyColumn(r.dialect, field.Column)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s IS TRUE", column),
|
||||
}, nil
|
||||
case FieldKindJSONBool:
|
||||
sql, err := r.jsonBoolPredicate(field)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
return renderResult{sql: sql}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("field %q cannot be used as a predicate", cond.Field)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderComparison(cond *ComparisonCondition) (renderResult, error) {
|
||||
switch left := cond.Left.(type) {
|
||||
case *FieldRef:
|
||||
field, ok := r.schema.Field(left.Name)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", left.Name)
|
||||
}
|
||||
switch field.Kind {
|
||||
case FieldKindBoolColumn:
|
||||
return r.renderBoolColumnComparison(field, cond.Operator, cond.Right)
|
||||
case FieldKindJSONBool:
|
||||
return r.renderJSONBoolComparison(field, cond.Operator, cond.Right)
|
||||
case FieldKindScalar:
|
||||
return r.renderScalarComparison(field, cond.Operator, cond.Right)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("field %q does not support comparison", field.Name)
|
||||
}
|
||||
case *FunctionValue:
|
||||
return r.renderFunctionComparison(left, cond.Operator, cond.Right)
|
||||
default:
|
||||
return renderResult{}, errors.New("comparison must start with a field reference or supported function")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderFunctionComparison(fn *FunctionValue, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
if fn.Name != "size" {
|
||||
return renderResult{}, errors.Errorf("unsupported function %s in comparison", fn.Name)
|
||||
}
|
||||
if len(fn.Args) != 1 {
|
||||
return renderResult{}, errors.New("size() expects one argument")
|
||||
}
|
||||
fieldArg, ok := fn.Args[0].(*FieldRef)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("size() argument must be a field")
|
||||
}
|
||||
|
||||
field, ok := r.schema.Field(fieldArg.Name)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", fieldArg.Name)
|
||||
}
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return renderResult{}, errors.Errorf("size() only supports tag lists, got %q", field.Name)
|
||||
}
|
||||
|
||||
value, err := expectNumericLiteral(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
expr := jsonArrayLengthExpr(r.dialect, field)
|
||||
placeholder := r.addArg(value)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s %s", expr, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderScalarComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
lit, err := expectLiteral(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
columnExpr := field.columnExpr(r.dialect)
|
||||
if lit == nil {
|
||||
switch op {
|
||||
case CompareEq:
|
||||
return renderResult{sql: fmt.Sprintf("%s IS NULL", columnExpr)}, nil
|
||||
case CompareNeq:
|
||||
return renderResult{sql: fmt.Sprintf("%s IS NOT NULL", columnExpr)}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("operator %s not supported for null comparison", op)
|
||||
}
|
||||
}
|
||||
|
||||
placeholder := ""
|
||||
switch field.Type {
|
||||
case FieldTypeString:
|
||||
value, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("field %q expects string value", field.Name)
|
||||
}
|
||||
placeholder = r.addArg(value)
|
||||
case FieldTypeInt, FieldTypeTimestamp:
|
||||
num, err := toInt64(lit)
|
||||
if err != nil {
|
||||
return renderResult{}, errors.Wrapf(err, "field %q expects integer value", field.Name)
|
||||
}
|
||||
placeholder = r.addArg(num)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported data type %q for field %s", field.Type, field.Name)
|
||||
}
|
||||
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s %s", columnExpr, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderBoolColumnComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
value, err := expectBool(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
placeholder := r.addBoolArg(value)
|
||||
column := qualifyColumn(r.dialect, field.Column)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s %s", column, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderJSONBoolComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
value, err := expectBool(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
jsonExpr := jsonExtractExpr(r.dialect, field)
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
switch op {
|
||||
case CompareEq:
|
||||
if field.Name == "has_task_list" {
|
||||
target := "0"
|
||||
if value {
|
||||
target = "1"
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("%s = %s", jsonExpr, target)}, nil
|
||||
}
|
||||
if value {
|
||||
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
|
||||
case CompareNeq:
|
||||
if field.Name == "has_task_list" {
|
||||
target := "0"
|
||||
if value {
|
||||
target = "1"
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("%s != %s", jsonExpr, target)}, nil
|
||||
}
|
||||
if value {
|
||||
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("operator %s not supported for boolean JSON field", op)
|
||||
}
|
||||
case DialectMySQL:
|
||||
boolStr := "false"
|
||||
if value {
|
||||
boolStr = "true"
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s CAST('%s' AS JSON)", jsonExpr, sqlOperator(op), boolStr),
|
||||
}, nil
|
||||
case DialectPostgres:
|
||||
placeholder := r.addArg(value)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s)::boolean %s %s", jsonExpr, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderInCondition(cond *InCondition) (renderResult, error) {
|
||||
fieldRef, ok := cond.Left.(*FieldRef)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("IN operator requires a field on the left-hand side")
|
||||
}
|
||||
|
||||
if fieldRef.Name == "tag" {
|
||||
return r.renderTagInList(cond.Values)
|
||||
}
|
||||
|
||||
field, ok := r.schema.Field(fieldRef.Name)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", fieldRef.Name)
|
||||
}
|
||||
|
||||
if field.Kind != FieldKindScalar {
|
||||
return renderResult{}, errors.Errorf("field %q does not support IN()", fieldRef.Name)
|
||||
}
|
||||
|
||||
return r.renderScalarInCondition(field, cond.Values)
|
||||
}
|
||||
|
||||
func (r *renderer) renderTagInList(values []ValueExpr) (renderResult, error) {
|
||||
field, ok := r.schema.ResolveAlias("tag")
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("tag attribute is not configured")
|
||||
}
|
||||
|
||||
conditions := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
lit, err := expectLiteral(v)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
str, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("tags must be compared with string literals")
|
||||
}
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
// Support hierarchical tags: match exact tag OR tags with this prefix (e.g., "book" matches "book" and "book/something")
|
||||
exactMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
|
||||
prefixMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
|
||||
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
conditions = append(conditions, expr)
|
||||
case DialectMySQL:
|
||||
// Support hierarchical tags: match exact tag OR tags with this prefix
|
||||
exactMatch := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
||||
prefixMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
|
||||
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
conditions = append(conditions, expr)
|
||||
case DialectPostgres:
|
||||
// Support hierarchical tags: match exact tag OR tags with this prefix
|
||||
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
||||
prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
|
||||
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
conditions = append(conditions, expr)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 1 {
|
||||
return renderResult{sql: conditions[0]}, nil
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderElementInCondition(cond *ElementInCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return renderResult{}, errors.Errorf("field %q is not a tag list", cond.Field)
|
||||
}
|
||||
|
||||
lit, err := expectLiteral(cond.Element)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
str, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("tags membership requires string literal")
|
||||
}
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
sql := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
|
||||
return renderResult{sql: sql}, nil
|
||||
case DialectMySQL:
|
||||
sql := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
||||
return renderResult{sql: sql}, nil
|
||||
case DialectPostgres:
|
||||
sql := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
||||
return renderResult{sql: sql}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderScalarInCondition(field Field, values []ValueExpr) (renderResult, error) {
|
||||
placeholders := make([]string, 0, len(values))
|
||||
|
||||
for _, v := range values {
|
||||
lit, err := expectLiteral(v)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
switch field.Type {
|
||||
case FieldTypeString:
|
||||
str, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("field %q expects string values", field.Name)
|
||||
}
|
||||
placeholders = append(placeholders, r.addArg(str))
|
||||
case FieldTypeInt:
|
||||
num, err := toInt64(lit)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
placeholders = append(placeholders, r.addArg(num))
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("field %q does not support IN() comparisons", field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
column := field.columnExpr(r.dialect)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
column := field.columnExpr(r.dialect)
|
||||
arg := fmt.Sprintf("%%%s%%", cond.Value)
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
// Use custom Unicode-aware case folding function for case-insensitive comparison.
|
||||
// This overcomes SQLite's ASCII-only LOWER() limitation.
|
||||
sql := fmt.Sprintf("memos_unicode_lower(%s) LIKE memos_unicode_lower(%s)", column, r.addArg(arg))
|
||||
return renderResult{sql: sql}, nil
|
||||
case DialectPostgres:
|
||||
sql := fmt.Sprintf("%s ILIKE %s", column, r.addArg(arg))
|
||||
return renderResult{sql: sql}, nil
|
||||
default:
|
||||
sql := fmt.Sprintf("%s LIKE %s", column, r.addArg(arg))
|
||||
return renderResult{sql: sql}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return renderResult{}, errors.Errorf("field %q is not a JSON list", cond.Field)
|
||||
}
|
||||
|
||||
// Render based on predicate type
|
||||
switch pred := cond.Predicate.(type) {
|
||||
case *StartsWithPredicate:
|
||||
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
|
||||
case *EndsWithPredicate:
|
||||
return r.renderTagEndsWith(field, pred.Suffix, cond.Kind)
|
||||
case *ContainsPredicate:
|
||||
return r.renderTagContains(field, pred.Substring, cond.Kind)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported predicate type %T in comprehension", pred)
|
||||
}
|
||||
}
|
||||
|
||||
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
|
||||
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
// Match exact tag or tags with this prefix (hierarchical support)
|
||||
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, prefix))
|
||||
prefixMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s%%`, prefix))
|
||||
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
|
||||
|
||||
case DialectPostgres:
|
||||
// Use PostgreSQL's powerful JSON operators
|
||||
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, prefix)))
|
||||
prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(fmt.Sprintf(`%%"%s%%`, prefix)))
|
||||
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
|
||||
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
// renderTagEndsWith generates SQL for tags.exists(t, t.endsWith("suffix")).
|
||||
func (r *renderer) renderTagEndsWith(field Field, suffix string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
pattern := fmt.Sprintf(`%%%s"%%`, suffix)
|
||||
|
||||
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
|
||||
}
|
||||
|
||||
// renderTagContains generates SQL for tags.exists(t, t.contains("substring")).
|
||||
func (r *renderer) renderTagContains(field Field, substring string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
pattern := fmt.Sprintf(`%%%s%%`, substring)
|
||||
|
||||
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
|
||||
}
|
||||
|
||||
// buildJSONArrayLike builds a LIKE expression for matching within a JSON array.
|
||||
// Returns the LIKE clause without NULL/empty checks.
|
||||
func (r *renderer) buildJSONArrayLike(arrayExpr, pattern string) string {
|
||||
switch r.dialect {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
return fmt.Sprintf("%s LIKE %s", arrayExpr, r.addArg(pattern))
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(pattern))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// wrapWithNullCheck wraps a condition with NULL and empty array checks.
|
||||
// This ensures we don't match against NULL or empty JSON arrays.
|
||||
func (r *renderer) wrapWithNullCheck(arrayExpr, condition string) string {
|
||||
var nullCheck string
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
nullCheck = fmt.Sprintf("%s IS NOT NULL AND %s != '[]'", arrayExpr, arrayExpr)
|
||||
case DialectMySQL:
|
||||
nullCheck = fmt.Sprintf("%s IS NOT NULL AND JSON_LENGTH(%s) > 0", arrayExpr, arrayExpr)
|
||||
case DialectPostgres:
|
||||
nullCheck = fmt.Sprintf("%s IS NOT NULL AND jsonb_array_length(%s) > 0", arrayExpr, arrayExpr)
|
||||
default:
|
||||
return condition
|
||||
}
|
||||
return fmt.Sprintf("(%s AND %s)", condition, nullCheck)
|
||||
}
|
||||
|
||||
func (r *renderer) jsonBoolPredicate(field Field) (string, error) {
|
||||
expr := jsonExtractExpr(r.dialect, field)
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
return fmt.Sprintf("%s IS TRUE", expr), nil
|
||||
case DialectMySQL:
|
||||
return fmt.Sprintf("COALESCE(%s, CAST('false' AS JSON)) = CAST('true' AS JSON)", expr), nil
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func combineAnd(left, right renderResult) renderResult {
|
||||
if left.unsatisfiable || right.unsatisfiable {
|
||||
return renderResult{sql: "1 = 0", unsatisfiable: true}
|
||||
}
|
||||
if left.trivial {
|
||||
return right
|
||||
}
|
||||
if right.trivial {
|
||||
return left
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s AND %s)", left.sql, right.sql),
|
||||
}
|
||||
}
|
||||
|
||||
func combineOr(left, right renderResult) renderResult {
|
||||
if left.trivial || right.trivial {
|
||||
return renderResult{trivial: true}
|
||||
}
|
||||
if left.unsatisfiable {
|
||||
return right
|
||||
}
|
||||
if right.unsatisfiable {
|
||||
return left
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s OR %s)", left.sql, right.sql),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) addArg(value any) string {
|
||||
r.placeholderCounter++
|
||||
r.args = append(r.args, value)
|
||||
if r.dialect == DialectPostgres {
|
||||
return fmt.Sprintf("$%d", r.placeholderOffset+r.placeholderCounter)
|
||||
}
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (r *renderer) addBoolArg(value bool) string {
|
||||
var v any
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
if value {
|
||||
v = 1
|
||||
} else {
|
||||
v = 0
|
||||
}
|
||||
default:
|
||||
v = value
|
||||
}
|
||||
return r.addArg(v)
|
||||
}
|
||||
|
||||
func expectLiteral(expr ValueExpr) (any, error) {
|
||||
lit, ok := expr.(*LiteralValue)
|
||||
if !ok {
|
||||
return nil, errors.New("expression must be a literal")
|
||||
}
|
||||
return lit.Value, nil
|
||||
}
|
||||
|
||||
func expectBool(expr ValueExpr) (bool, error) {
|
||||
lit, err := expectLiteral(expr)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
value, ok := lit.(bool)
|
||||
if !ok {
|
||||
return false, errors.New("boolean literal required")
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func expectNumericLiteral(expr ValueExpr) (int64, error) {
|
||||
lit, err := expectLiteral(expr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return toInt64(lit)
|
||||
}
|
||||
|
||||
func toInt64(value any) (int64, error) {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return int64(v), nil
|
||||
case int32:
|
||||
return int64(v), nil
|
||||
case int64:
|
||||
return v, nil
|
||||
case uint32:
|
||||
return int64(v), nil
|
||||
case uint64:
|
||||
return int64(v), nil
|
||||
case float32:
|
||||
return int64(v), nil
|
||||
case float64:
|
||||
return int64(v), nil
|
||||
default:
|
||||
return 0, errors.Errorf("cannot convert %T to int64", value)
|
||||
}
|
||||
}
|
||||
|
||||
func sqlOperator(op ComparisonOperator) string {
|
||||
return string(op)
|
||||
}
|
||||
|
||||
func qualifyColumn(d DialectName, col Column) string {
|
||||
switch d {
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("%s.%s", col.Table, col.Name)
|
||||
default:
|
||||
return fmt.Sprintf("`%s`.`%s`", col.Table, col.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func jsonPath(field Field) string {
|
||||
return "$." + strings.Join(field.JSONPath, ".")
|
||||
}
|
||||
|
||||
func jsonExtractExpr(d DialectName, field Field) string {
|
||||
column := qualifyColumn(d, field.Column)
|
||||
switch d {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
|
||||
case DialectPostgres:
|
||||
return buildPostgresJSONAccessor(column, field.JSONPath, true)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func jsonArrayExpr(d DialectName, field Field) string {
|
||||
column := qualifyColumn(d, field.Column)
|
||||
switch d {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
|
||||
case DialectPostgres:
|
||||
return buildPostgresJSONAccessor(column, field.JSONPath, false)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func jsonArrayLengthExpr(d DialectName, field Field) string {
|
||||
arrayExpr := jsonArrayExpr(d, field)
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
|
||||
case DialectMySQL:
|
||||
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("jsonb_array_length(COALESCE(%s, '[]'::jsonb))", arrayExpr)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func buildPostgresJSONAccessor(base string, path []string, terminalText bool) string {
|
||||
expr := base
|
||||
for idx, part := range path {
|
||||
if idx == len(path)-1 && terminalText {
|
||||
expr = fmt.Sprintf("%s->>'%s'", expr, part)
|
||||
} else {
|
||||
expr = fmt.Sprintf("%s->'%s'", expr, part)
|
||||
}
|
||||
}
|
||||
return expr
|
||||
}
|
||||
319
plugin/filter/schema.go
Normal file
319
plugin/filter/schema.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
)
|
||||
|
||||
// DialectName enumerates supported SQL dialects.
|
||||
type DialectName string
|
||||
|
||||
const (
|
||||
DialectSQLite DialectName = "sqlite"
|
||||
DialectMySQL DialectName = "mysql"
|
||||
DialectPostgres DialectName = "postgres"
|
||||
)
|
||||
|
||||
// FieldType represents the logical type of a field.
|
||||
type FieldType string
|
||||
|
||||
const (
|
||||
FieldTypeString FieldType = "string"
|
||||
FieldTypeInt FieldType = "int"
|
||||
FieldTypeBool FieldType = "bool"
|
||||
FieldTypeTimestamp FieldType = "timestamp"
|
||||
)
|
||||
|
||||
// FieldKind describes how a field is stored.
|
||||
type FieldKind string
|
||||
|
||||
const (
|
||||
FieldKindScalar FieldKind = "scalar"
|
||||
FieldKindBoolColumn FieldKind = "bool_column"
|
||||
FieldKindJSONBool FieldKind = "json_bool"
|
||||
FieldKindJSONList FieldKind = "json_list"
|
||||
FieldKindVirtualAlias FieldKind = "virtual_alias"
|
||||
)
|
||||
|
||||
// Column identifies the backing table column.
|
||||
type Column struct {
|
||||
Table string
|
||||
Name string
|
||||
}
|
||||
|
||||
// Field captures the schema metadata for an exposed CEL identifier.
|
||||
type Field struct {
|
||||
Name string
|
||||
Kind FieldKind
|
||||
Type FieldType
|
||||
Column Column
|
||||
JSONPath []string
|
||||
AliasFor string
|
||||
SupportsContains bool
|
||||
Expressions map[DialectName]string
|
||||
AllowedComparisonOps map[ComparisonOperator]bool
|
||||
}
|
||||
|
||||
// Schema collects CEL environment options and field metadata.
|
||||
type Schema struct {
|
||||
Name string
|
||||
Fields map[string]Field
|
||||
EnvOptions []cel.EnvOption
|
||||
}
|
||||
|
||||
// Field returns the field metadata if present.
|
||||
func (s Schema) Field(name string) (Field, bool) {
|
||||
f, ok := s.Fields[name]
|
||||
return f, ok
|
||||
}
|
||||
|
||||
// ResolveAlias resolves a virtual alias to its target field.
|
||||
func (s Schema) ResolveAlias(name string) (Field, bool) {
|
||||
field, ok := s.Fields[name]
|
||||
if !ok {
|
||||
return Field{}, false
|
||||
}
|
||||
if field.Kind == FieldKindVirtualAlias {
|
||||
target, ok := s.Fields[field.AliasFor]
|
||||
if !ok {
|
||||
return Field{}, false
|
||||
}
|
||||
return target, true
|
||||
}
|
||||
return field, true
|
||||
}
|
||||
|
||||
var nowFunction = cel.Function("now",
|
||||
cel.Overload("now",
|
||||
[]*cel.Type{},
|
||||
cel.IntType,
|
||||
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
|
||||
return types.Int(time.Now().Unix())
|
||||
}),
|
||||
),
|
||||
)
|
||||
|
||||
// NewSchema constructs the memo filter schema and CEL environment.
|
||||
func NewSchema() Schema {
|
||||
fields := map[string]Field{
|
||||
"content": {
|
||||
Name: "content",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "memo", Name: "content"},
|
||||
SupportsContains: true,
|
||||
Expressions: map[DialectName]string{},
|
||||
},
|
||||
"creator_id": {
|
||||
Name: "creator_id",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeInt,
|
||||
Column: Column{Table: "memo", Name: "creator_id"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"created_ts": {
|
||||
Name: "created_ts",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeTimestamp,
|
||||
Column: Column{Table: "memo", Name: "created_ts"},
|
||||
Expressions: map[DialectName]string{
|
||||
// MySQL stores created_ts as TIMESTAMP, needs conversion to epoch
|
||||
DialectMySQL: "UNIX_TIMESTAMP(%s)",
|
||||
// PostgreSQL and SQLite store created_ts as BIGINT (epoch), no conversion needed
|
||||
DialectPostgres: "%s",
|
||||
DialectSQLite: "%s",
|
||||
},
|
||||
},
|
||||
"updated_ts": {
|
||||
Name: "updated_ts",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeTimestamp,
|
||||
Column: Column{Table: "memo", Name: "updated_ts"},
|
||||
Expressions: map[DialectName]string{
|
||||
// MySQL stores updated_ts as TIMESTAMP, needs conversion to epoch
|
||||
DialectMySQL: "UNIX_TIMESTAMP(%s)",
|
||||
// PostgreSQL and SQLite store updated_ts as BIGINT (epoch), no conversion needed
|
||||
DialectPostgres: "%s",
|
||||
DialectSQLite: "%s",
|
||||
},
|
||||
},
|
||||
"pinned": {
|
||||
Name: "pinned",
|
||||
Kind: FieldKindBoolColumn,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "pinned"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"visibility": {
|
||||
Name: "visibility",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "memo", Name: "visibility"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"tags": {
|
||||
Name: "tags",
|
||||
Kind: FieldKindJSONList,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"tags"},
|
||||
},
|
||||
"tag": {
|
||||
Name: "tag",
|
||||
Kind: FieldKindVirtualAlias,
|
||||
Type: FieldTypeString,
|
||||
AliasFor: "tags",
|
||||
},
|
||||
"has_task_list": {
|
||||
Name: "has_task_list",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasTaskList"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"has_link": {
|
||||
Name: "has_link",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasLink"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"has_code": {
|
||||
Name: "has_code",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasCode"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"has_incomplete_tasks": {
|
||||
Name: "has_incomplete_tasks",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasIncompleteTasks"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
envOptions := []cel.EnvOption{
|
||||
cel.Variable("content", cel.StringType),
|
||||
cel.Variable("creator_id", cel.IntType),
|
||||
cel.Variable("created_ts", cel.IntType),
|
||||
cel.Variable("updated_ts", cel.IntType),
|
||||
cel.Variable("pinned", cel.BoolType),
|
||||
cel.Variable("tag", cel.StringType),
|
||||
cel.Variable("tags", cel.ListType(cel.StringType)),
|
||||
cel.Variable("visibility", cel.StringType),
|
||||
cel.Variable("has_task_list", cel.BoolType),
|
||||
cel.Variable("has_link", cel.BoolType),
|
||||
cel.Variable("has_code", cel.BoolType),
|
||||
cel.Variable("has_incomplete_tasks", cel.BoolType),
|
||||
nowFunction,
|
||||
}
|
||||
|
||||
return Schema{
|
||||
Name: "memo",
|
||||
Fields: fields,
|
||||
EnvOptions: envOptions,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAttachmentSchema constructs the attachment filter schema and CEL environment.
|
||||
func NewAttachmentSchema() Schema {
|
||||
fields := map[string]Field{
|
||||
"filename": {
|
||||
Name: "filename",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "attachment", Name: "filename"},
|
||||
SupportsContains: true,
|
||||
Expressions: map[DialectName]string{},
|
||||
},
|
||||
"mime_type": {
|
||||
Name: "mime_type",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "attachment", Name: "type"},
|
||||
Expressions: map[DialectName]string{},
|
||||
},
|
||||
"create_time": {
|
||||
Name: "create_time",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeTimestamp,
|
||||
Column: Column{Table: "attachment", Name: "created_ts"},
|
||||
Expressions: map[DialectName]string{
|
||||
// MySQL stores created_ts as TIMESTAMP, needs conversion to epoch
|
||||
DialectMySQL: "UNIX_TIMESTAMP(%s)",
|
||||
// PostgreSQL and SQLite store created_ts as BIGINT (epoch), no conversion needed
|
||||
DialectPostgres: "%s",
|
||||
DialectSQLite: "%s",
|
||||
},
|
||||
},
|
||||
"memo_id": {
|
||||
Name: "memo_id",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeInt,
|
||||
Column: Column{Table: "attachment", Name: "memo_id"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
envOptions := []cel.EnvOption{
|
||||
cel.Variable("filename", cel.StringType),
|
||||
cel.Variable("mime_type", cel.StringType),
|
||||
cel.Variable("create_time", cel.IntType),
|
||||
cel.Variable("memo_id", cel.AnyType),
|
||||
nowFunction,
|
||||
}
|
||||
|
||||
return Schema{
|
||||
Name: "attachment",
|
||||
Fields: fields,
|
||||
EnvOptions: envOptions,
|
||||
}
|
||||
}
|
||||
|
||||
// columnExpr returns the field expression for the given dialect, applying
|
||||
// any schema-specific overrides (e.g. UNIX timestamp conversions).
|
||||
func (f Field) columnExpr(d DialectName) string {
|
||||
base := qualifyColumn(d, f.Column)
|
||||
if expr, ok := f.Expressions[d]; ok && expr != "" {
|
||||
return fmt.Sprintf(expr, base)
|
||||
}
|
||||
return base
|
||||
}
|
||||
166
plugin/httpgetter/html_meta.go
Normal file
166
plugin/httpgetter/html_meta.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package httpgetter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/html"
|
||||
"golang.org/x/net/html/atom"
|
||||
)
|
||||
|
||||
var ErrInternalIP = errors.New("internal IP addresses are not allowed")
|
||||
|
||||
var httpClient = &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if err := validateURL(req.URL.String()); err != nil {
|
||||
return errors.Wrap(err, "redirect to internal IP")
|
||||
}
|
||||
if len(via) >= 10 {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
type HTMLMeta struct {
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Image string `json:"image"`
|
||||
}
|
||||
|
||||
func GetHTMLMeta(urlStr string) (*HTMLMeta, error) {
|
||||
if err := validateURL(urlStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := httpClient.Get(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
mediatype, err := getMediatype(response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if mediatype != "text/html" {
|
||||
return nil, errors.New("not a HTML page")
|
||||
}
|
||||
|
||||
// TODO: limit the size of the response body
|
||||
|
||||
htmlMeta := extractHTMLMeta(response.Body)
|
||||
enrichSiteMeta(response.Request.URL, htmlMeta)
|
||||
return htmlMeta, nil
|
||||
}
|
||||
|
||||
func extractHTMLMeta(resp io.Reader) *HTMLMeta {
|
||||
tokenizer := html.NewTokenizer(resp)
|
||||
htmlMeta := new(HTMLMeta)
|
||||
|
||||
for {
|
||||
tokenType := tokenizer.Next()
|
||||
if tokenType == html.ErrorToken {
|
||||
break
|
||||
} else if tokenType == html.StartTagToken || tokenType == html.SelfClosingTagToken {
|
||||
token := tokenizer.Token()
|
||||
if token.DataAtom == atom.Body {
|
||||
break
|
||||
}
|
||||
|
||||
if token.DataAtom == atom.Title {
|
||||
tokenizer.Next()
|
||||
token := tokenizer.Token()
|
||||
htmlMeta.Title = token.Data
|
||||
} else if token.DataAtom == atom.Meta {
|
||||
description, ok := extractMetaProperty(token, "description")
|
||||
if ok {
|
||||
htmlMeta.Description = description
|
||||
}
|
||||
|
||||
ogTitle, ok := extractMetaProperty(token, "og:title")
|
||||
if ok {
|
||||
htmlMeta.Title = ogTitle
|
||||
}
|
||||
|
||||
ogDescription, ok := extractMetaProperty(token, "og:description")
|
||||
if ok {
|
||||
htmlMeta.Description = ogDescription
|
||||
}
|
||||
|
||||
ogImage, ok := extractMetaProperty(token, "og:image")
|
||||
if ok {
|
||||
htmlMeta.Image = ogImage
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return htmlMeta
|
||||
}
|
||||
|
||||
func extractMetaProperty(token html.Token, prop string) (content string, ok bool) {
|
||||
content, ok = "", false
|
||||
for _, attr := range token.Attr {
|
||||
if attr.Key == "property" && attr.Val == prop {
|
||||
ok = true
|
||||
}
|
||||
if attr.Key == "content" {
|
||||
content = attr.Val
|
||||
}
|
||||
}
|
||||
return content, ok
|
||||
}
|
||||
|
||||
func validateURL(urlStr string) error {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return errors.New("invalid URL format")
|
||||
}
|
||||
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return errors.New("only http/https protocols are allowed")
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return errors.New("empty hostname")
|
||||
}
|
||||
|
||||
// check if the hostname is an IP
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
|
||||
return errors.Wrap(ErrInternalIP, ip.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// check if it's a hostname, resolve it and check all returned IPs
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return errors.Errorf("failed to resolve hostname: %v", err)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
|
||||
return errors.Wrapf(ErrInternalIP, "host=%s, ip=%s", host, ip.String())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enrichSiteMeta(url *url.URL, meta *HTMLMeta) {
|
||||
if url.Hostname() == "www.youtube.com" {
|
||||
if url.Path == "/watch" {
|
||||
vid := url.Query().Get("v")
|
||||
if vid != "" {
|
||||
meta.Image = fmt.Sprintf("https://img.youtube.com/vi/%s/mqdefault.jpg", vid)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
32
plugin/httpgetter/html_meta_test.go
Normal file
32
plugin/httpgetter/html_meta_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package httpgetter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetHTMLMeta(t *testing.T) {
|
||||
tests := []struct {
|
||||
urlStr string
|
||||
htmlMeta HTMLMeta
|
||||
}{}
|
||||
for _, test := range tests {
|
||||
metadata, err := GetHTMLMeta(test.urlStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.htmlMeta, *metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHTMLMetaForInternal(t *testing.T) {
|
||||
// test for internal IP
|
||||
if _, err := GetHTMLMeta("http://192.168.0.1"); !errors.Is(err, ErrInternalIP) {
|
||||
t.Errorf("Expected error for internal IP, got %v", err)
|
||||
}
|
||||
|
||||
// test for resolved internal IP
|
||||
if _, err := GetHTMLMeta("http://localhost"); !errors.Is(err, ErrInternalIP) {
|
||||
t.Errorf("Expected error for resolved internal IP, got %v", err)
|
||||
}
|
||||
}
|
||||
1
plugin/httpgetter/http_getter.go
Normal file
1
plugin/httpgetter/http_getter.go
Normal file
@@ -0,0 +1 @@
|
||||
package httpgetter
|
||||
45
plugin/httpgetter/image.go
Normal file
45
plugin/httpgetter/image.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package httpgetter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Image struct {
|
||||
Blob []byte
|
||||
Mediatype string
|
||||
}
|
||||
|
||||
func GetImage(urlStr string) (*Image, error) {
|
||||
if _, err := url.Parse(urlStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := http.Get(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
mediatype, err := getMediatype(response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !strings.HasPrefix(mediatype, "image/") {
|
||||
return nil, errors.New("wrong image mediatype")
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
image := &Image{
|
||||
Blob: bodyBytes,
|
||||
Mediatype: mediatype,
|
||||
}
|
||||
return image, nil
|
||||
}
|
||||
15
plugin/httpgetter/util.go
Normal file
15
plugin/httpgetter/util.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package httpgetter
|
||||
|
||||
import (
|
||||
"mime"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func getMediatype(response *http.Response) (string, error) {
|
||||
contentType := response.Header.Get("content-type")
|
||||
mediatype, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return mediatype, nil
|
||||
}
|
||||
8
plugin/idp/idp.go
Normal file
8
plugin/idp/idp.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package idp
|
||||
|
||||
type IdentityProviderUserInfo struct {
|
||||
Identifier string
|
||||
DisplayName string
|
||||
Email string
|
||||
AvatarURL string
|
||||
}
|
||||
134
plugin/idp/oauth2/oauth2.go
Normal file
134
plugin/idp/oauth2/oauth2.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Package oauth2 is the plugin for OAuth2 Identity Provider.
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
// IdentityProvider represents an OAuth2 Identity Provider.
|
||||
type IdentityProvider struct {
|
||||
config *storepb.OAuth2Config
|
||||
}
|
||||
|
||||
// NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration.
|
||||
func NewIdentityProvider(config *storepb.OAuth2Config) (*IdentityProvider, error) {
|
||||
for v, field := range map[string]string{
|
||||
config.ClientId: "clientId",
|
||||
config.ClientSecret: "clientSecret",
|
||||
config.TokenUrl: "tokenUrl",
|
||||
config.UserInfoUrl: "userInfoUrl",
|
||||
config.FieldMapping.Identifier: "fieldMapping.identifier",
|
||||
} {
|
||||
if v == "" {
|
||||
return nil, errors.Errorf(`the field "%s" is empty but required`, field)
|
||||
}
|
||||
}
|
||||
|
||||
return &IdentityProvider{
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeToken returns the exchanged OAuth2 token using the given authorization code.
|
||||
// If codeVerifier is provided, it will be used for PKCE (Proof Key for Code Exchange) validation.
|
||||
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code, codeVerifier string) (string, error) {
|
||||
conf := &oauth2.Config{
|
||||
ClientID: p.config.ClientId,
|
||||
ClientSecret: p.config.ClientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: p.config.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: p.config.AuthUrl,
|
||||
TokenURL: p.config.TokenUrl,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
}
|
||||
|
||||
// Prepare token exchange options
|
||||
opts := []oauth2.AuthCodeOption{}
|
||||
|
||||
// Add PKCE code_verifier if provided
|
||||
if codeVerifier != "" {
|
||||
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
|
||||
}
|
||||
|
||||
token, err := conf.Exchange(ctx, code, opts...)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to exchange access token")
|
||||
}
|
||||
|
||||
// Use the standard AccessToken field instead of Extra()
|
||||
// This is more reliable across different OAuth providers
|
||||
if token.AccessToken == "" {
|
||||
return "", errors.New("missing access token from authorization response")
|
||||
}
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// UserInfo returns the parsed user information using the given OAuth2 token.
|
||||
func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) {
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(http.MethodGet, p.config.UserInfoUrl, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to new http request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user information")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(body, &claims); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
slog.Info("user info claims", "claims", claims)
|
||||
userInfo := &idp.IdentityProviderUserInfo{}
|
||||
if v, ok := claims[p.config.FieldMapping.Identifier].(string); ok {
|
||||
userInfo.Identifier = v
|
||||
}
|
||||
if userInfo.Identifier == "" {
|
||||
return nil, errors.Errorf("the field %q is not found in claims or has empty value", p.config.FieldMapping.Identifier)
|
||||
}
|
||||
|
||||
// Best effort to map optional fields
|
||||
if p.config.FieldMapping.DisplayName != "" {
|
||||
if v, ok := claims[p.config.FieldMapping.DisplayName].(string); ok {
|
||||
userInfo.DisplayName = v
|
||||
}
|
||||
}
|
||||
if userInfo.DisplayName == "" {
|
||||
userInfo.DisplayName = userInfo.Identifier
|
||||
}
|
||||
if p.config.FieldMapping.Email != "" {
|
||||
if v, ok := claims[p.config.FieldMapping.Email].(string); ok {
|
||||
userInfo.Email = v
|
||||
}
|
||||
}
|
||||
if p.config.FieldMapping.AvatarUrl != "" {
|
||||
if v, ok := claims[p.config.FieldMapping.AvatarUrl].(string); ok {
|
||||
userInfo.AvatarURL = v
|
||||
}
|
||||
}
|
||||
slog.Info("user info", "userInfo", userInfo)
|
||||
return userInfo, nil
|
||||
}
|
||||
164
plugin/idp/oauth2/oauth2_test.go
Normal file
164
plugin/idp/oauth2/oauth2_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
func TestNewIdentityProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *storepb.OAuth2Config
|
||||
containsErr string
|
||||
}{
|
||||
{
|
||||
name: "no tokenUrl",
|
||||
config: &storepb.OAuth2Config{
|
||||
ClientId: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthUrl: "",
|
||||
TokenUrl: "",
|
||||
UserInfoUrl: "https://example.com/api/user",
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "login",
|
||||
},
|
||||
},
|
||||
containsErr: `the field "tokenUrl" is empty but required`,
|
||||
},
|
||||
{
|
||||
name: "no userInfoUrl",
|
||||
config: &storepb.OAuth2Config{
|
||||
ClientId: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthUrl: "",
|
||||
TokenUrl: "https://example.com/token",
|
||||
UserInfoUrl: "",
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "login",
|
||||
},
|
||||
},
|
||||
containsErr: `the field "userInfoUrl" is empty but required`,
|
||||
},
|
||||
{
|
||||
name: "no field mapping identifier",
|
||||
config: &storepb.OAuth2Config{
|
||||
ClientId: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthUrl: "",
|
||||
TokenUrl: "https://example.com/token",
|
||||
UserInfoUrl: "https://example.com/api/user",
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "",
|
||||
},
|
||||
},
|
||||
containsErr: `the field "fieldMapping.identifier" is empty but required`,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(*testing.T) {
|
||||
_, err := NewIdentityProvider(test.config)
|
||||
assert.ErrorContains(t, err, test.containsErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
var rawIDToken string
|
||||
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
vals, err := url.ParseQuery(string(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, code, vals.Get("code"))
|
||||
require.Equal(t, "authorization_code", vals.Get("grant_type"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"id_token": rawIDToken,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err := w.Write(userinfo)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
s := httptest.NewServer(mux)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func TestIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
testClientID = "test-client-id"
|
||||
testCode = "test-code"
|
||||
testAccessToken = "test-access-token"
|
||||
testSubject = "123456789"
|
||||
testName = "John Doe"
|
||||
testEmail = "john.doe@example.com"
|
||||
)
|
||||
userInfo, err := json.Marshal(
|
||||
map[string]any{
|
||||
"sub": testSubject,
|
||||
"name": testName,
|
||||
"email": testEmail,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := newMockServer(t, testCode, testAccessToken, userInfo)
|
||||
|
||||
oauth2, err := NewIdentityProvider(
|
||||
&storepb.OAuth2Config{
|
||||
ClientId: testClientID,
|
||||
ClientSecret: "test-client-secret",
|
||||
TokenUrl: fmt.Sprintf("%s/oauth2/token", s.URL),
|
||||
UserInfoUrl: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
redirectURL := "https://example.com/oauth/callback"
|
||||
// Test without PKCE (backward compatibility)
|
||||
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testAccessToken, oauthToken)
|
||||
|
||||
userInfoResult, err := oauth2.UserInfo(oauthToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
wantUserInfo := &idp.IdentityProviderUserInfo{
|
||||
Identifier: testSubject,
|
||||
DisplayName: testName,
|
||||
Email: testEmail,
|
||||
}
|
||||
assert.Equal(t, wantUserInfo, userInfoResult)
|
||||
}
|
||||
28
plugin/markdown/ast/tag.go
Normal file
28
plugin/markdown/ast/tag.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
gast "github.com/yuin/goldmark/ast"
|
||||
)
|
||||
|
||||
// TagNode represents a #tag in the markdown AST.
|
||||
type TagNode struct {
|
||||
gast.BaseInline
|
||||
|
||||
// Tag name without the # prefix
|
||||
Tag []byte
|
||||
}
|
||||
|
||||
// KindTag is the NodeKind for TagNode.
|
||||
var KindTag = gast.NewNodeKind("Tag")
|
||||
|
||||
// Kind returns KindTag.
|
||||
func (*TagNode) Kind() gast.NodeKind {
|
||||
return KindTag
|
||||
}
|
||||
|
||||
// Dump implements Node.Dump for debugging.
|
||||
func (n *TagNode) Dump(source []byte, level int) {
|
||||
gast.DumpHelper(n, source, level, map[string]string{
|
||||
"Tag": string(n.Tag),
|
||||
}, nil)
|
||||
}
|
||||
24
plugin/markdown/extensions/tag.go
Normal file
24
plugin/markdown/extensions/tag.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"github.com/yuin/goldmark"
|
||||
"github.com/yuin/goldmark/parser"
|
||||
"github.com/yuin/goldmark/util"
|
||||
|
||||
mparser "github.com/usememos/memos/plugin/markdown/parser"
|
||||
)
|
||||
|
||||
type tagExtension struct{}
|
||||
|
||||
// TagExtension is a goldmark extension for #tag syntax.
|
||||
var TagExtension = &tagExtension{}
|
||||
|
||||
// Extend extends the goldmark parser with tag support.
|
||||
func (*tagExtension) Extend(m goldmark.Markdown) {
|
||||
m.Parser().AddOptions(
|
||||
parser.WithInlineParsers(
|
||||
// Priority 200 - run before standard link parser (500)
|
||||
util.Prioritized(mparser.NewTagParser(), 200),
|
||||
),
|
||||
)
|
||||
}
|
||||
409
plugin/markdown/markdown.go
Normal file
409
plugin/markdown/markdown.go
Normal file
@@ -0,0 +1,409 @@
|
||||
package markdown
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/yuin/goldmark"
|
||||
gast "github.com/yuin/goldmark/ast"
|
||||
"github.com/yuin/goldmark/extension"
|
||||
east "github.com/yuin/goldmark/extension/ast"
|
||||
"github.com/yuin/goldmark/parser"
|
||||
"github.com/yuin/goldmark/text"
|
||||
|
||||
mast "github.com/usememos/memos/plugin/markdown/ast"
|
||||
"github.com/usememos/memos/plugin/markdown/extensions"
|
||||
"github.com/usememos/memos/plugin/markdown/renderer"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
// ExtractedData contains all metadata extracted from markdown in a single pass.
|
||||
type ExtractedData struct {
|
||||
Tags []string
|
||||
Property *storepb.MemoPayload_Property
|
||||
}
|
||||
|
||||
// Service handles markdown metadata extraction.
|
||||
// It uses goldmark to parse markdown and extract tags, properties, and snippets.
|
||||
// HTML rendering is primarily done on frontend using markdown-it, but backend provides
|
||||
// RenderHTML for RSS feeds and other server-side rendering needs.
|
||||
type Service interface {
|
||||
// ExtractAll extracts tags, properties, and references in a single parse (most efficient)
|
||||
ExtractAll(content []byte) (*ExtractedData, error)
|
||||
|
||||
// ExtractTags returns all #tags found in content
|
||||
ExtractTags(content []byte) ([]string, error)
|
||||
|
||||
// ExtractProperties computes boolean properties
|
||||
ExtractProperties(content []byte) (*storepb.MemoPayload_Property, error)
|
||||
|
||||
// RenderMarkdown renders goldmark AST back to markdown text
|
||||
RenderMarkdown(content []byte) (string, error)
|
||||
|
||||
// RenderHTML renders markdown content to HTML
|
||||
RenderHTML(content []byte) (string, error)
|
||||
|
||||
// GenerateSnippet creates plain text summary
|
||||
GenerateSnippet(content []byte, maxLength int) (string, error)
|
||||
|
||||
// ValidateContent checks for syntax errors
|
||||
ValidateContent(content []byte) error
|
||||
|
||||
// RenameTag renames all occurrences of oldTag to newTag in content
|
||||
RenameTag(content []byte, oldTag, newTag string) (string, error)
|
||||
}
|
||||
|
||||
// service implements the Service interface.
|
||||
type service struct {
|
||||
md goldmark.Markdown
|
||||
}
|
||||
|
||||
// Option configures the markdown service.
|
||||
type Option func(*config)
|
||||
|
||||
type config struct {
|
||||
enableTags bool
|
||||
}
|
||||
|
||||
// WithTagExtension enables #tag parsing.
|
||||
func WithTagExtension() Option {
|
||||
return func(c *config) {
|
||||
c.enableTags = true
|
||||
}
|
||||
}
|
||||
|
||||
// NewService creates a new markdown service with the given options.
|
||||
func NewService(opts ...Option) Service {
|
||||
cfg := &config{}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
exts := []goldmark.Extender{
|
||||
extension.GFM, // GitHub Flavored Markdown (tables, strikethrough, task lists, autolinks)
|
||||
}
|
||||
|
||||
// Add custom extensions based on config
|
||||
if cfg.enableTags {
|
||||
exts = append(exts, extensions.TagExtension)
|
||||
}
|
||||
|
||||
md := goldmark.New(
|
||||
goldmark.WithExtensions(exts...),
|
||||
goldmark.WithParserOptions(
|
||||
parser.WithAutoHeadingID(), // Generate heading IDs
|
||||
),
|
||||
)
|
||||
|
||||
return &service{
|
||||
md: md,
|
||||
}
|
||||
}
|
||||
|
||||
// parse is an internal helper to parse content into AST.
|
||||
func (s *service) parse(content []byte) (gast.Node, error) {
|
||||
reader := text.NewReader(content)
|
||||
doc := s.md.Parser().Parse(reader)
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// ExtractTags returns all #tags found in content.
|
||||
func (s *service) ExtractTags(content []byte) ([]string, error) {
|
||||
root, err := s.parse(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tags []string
|
||||
|
||||
// Walk the AST to find tag nodes
|
||||
err = gast.Walk(root, func(n gast.Node, entering bool) (gast.WalkStatus, error) {
|
||||
if !entering {
|
||||
return gast.WalkContinue, nil
|
||||
}
|
||||
|
||||
// Check for custom TagNode
|
||||
if tagNode, ok := n.(*mast.TagNode); ok {
|
||||
tags = append(tags, string(tagNode.Tag))
|
||||
}
|
||||
|
||||
return gast.WalkContinue, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Deduplicate tags while preserving original case
|
||||
return uniquePreserveCase(tags), nil
|
||||
}
|
||||
|
||||
// ExtractProperties computes boolean properties about the content.
|
||||
func (s *service) ExtractProperties(content []byte) (*storepb.MemoPayload_Property, error) {
|
||||
root, err := s.parse(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prop := &storepb.MemoPayload_Property{}
|
||||
|
||||
err = gast.Walk(root, func(n gast.Node, entering bool) (gast.WalkStatus, error) {
|
||||
if !entering {
|
||||
return gast.WalkContinue, nil
|
||||
}
|
||||
|
||||
switch n.Kind() {
|
||||
case gast.KindLink:
|
||||
prop.HasLink = true
|
||||
|
||||
case gast.KindCodeBlock, gast.KindFencedCodeBlock, gast.KindCodeSpan:
|
||||
prop.HasCode = true
|
||||
|
||||
case east.KindTaskCheckBox:
|
||||
prop.HasTaskList = true
|
||||
if checkBox, ok := n.(*east.TaskCheckBox); ok {
|
||||
if !checkBox.IsChecked {
|
||||
prop.HasIncompleteTasks = true
|
||||
}
|
||||
}
|
||||
default:
|
||||
// No special handling for other node types
|
||||
}
|
||||
|
||||
return gast.WalkContinue, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return prop, nil
|
||||
}
|
||||
|
||||
// RenderMarkdown renders goldmark AST back to markdown text.
|
||||
func (s *service) RenderMarkdown(content []byte) (string, error) {
|
||||
root, err := s.parse(content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mdRenderer := renderer.NewMarkdownRenderer()
|
||||
return mdRenderer.Render(root, content), nil
|
||||
}
|
||||
|
||||
// RenderHTML renders markdown content to HTML using goldmark's built-in HTML renderer.
|
||||
func (s *service) RenderHTML(content []byte) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := s.md.Convert(content, &buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// GenerateSnippet creates a plain text summary from markdown content.
|
||||
func (s *service) GenerateSnippet(content []byte, maxLength int) (string, error) {
|
||||
root, err := s.parse(content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
var lastNodeWasBlock bool
|
||||
|
||||
err = gast.Walk(root, func(n gast.Node, entering bool) (gast.WalkStatus, error) {
|
||||
if entering {
|
||||
// Skip code blocks and code spans entirely
|
||||
switch n.Kind() {
|
||||
case gast.KindCodeBlock, gast.KindFencedCodeBlock, gast.KindCodeSpan:
|
||||
return gast.WalkSkipChildren, nil
|
||||
default:
|
||||
// Continue walking for other node types
|
||||
}
|
||||
|
||||
// Add space before block elements (except first)
|
||||
switch n.Kind() {
|
||||
case gast.KindParagraph, gast.KindHeading, gast.KindListItem:
|
||||
if buf.Len() > 0 && lastNodeWasBlock {
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
default:
|
||||
// No space needed for other node types
|
||||
}
|
||||
}
|
||||
|
||||
if !entering {
|
||||
// Mark that we just exited a block element
|
||||
switch n.Kind() {
|
||||
case gast.KindParagraph, gast.KindHeading, gast.KindListItem:
|
||||
lastNodeWasBlock = true
|
||||
default:
|
||||
// Not a block element
|
||||
}
|
||||
return gast.WalkContinue, nil
|
||||
}
|
||||
|
||||
lastNodeWasBlock = false
|
||||
|
||||
// Only extract plain text nodes
|
||||
if textNode, ok := n.(*gast.Text); ok {
|
||||
segment := textNode.Segment
|
||||
buf.Write(segment.Value(content))
|
||||
|
||||
// Add space if this is a soft line break
|
||||
if textNode.SoftLineBreak() {
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
// Stop walking if we've exceeded double the max length
|
||||
// (we'll truncate precisely later)
|
||||
if buf.Len() > maxLength*2 {
|
||||
return gast.WalkStop, nil
|
||||
}
|
||||
|
||||
return gast.WalkContinue, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
snippet := buf.String()
|
||||
|
||||
// Truncate at word boundary if needed
|
||||
if len(snippet) > maxLength {
|
||||
snippet = truncateAtWord(snippet, maxLength)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(snippet), nil
|
||||
}
|
||||
|
||||
// ValidateContent checks if the markdown content is valid.
|
||||
func (s *service) ValidateContent(content []byte) error {
|
||||
// Try to parse the content
|
||||
_, err := s.parse(content)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExtractAll extracts tags, properties, and references in a single parse for efficiency.
|
||||
func (s *service) ExtractAll(content []byte) (*ExtractedData, error) {
|
||||
root, err := s.parse(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := &ExtractedData{
|
||||
Tags: []string{},
|
||||
Property: &storepb.MemoPayload_Property{},
|
||||
}
|
||||
|
||||
// Single walk to collect all data
|
||||
err = gast.Walk(root, func(n gast.Node, entering bool) (gast.WalkStatus, error) {
|
||||
if !entering {
|
||||
return gast.WalkContinue, nil
|
||||
}
|
||||
|
||||
// Extract tags
|
||||
if tagNode, ok := n.(*mast.TagNode); ok {
|
||||
data.Tags = append(data.Tags, string(tagNode.Tag))
|
||||
}
|
||||
|
||||
// Extract properties based on node kind
|
||||
switch n.Kind() {
|
||||
case gast.KindLink:
|
||||
data.Property.HasLink = true
|
||||
|
||||
case gast.KindCodeBlock, gast.KindFencedCodeBlock, gast.KindCodeSpan:
|
||||
data.Property.HasCode = true
|
||||
|
||||
case east.KindTaskCheckBox:
|
||||
data.Property.HasTaskList = true
|
||||
if checkBox, ok := n.(*east.TaskCheckBox); ok {
|
||||
if !checkBox.IsChecked {
|
||||
data.Property.HasIncompleteTasks = true
|
||||
}
|
||||
}
|
||||
default:
|
||||
// No special handling for other node types
|
||||
}
|
||||
|
||||
return gast.WalkContinue, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Deduplicate tags while preserving original case
|
||||
data.Tags = uniquePreserveCase(data.Tags)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// RenameTag renames all occurrences of oldTag to newTag in content.
|
||||
func (s *service) RenameTag(content []byte, oldTag, newTag string) (string, error) {
|
||||
root, err := s.parse(content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Walk the AST to find and rename tag nodes
|
||||
err = gast.Walk(root, func(n gast.Node, entering bool) (gast.WalkStatus, error) {
|
||||
if !entering {
|
||||
return gast.WalkContinue, nil
|
||||
}
|
||||
|
||||
// Check for custom TagNode and rename if it matches
|
||||
if tagNode, ok := n.(*mast.TagNode); ok {
|
||||
if string(tagNode.Tag) == oldTag {
|
||||
tagNode.Tag = []byte(newTag)
|
||||
}
|
||||
}
|
||||
|
||||
return gast.WalkContinue, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Render back to markdown using the already-parsed AST
|
||||
mdRenderer := renderer.NewMarkdownRenderer()
|
||||
return mdRenderer.Render(root, content), nil
|
||||
}
|
||||
|
||||
// uniquePreserveCase returns unique strings from input while preserving case.
|
||||
func uniquePreserveCase(strs []string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var result []string
|
||||
|
||||
for _, s := range strs {
|
||||
if _, exists := seen[s]; !exists {
|
||||
seen[s] = struct{}{}
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// truncateAtWord truncates a string at the last word boundary before maxLength.
|
||||
// maxLength is treated as a rune (character) count to properly handle UTF-8 multi-byte characters.
|
||||
func truncateAtWord(s string, maxLength int) string {
|
||||
// Convert to runes to properly handle multi-byte UTF-8 characters
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLength {
|
||||
return s
|
||||
}
|
||||
|
||||
// Truncate to max length (by character count, not byte count)
|
||||
truncated := string(runes[:maxLength])
|
||||
|
||||
// Find last space to avoid cutting in the middle of a word
|
||||
lastSpace := strings.LastIndexAny(truncated, " \t\n\r")
|
||||
if lastSpace > 0 {
|
||||
truncated = truncated[:lastSpace]
|
||||
}
|
||||
|
||||
return truncated + " ..."
|
||||
}
|
||||
448
plugin/markdown/markdown_test.go
Normal file
448
plugin/markdown/markdown_test.go
Normal file
@@ -0,0 +1,448 @@
|
||||
package markdown
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewService(t *testing.T) {
|
||||
svc := NewService()
|
||||
assert.NotNil(t, svc)
|
||||
}
|
||||
|
||||
func TestValidateContent(t *testing.T) {
|
||||
svc := NewService()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid markdown",
|
||||
content: "# Hello\n\nThis is **bold** text.",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "complex markdown",
|
||||
content: "# Title\n\n- List item 1\n- List item 2\n\n```go\ncode block\n```",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := svc.ValidateContent([]byte(tt.content))
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSnippet(t *testing.T) {
|
||||
svc := NewService()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
maxLength int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple text",
|
||||
content: "Hello world",
|
||||
maxLength: 100,
|
||||
expected: "Hello world",
|
||||
},
|
||||
{
|
||||
name: "text with formatting",
|
||||
content: "This is **bold** and *italic* text.",
|
||||
maxLength: 100,
|
||||
expected: "This is bold and italic text.",
|
||||
},
|
||||
{
|
||||
name: "truncate long text",
|
||||
content: "This is a very long piece of text that should be truncated at a word boundary.",
|
||||
maxLength: 30,
|
||||
expected: "This is a very long piece of ...",
|
||||
},
|
||||
{
|
||||
name: "heading and paragraph",
|
||||
content: "# My Title\n\nThis is the first paragraph.",
|
||||
maxLength: 100,
|
||||
expected: "My Title This is the first paragraph.",
|
||||
},
|
||||
{
|
||||
name: "code block removed",
|
||||
content: "Text before\n\n```go\ncode\n```\n\nText after",
|
||||
maxLength: 100,
|
||||
expected: "Text before Text after",
|
||||
},
|
||||
{
|
||||
name: "list items",
|
||||
content: "- Item 1\n- Item 2\n- Item 3",
|
||||
maxLength: 100,
|
||||
expected: "Item 1 Item 2 Item 3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
snippet, err := svc.GenerateSnippet([]byte(tt.content), tt.maxLength)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, snippet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProperties(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
hasLink bool
|
||||
hasCode bool
|
||||
hasTasks bool
|
||||
hasInc bool
|
||||
}{
|
||||
{
|
||||
name: "plain text",
|
||||
content: "Just plain text",
|
||||
hasLink: false,
|
||||
hasCode: false,
|
||||
hasTasks: false,
|
||||
hasInc: false,
|
||||
},
|
||||
{
|
||||
name: "with link",
|
||||
content: "Check out [this link](https://example.com)",
|
||||
hasLink: true,
|
||||
hasCode: false,
|
||||
hasTasks: false,
|
||||
hasInc: false,
|
||||
},
|
||||
{
|
||||
name: "with inline code",
|
||||
content: "Use `console.log()` to debug",
|
||||
hasLink: false,
|
||||
hasCode: true,
|
||||
hasTasks: false,
|
||||
hasInc: false,
|
||||
},
|
||||
{
|
||||
name: "with code block",
|
||||
content: "```go\nfunc main() {}\n```",
|
||||
hasLink: false,
|
||||
hasCode: true,
|
||||
hasTasks: false,
|
||||
hasInc: false,
|
||||
},
|
||||
{
|
||||
name: "with completed task",
|
||||
content: "- [x] Completed task",
|
||||
hasLink: false,
|
||||
hasCode: false,
|
||||
hasTasks: true,
|
||||
hasInc: false,
|
||||
},
|
||||
{
|
||||
name: "with incomplete task",
|
||||
content: "- [ ] Todo item",
|
||||
hasLink: false,
|
||||
hasCode: false,
|
||||
hasTasks: true,
|
||||
hasInc: true,
|
||||
},
|
||||
{
|
||||
name: "mixed tasks",
|
||||
content: "- [x] Done\n- [ ] Not done",
|
||||
hasLink: false,
|
||||
hasCode: false,
|
||||
hasTasks: true,
|
||||
hasInc: true,
|
||||
},
|
||||
{
|
||||
name: "everything",
|
||||
content: "# Title\n\n[Link](url)\n\n`code`\n\n- [ ] Task",
|
||||
hasLink: true,
|
||||
hasCode: true,
|
||||
hasTasks: true,
|
||||
hasInc: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc := NewService()
|
||||
|
||||
props, err := svc.ExtractProperties([]byte(tt.content))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.hasLink, props.HasLink, "HasLink")
|
||||
assert.Equal(t, tt.hasCode, props.HasCode, "HasCode")
|
||||
assert.Equal(t, tt.hasTasks, props.HasTaskList, "HasTaskList")
|
||||
assert.Equal(t, tt.hasInc, props.HasIncompleteTasks, "HasIncompleteTasks")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTags(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
withExt bool
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "no tags",
|
||||
content: "Just plain text",
|
||||
withExt: false,
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "single tag",
|
||||
content: "Text with #tag",
|
||||
withExt: true,
|
||||
expected: []string{"tag"},
|
||||
},
|
||||
{
|
||||
name: "multiple tags",
|
||||
content: "Text with #tag1 and #tag2",
|
||||
withExt: true,
|
||||
expected: []string{"tag1", "tag2"},
|
||||
},
|
||||
{
|
||||
name: "duplicate tags",
|
||||
content: "#work is important. #Work #WORK",
|
||||
withExt: true,
|
||||
expected: []string{"work", "Work", "WORK"},
|
||||
},
|
||||
{
|
||||
name: "tags with hyphens and underscores",
|
||||
content: "Tags: #work-notes #2024_plans",
|
||||
withExt: true,
|
||||
expected: []string{"work-notes", "2024_plans"},
|
||||
},
|
||||
{
|
||||
name: "tags at end of sentence",
|
||||
content: "This is important #urgent.",
|
||||
withExt: true,
|
||||
expected: []string{"urgent"},
|
||||
},
|
||||
{
|
||||
name: "headings not tags",
|
||||
content: "## Heading\n\n# Title\n\nText with #realtag",
|
||||
withExt: true,
|
||||
expected: []string{"realtag"},
|
||||
},
|
||||
{
|
||||
name: "numeric tag",
|
||||
content: "Issue #123",
|
||||
withExt: true,
|
||||
expected: []string{"123"},
|
||||
},
|
||||
{
|
||||
name: "tag in list",
|
||||
content: "- Item 1 #todo\n- Item 2 #done",
|
||||
withExt: true,
|
||||
expected: []string{"todo", "done"},
|
||||
},
|
||||
{
|
||||
name: "no extension enabled",
|
||||
content: "Text with #tag",
|
||||
withExt: false,
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Chinese tag",
|
||||
content: "Text with #测试",
|
||||
withExt: true,
|
||||
expected: []string{"测试"},
|
||||
},
|
||||
{
|
||||
name: "Chinese tag followed by punctuation",
|
||||
content: "Text #测试。 More text",
|
||||
withExt: true,
|
||||
expected: []string{"测试"},
|
||||
},
|
||||
{
|
||||
name: "mixed Chinese and ASCII tag",
|
||||
content: "#测试test123 content",
|
||||
withExt: true,
|
||||
expected: []string{"测试test123"},
|
||||
},
|
||||
{
|
||||
name: "Japanese tag",
|
||||
content: "#日本語 content",
|
||||
withExt: true,
|
||||
expected: []string{"日本語"},
|
||||
},
|
||||
{
|
||||
name: "Korean tag",
|
||||
content: "#한국어 content",
|
||||
withExt: true,
|
||||
expected: []string{"한국어"},
|
||||
},
|
||||
{
|
||||
name: "hierarchical tag with Chinese",
|
||||
content: "#work/测试/项目",
|
||||
withExt: true,
|
||||
expected: []string{"work/测试/项目"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var svc Service
|
||||
if tt.withExt {
|
||||
svc = NewService(WithTagExtension())
|
||||
} else {
|
||||
svc = NewService()
|
||||
}
|
||||
|
||||
tags, err := svc.ExtractTags([]byte(tt.content))
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, tt.expected, tags)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniquePreserveCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
input: []string{},
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "unique items",
|
||||
input: []string{"tag1", "tag2", "tag3"},
|
||||
expected: []string{"tag1", "tag2", "tag3"},
|
||||
},
|
||||
{
|
||||
name: "duplicates",
|
||||
input: []string{"tag", "TAG", "Tag"},
|
||||
expected: []string{"tag", "TAG", "Tag"},
|
||||
},
|
||||
{
|
||||
name: "mixed",
|
||||
input: []string{"Work", "work", "Important", "work"},
|
||||
expected: []string{"Work", "work", "Important"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := uniquePreserveCase(tt.input)
|
||||
assert.ElementsMatch(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateAtWord(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLength int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no truncation needed",
|
||||
input: "short",
|
||||
maxLength: 10,
|
||||
expected: "short",
|
||||
},
|
||||
{
|
||||
name: "exact length",
|
||||
input: "exactly ten",
|
||||
maxLength: 11,
|
||||
expected: "exactly ten",
|
||||
},
|
||||
{
|
||||
name: "truncate at word",
|
||||
input: "this is a long sentence",
|
||||
maxLength: 10,
|
||||
expected: "this is a ...",
|
||||
},
|
||||
{
|
||||
name: "truncate very long word",
|
||||
input: "supercalifragilisticexpialidocious",
|
||||
maxLength: 10,
|
||||
expected: "supercalif ...",
|
||||
},
|
||||
{
|
||||
name: "CJK characters without spaces",
|
||||
input: "这是一个很长的中文句子没有空格的情况下也要正确处理",
|
||||
maxLength: 15,
|
||||
expected: "这是一个很长的中文句子没有空格 ...",
|
||||
},
|
||||
{
|
||||
name: "mixed CJK and Latin",
|
||||
input: "这是中文mixed with English文字",
|
||||
maxLength: 10,
|
||||
expected: "这是中文mixed ...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := truncateAtWord(tt.input, tt.maxLength)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests.
|
||||
func BenchmarkGenerateSnippet(b *testing.B) {
|
||||
svc := NewService()
|
||||
content := []byte(`# Large Document
|
||||
|
||||
This is a large document with multiple paragraphs and formatting.
|
||||
|
||||
## Section 1
|
||||
|
||||
Here is some **bold** text and *italic* text with [links](https://example.com).
|
||||
|
||||
- List item 1
|
||||
- List item 2
|
||||
- List item 3
|
||||
|
||||
## Section 2
|
||||
|
||||
More content here with ` + "`inline code`" + ` and other elements.
|
||||
|
||||
` + "```go\nfunc example() {\n return true\n}\n```")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := svc.GenerateSnippet(content, 200)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExtractProperties(b *testing.B) {
|
||||
svc := NewService()
|
||||
content := []byte("# Title\n\n[Link](url)\n\n`code`\n\n- [ ] Task\n- [x] Done")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := svc.ExtractProperties(content)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
139
plugin/markdown/parser/tag.go
Normal file
139
plugin/markdown/parser/tag.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
gast "github.com/yuin/goldmark/ast"
|
||||
"github.com/yuin/goldmark/parser"
|
||||
"github.com/yuin/goldmark/text"
|
||||
|
||||
mast "github.com/usememos/memos/plugin/markdown/ast"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxTagLength defines the maximum number of runes allowed in a tag.
|
||||
MaxTagLength = 100
|
||||
)
|
||||
|
||||
type tagParser struct{}
|
||||
|
||||
// NewTagParser creates a new inline parser for #tag syntax.
|
||||
func NewTagParser() parser.InlineParser {
|
||||
return &tagParser{}
|
||||
}
|
||||
|
||||
// Trigger returns the characters that trigger this parser.
|
||||
func (*tagParser) Trigger() []byte {
|
||||
return []byte{'#'}
|
||||
}
|
||||
|
||||
// isValidTagRune checks if a Unicode rune is valid in a tag.
|
||||
// Uses Unicode categories for proper international character support.
|
||||
func isValidTagRune(r rune) bool {
|
||||
// Allow Unicode letters (any script: Latin, CJK, Arabic, Cyrillic, etc.)
|
||||
if unicode.IsLetter(r) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Allow Unicode digits
|
||||
if unicode.IsNumber(r) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Allow emoji and symbols (So category: Symbol, Other)
|
||||
// This includes emoji, which are essential for social media-style tagging
|
||||
if unicode.IsSymbol(r) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Allow specific ASCII symbols for tag structure
|
||||
// Underscore: word separation (snake_case)
|
||||
// Hyphen: word separation (kebab-case)
|
||||
// Forward slash: hierarchical tags (category/subcategory)
|
||||
// Ampersand: compound tags (science&tech)
|
||||
if r == '_' || r == '-' || r == '/' || r == '&' {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse parses #tag syntax using Unicode-aware validation.
|
||||
// Tags support international characters and follow these rules:
|
||||
// - Must start with # followed by valid tag characters
|
||||
// - Valid characters: Unicode letters, Unicode digits, underscore (_), hyphen (-), forward slash (/)
|
||||
// - Maximum length: 100 runes (Unicode characters)
|
||||
// - Stops at: whitespace, punctuation, or other invalid characters
|
||||
func (*tagParser) Parse(_ gast.Node, block text.Reader, _ parser.Context) gast.Node {
|
||||
line, _ := block.PeekLine()
|
||||
|
||||
// Must start with #
|
||||
if len(line) == 0 || line[0] != '#' {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if it's a heading (## or space after #)
|
||||
if len(line) > 1 {
|
||||
if line[1] == '#' {
|
||||
// It's a heading (##), not a tag
|
||||
return nil
|
||||
}
|
||||
if line[1] == ' ' {
|
||||
// Space after # - heading or just a hash
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
// Just a lone #
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse tag using UTF-8 aware rune iteration
|
||||
tagStart := 1
|
||||
pos := tagStart
|
||||
runeCount := 0
|
||||
|
||||
for pos < len(line) {
|
||||
r, size := utf8.DecodeRune(line[pos:])
|
||||
|
||||
// Stop at invalid UTF-8
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
break
|
||||
}
|
||||
|
||||
// Validate character using Unicode categories
|
||||
if !isValidTagRune(r) {
|
||||
break
|
||||
}
|
||||
|
||||
// Enforce max length (by rune count, not byte count)
|
||||
runeCount++
|
||||
if runeCount > MaxTagLength {
|
||||
break
|
||||
}
|
||||
|
||||
pos += size
|
||||
}
|
||||
|
||||
// Must have at least one character after #
|
||||
if pos <= tagStart {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract tag (without #)
|
||||
tagName := line[tagStart:pos]
|
||||
|
||||
// Make a copy of the tag name
|
||||
tagCopy := make([]byte, len(tagName))
|
||||
copy(tagCopy, tagName)
|
||||
|
||||
// Advance reader
|
||||
block.Advance(pos)
|
||||
|
||||
// Create node
|
||||
node := &mast.TagNode{
|
||||
Tag: tagCopy,
|
||||
}
|
||||
|
||||
return node
|
||||
}
|
||||
251
plugin/markdown/parser/tag_test.go
Normal file
251
plugin/markdown/parser/tag_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yuin/goldmark/parser"
|
||||
"github.com/yuin/goldmark/text"
|
||||
|
||||
mast "github.com/usememos/memos/plugin/markdown/ast"
|
||||
)
|
||||
|
||||
func TestTagParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedTag string
|
||||
shouldParse bool
|
||||
}{
|
||||
{
|
||||
name: "basic tag",
|
||||
input: "#tag",
|
||||
expectedTag: "tag",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "tag with hyphen",
|
||||
input: "#work-notes",
|
||||
expectedTag: "work-notes",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "tag with ampersand",
|
||||
input: "#science&tech",
|
||||
expectedTag: "science&tech",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "tag with underscore",
|
||||
input: "#2024_plans",
|
||||
expectedTag: "2024_plans",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "numeric tag",
|
||||
input: "#123",
|
||||
expectedTag: "123",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "tag followed by space",
|
||||
input: "#tag ",
|
||||
expectedTag: "tag",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "tag followed by punctuation",
|
||||
input: "#tag.",
|
||||
expectedTag: "tag",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "tag in sentence",
|
||||
input: "#important task",
|
||||
expectedTag: "important",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "heading (##)",
|
||||
input: "## Heading",
|
||||
expectedTag: "",
|
||||
shouldParse: false,
|
||||
},
|
||||
{
|
||||
name: "space after hash",
|
||||
input: "# heading",
|
||||
expectedTag: "",
|
||||
shouldParse: false,
|
||||
},
|
||||
{
|
||||
name: "lone hash",
|
||||
input: "#",
|
||||
expectedTag: "",
|
||||
shouldParse: false,
|
||||
},
|
||||
{
|
||||
name: "hash with space",
|
||||
input: "# ",
|
||||
expectedTag: "",
|
||||
shouldParse: false,
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
input: "#tag@special",
|
||||
expectedTag: "tag",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "mixed case",
|
||||
input: "#WorkNotes",
|
||||
expectedTag: "WorkNotes",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "hierarchical tag with slash",
|
||||
input: "#tag1/subtag",
|
||||
expectedTag: "tag1/subtag",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "hierarchical tag with multiple levels",
|
||||
input: "#tag1/subtag/subtag2",
|
||||
expectedTag: "tag1/subtag/subtag2",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "hierarchical tag followed by space",
|
||||
input: "#work/notes ",
|
||||
expectedTag: "work/notes",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "hierarchical tag followed by punctuation",
|
||||
input: "#project/2024.",
|
||||
expectedTag: "project/2024",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "hierarchical tag with numbers and dashes",
|
||||
input: "#work-log/2024/q1",
|
||||
expectedTag: "work-log/2024/q1",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "Chinese characters",
|
||||
input: "#测试",
|
||||
expectedTag: "测试",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "Chinese tag followed by space",
|
||||
input: "#测试 some text",
|
||||
expectedTag: "测试",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "Chinese tag followed by punctuation",
|
||||
input: "#测试。",
|
||||
expectedTag: "测试",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "mixed Chinese and ASCII",
|
||||
input: "#测试test123",
|
||||
expectedTag: "测试test123",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "Japanese characters",
|
||||
input: "#テスト",
|
||||
expectedTag: "テスト",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "Korean characters",
|
||||
input: "#테스트",
|
||||
expectedTag: "테스트",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "emoji",
|
||||
input: "#test🚀",
|
||||
expectedTag: "test🚀",
|
||||
shouldParse: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewTagParser()
|
||||
reader := text.NewReader([]byte(tt.input))
|
||||
ctx := parser.NewContext()
|
||||
|
||||
node := p.Parse(nil, reader, ctx)
|
||||
|
||||
if tt.shouldParse {
|
||||
require.NotNil(t, node, "Expected tag to be parsed")
|
||||
require.IsType(t, &mast.TagNode{}, node)
|
||||
|
||||
tagNode, ok := node.(*mast.TagNode)
|
||||
require.True(t, ok, "Expected node to be *mast.TagNode")
|
||||
assert.Equal(t, tt.expectedTag, string(tagNode.Tag))
|
||||
} else {
|
||||
assert.Nil(t, node, "Expected tag NOT to be parsed")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagParser_Trigger(t *testing.T) {
|
||||
p := NewTagParser()
|
||||
triggers := p.Trigger()
|
||||
|
||||
assert.Equal(t, []byte{'#'}, triggers)
|
||||
}
|
||||
|
||||
func TestTagParser_MultipleTags(t *testing.T) {
|
||||
// Test that parser correctly handles multiple tags in sequence
|
||||
input := "#tag1 #tag2"
|
||||
|
||||
p := NewTagParser()
|
||||
reader := text.NewReader([]byte(input))
|
||||
ctx := parser.NewContext()
|
||||
|
||||
// Parse first tag
|
||||
node1 := p.Parse(nil, reader, ctx)
|
||||
require.NotNil(t, node1)
|
||||
tagNode1, ok := node1.(*mast.TagNode)
|
||||
require.True(t, ok, "Expected node1 to be *mast.TagNode")
|
||||
assert.Equal(t, "tag1", string(tagNode1.Tag))
|
||||
|
||||
// Advance past the space
|
||||
reader.Advance(1)
|
||||
|
||||
// Parse second tag
|
||||
node2 := p.Parse(nil, reader, ctx)
|
||||
require.NotNil(t, node2)
|
||||
tagNode2, ok := node2.(*mast.TagNode)
|
||||
require.True(t, ok, "Expected node2 to be *mast.TagNode")
|
||||
assert.Equal(t, "tag2", string(tagNode2.Tag))
|
||||
}
|
||||
|
||||
func TestTagNode_Kind(t *testing.T) {
|
||||
node := &mast.TagNode{
|
||||
Tag: []byte("test"),
|
||||
}
|
||||
|
||||
assert.Equal(t, mast.KindTag, node.Kind())
|
||||
}
|
||||
|
||||
func TestTagNode_Dump(t *testing.T) {
|
||||
node := &mast.TagNode{
|
||||
Tag: []byte("test"),
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
node.Dump([]byte("#test"), 0)
|
||||
})
|
||||
}
|
||||
266
plugin/markdown/renderer/markdown_renderer.go
Normal file
266
plugin/markdown/renderer/markdown_renderer.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package renderer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
gast "github.com/yuin/goldmark/ast"
|
||||
east "github.com/yuin/goldmark/extension/ast"
|
||||
|
||||
mast "github.com/usememos/memos/plugin/markdown/ast"
|
||||
)
|
||||
|
||||
// MarkdownRenderer renders goldmark AST back to markdown text.
|
||||
type MarkdownRenderer struct {
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
// NewMarkdownRenderer creates a new markdown renderer.
|
||||
func NewMarkdownRenderer() *MarkdownRenderer {
|
||||
return &MarkdownRenderer{
|
||||
buf: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
// Render renders the AST node to markdown and returns the result.
|
||||
func (r *MarkdownRenderer) Render(node gast.Node, source []byte) string {
|
||||
r.buf.Reset()
|
||||
r.renderNode(node, source, 0)
|
||||
return r.buf.String()
|
||||
}
|
||||
|
||||
// renderNode renders a single node and its children.
|
||||
func (r *MarkdownRenderer) renderNode(node gast.Node, source []byte, depth int) {
|
||||
switch n := node.(type) {
|
||||
case *gast.Document:
|
||||
r.renderChildren(n, source, depth)
|
||||
|
||||
case *gast.Paragraph:
|
||||
r.renderChildren(n, source, depth)
|
||||
if node.NextSibling() != nil {
|
||||
r.buf.WriteString("\n\n")
|
||||
}
|
||||
|
||||
case *gast.Text:
|
||||
// Text nodes store their content as segments in the source
|
||||
segment := n.Segment
|
||||
r.buf.Write(segment.Value(source))
|
||||
if n.SoftLineBreak() {
|
||||
r.buf.WriteByte('\n')
|
||||
} else if n.HardLineBreak() {
|
||||
r.buf.WriteString(" \n")
|
||||
}
|
||||
|
||||
case *gast.CodeSpan:
|
||||
r.buf.WriteByte('`')
|
||||
r.renderChildren(n, source, depth)
|
||||
r.buf.WriteByte('`')
|
||||
|
||||
case *gast.Emphasis:
|
||||
symbol := "*"
|
||||
if n.Level == 2 {
|
||||
symbol = "**"
|
||||
}
|
||||
r.buf.WriteString(symbol)
|
||||
r.renderChildren(n, source, depth)
|
||||
r.buf.WriteString(symbol)
|
||||
|
||||
case *gast.Link:
|
||||
r.buf.WriteString("[")
|
||||
r.renderChildren(n, source, depth)
|
||||
r.buf.WriteString("](")
|
||||
r.buf.Write(n.Destination)
|
||||
if len(n.Title) > 0 {
|
||||
r.buf.WriteString(` "`)
|
||||
r.buf.Write(n.Title)
|
||||
r.buf.WriteString(`"`)
|
||||
}
|
||||
r.buf.WriteString(")")
|
||||
|
||||
case *gast.AutoLink:
|
||||
url := n.URL(source)
|
||||
if n.AutoLinkType == gast.AutoLinkEmail {
|
||||
r.buf.WriteString("<")
|
||||
r.buf.Write(url)
|
||||
r.buf.WriteString(">")
|
||||
} else {
|
||||
r.buf.Write(url)
|
||||
}
|
||||
|
||||
case *gast.Image:
|
||||
r.buf.WriteString("
|
||||
r.buf.Write(n.Destination)
|
||||
if len(n.Title) > 0 {
|
||||
r.buf.WriteString(` "`)
|
||||
r.buf.Write(n.Title)
|
||||
r.buf.WriteString(`"`)
|
||||
}
|
||||
r.buf.WriteString(")")
|
||||
|
||||
case *gast.Heading:
|
||||
r.buf.WriteString(strings.Repeat("#", n.Level))
|
||||
r.buf.WriteByte(' ')
|
||||
r.renderChildren(n, source, depth)
|
||||
if node.NextSibling() != nil {
|
||||
r.buf.WriteString("\n\n")
|
||||
}
|
||||
|
||||
case *gast.CodeBlock, *gast.FencedCodeBlock:
|
||||
r.renderCodeBlock(n, source)
|
||||
|
||||
case *gast.Blockquote:
|
||||
// Render each child line with "> " prefix
|
||||
r.renderBlockquote(n, source, depth)
|
||||
if node.NextSibling() != nil {
|
||||
r.buf.WriteString("\n\n")
|
||||
}
|
||||
|
||||
case *gast.List:
|
||||
r.renderChildren(n, source, depth)
|
||||
if node.NextSibling() != nil {
|
||||
r.buf.WriteString("\n\n")
|
||||
}
|
||||
|
||||
case *gast.ListItem:
|
||||
r.renderListItem(n, source, depth)
|
||||
|
||||
case *gast.ThematicBreak:
|
||||
r.buf.WriteString("---")
|
||||
if node.NextSibling() != nil {
|
||||
r.buf.WriteString("\n\n")
|
||||
}
|
||||
|
||||
case *east.Strikethrough:
|
||||
r.buf.WriteString("~~")
|
||||
r.renderChildren(n, source, depth)
|
||||
r.buf.WriteString("~~")
|
||||
|
||||
case *east.TaskCheckBox:
|
||||
if n.IsChecked {
|
||||
r.buf.WriteString("[x] ")
|
||||
} else {
|
||||
r.buf.WriteString("[ ] ")
|
||||
}
|
||||
|
||||
case *east.Table:
|
||||
r.renderTable(n, source)
|
||||
if node.NextSibling() != nil {
|
||||
r.buf.WriteString("\n\n")
|
||||
}
|
||||
|
||||
// Custom Memos nodes
|
||||
case *mast.TagNode:
|
||||
r.buf.WriteByte('#')
|
||||
r.buf.Write(n.Tag)
|
||||
|
||||
default:
|
||||
// For unknown nodes, try to render children
|
||||
r.renderChildren(n, source, depth)
|
||||
}
|
||||
}
|
||||
|
||||
// renderChildren renders all children of a node.
|
||||
func (r *MarkdownRenderer) renderChildren(node gast.Node, source []byte, depth int) {
|
||||
child := node.FirstChild()
|
||||
for child != nil {
|
||||
r.renderNode(child, source, depth+1)
|
||||
child = child.NextSibling()
|
||||
}
|
||||
}
|
||||
|
||||
// renderCodeBlock renders a code block.
|
||||
func (r *MarkdownRenderer) renderCodeBlock(node gast.Node, source []byte) {
|
||||
if fenced, ok := node.(*gast.FencedCodeBlock); ok {
|
||||
// Fenced code block with language
|
||||
r.buf.WriteString("```")
|
||||
if lang := fenced.Language(source); len(lang) > 0 {
|
||||
r.buf.Write(lang)
|
||||
}
|
||||
r.buf.WriteByte('\n')
|
||||
|
||||
// Write all lines
|
||||
lines := fenced.Lines()
|
||||
for i := 0; i < lines.Len(); i++ {
|
||||
line := lines.At(i)
|
||||
r.buf.Write(line.Value(source))
|
||||
}
|
||||
|
||||
r.buf.WriteString("```")
|
||||
if node.NextSibling() != nil {
|
||||
r.buf.WriteString("\n\n")
|
||||
}
|
||||
} else if codeBlock, ok := node.(*gast.CodeBlock); ok {
|
||||
// Indented code block
|
||||
lines := codeBlock.Lines()
|
||||
for i := 0; i < lines.Len(); i++ {
|
||||
line := lines.At(i)
|
||||
r.buf.WriteString(" ")
|
||||
r.buf.Write(line.Value(source))
|
||||
}
|
||||
if node.NextSibling() != nil {
|
||||
r.buf.WriteString("\n\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// renderBlockquote renders a blockquote with "> " prefix.
|
||||
func (r *MarkdownRenderer) renderBlockquote(node *gast.Blockquote, source []byte, depth int) {
|
||||
// Create a temporary buffer for the blockquote content
|
||||
tempBuf := &bytes.Buffer{}
|
||||
tempRenderer := &MarkdownRenderer{buf: tempBuf}
|
||||
tempRenderer.renderChildren(node, source, depth)
|
||||
|
||||
// Add "> " prefix to each line
|
||||
content := tempBuf.String()
|
||||
lines := strings.Split(strings.TrimRight(content, "\n"), "\n")
|
||||
for i, line := range lines {
|
||||
r.buf.WriteString("> ")
|
||||
r.buf.WriteString(line)
|
||||
if i < len(lines)-1 {
|
||||
r.buf.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// renderListItem renders a list item with proper indentation and markers.
|
||||
func (r *MarkdownRenderer) renderListItem(node *gast.ListItem, source []byte, depth int) {
|
||||
parent := node.Parent()
|
||||
list, ok := parent.(*gast.List)
|
||||
if !ok {
|
||||
r.renderChildren(node, source, depth)
|
||||
return
|
||||
}
|
||||
|
||||
// Add indentation only for nested lists
|
||||
// Document=0, List=1, ListItem=2 (no indent), nested ListItem=3+ (indent)
|
||||
if depth > 2 {
|
||||
indent := strings.Repeat(" ", depth-2)
|
||||
r.buf.WriteString(indent)
|
||||
}
|
||||
|
||||
// Add list marker
|
||||
if list.IsOrdered() {
|
||||
fmt.Fprintf(r.buf, "%d. ", list.Start)
|
||||
list.Start++ // Increment for next item
|
||||
} else {
|
||||
r.buf.WriteString("- ")
|
||||
}
|
||||
|
||||
// Render content
|
||||
r.renderChildren(node, source, depth)
|
||||
|
||||
// Add newline if there's a next sibling
|
||||
if node.NextSibling() != nil {
|
||||
r.buf.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
// renderTable renders a table in markdown format.
|
||||
func (r *MarkdownRenderer) renderTable(table *east.Table, source []byte) {
|
||||
// This is a simplified table renderer
|
||||
// A full implementation would need to handle alignment, etc.
|
||||
r.renderChildren(table, source, 0)
|
||||
}
|
||||
176
plugin/markdown/renderer/markdown_renderer_test.go
Normal file
176
plugin/markdown/renderer/markdown_renderer_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package renderer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yuin/goldmark"
|
||||
"github.com/yuin/goldmark/extension"
|
||||
"github.com/yuin/goldmark/parser"
|
||||
"github.com/yuin/goldmark/text"
|
||||
|
||||
"github.com/usememos/memos/plugin/markdown/extensions"
|
||||
)
|
||||
|
||||
func TestMarkdownRenderer(t *testing.T) {
|
||||
// Create goldmark instance with all extensions
|
||||
md := goldmark.New(
|
||||
goldmark.WithExtensions(
|
||||
extension.GFM,
|
||||
extensions.TagExtension,
|
||||
),
|
||||
goldmark.WithParserOptions(
|
||||
parser.WithAutoHeadingID(),
|
||||
),
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple text",
|
||||
input: "Hello world",
|
||||
expected: "Hello world",
|
||||
},
|
||||
{
|
||||
name: "paragraph with newlines",
|
||||
input: "First paragraph\n\nSecond paragraph",
|
||||
expected: "First paragraph\n\nSecond paragraph",
|
||||
},
|
||||
{
|
||||
name: "emphasis",
|
||||
input: "This is *italic* and **bold** text",
|
||||
expected: "This is *italic* and **bold** text",
|
||||
},
|
||||
{
|
||||
name: "headings",
|
||||
input: "# Heading 1\n\n## Heading 2\n\n### Heading 3",
|
||||
expected: "# Heading 1\n\n## Heading 2\n\n### Heading 3",
|
||||
},
|
||||
{
|
||||
name: "link",
|
||||
input: "Check [this link](https://example.com)",
|
||||
expected: "Check [this link](https://example.com)",
|
||||
},
|
||||
{
|
||||
name: "image",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "code inline",
|
||||
input: "This is `inline code` here",
|
||||
expected: "This is `inline code` here",
|
||||
},
|
||||
{
|
||||
name: "code block fenced",
|
||||
input: "```go\nfunc main() {\n}\n```",
|
||||
expected: "```go\nfunc main() {\n}\n```",
|
||||
},
|
||||
{
|
||||
name: "unordered list",
|
||||
input: "- Item 1\n- Item 2\n- Item 3",
|
||||
expected: "- Item 1\n- Item 2\n- Item 3",
|
||||
},
|
||||
{
|
||||
name: "ordered list",
|
||||
input: "1. First\n2. Second\n3. Third",
|
||||
expected: "1. First\n2. Second\n3. Third",
|
||||
},
|
||||
{
|
||||
name: "blockquote",
|
||||
input: "> This is a quote\n> Second line",
|
||||
expected: "> This is a quote\n> Second line",
|
||||
},
|
||||
{
|
||||
name: "horizontal rule",
|
||||
input: "Text before\n\n---\n\nText after",
|
||||
expected: "Text before\n\n---\n\nText after",
|
||||
},
|
||||
{
|
||||
name: "strikethrough",
|
||||
input: "This is ~~deleted~~ text",
|
||||
expected: "This is ~~deleted~~ text",
|
||||
},
|
||||
{
|
||||
name: "task list",
|
||||
input: "- [x] Completed task\n- [ ] Incomplete task",
|
||||
expected: "- [x] Completed task\n- [ ] Incomplete task",
|
||||
},
|
||||
{
|
||||
name: "tag",
|
||||
input: "This has #tag in it",
|
||||
expected: "This has #tag in it",
|
||||
},
|
||||
{
|
||||
name: "multiple tags",
|
||||
input: "#work #important meeting notes",
|
||||
expected: "#work #important meeting notes",
|
||||
},
|
||||
{
|
||||
name: "complex mixed content",
|
||||
input: "# Meeting Notes\n\n**Date**: 2024-01-01\n\n## Attendees\n- Alice\n- Bob\n\n## Discussion\n\nWe discussed #project status.\n\n```python\nprint('hello')\n```",
|
||||
expected: "# Meeting Notes\n\n**Date**: 2024-01-01\n\n## Attendees\n\n- Alice\n- Bob\n\n## Discussion\n\nWe discussed #project status.\n\n```python\nprint('hello')\n```",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Parse the input
|
||||
source := []byte(tt.input)
|
||||
reader := text.NewReader(source)
|
||||
doc := md.Parser().Parse(reader)
|
||||
require.NotNil(t, doc)
|
||||
|
||||
// Render back to markdown
|
||||
renderer := NewMarkdownRenderer()
|
||||
result := renderer.Render(doc, source)
|
||||
|
||||
// For debugging
|
||||
if result != tt.expected {
|
||||
t.Logf("Input: %q", tt.input)
|
||||
t.Logf("Expected: %q", tt.expected)
|
||||
t.Logf("Got: %q", result)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownRendererPreservesStructure(t *testing.T) {
|
||||
// Test that parsing and rendering preserves structure
|
||||
md := goldmark.New(
|
||||
goldmark.WithExtensions(
|
||||
extension.GFM,
|
||||
extensions.TagExtension,
|
||||
),
|
||||
)
|
||||
|
||||
inputs := []string{
|
||||
"# Title\n\nParagraph",
|
||||
"**Bold** and *italic*",
|
||||
"- List\n- Items",
|
||||
"#tag #another",
|
||||
"> Quote",
|
||||
}
|
||||
|
||||
renderer := NewMarkdownRenderer()
|
||||
|
||||
for _, input := range inputs {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
source := []byte(input)
|
||||
reader := text.NewReader(source)
|
||||
doc := md.Parser().Parse(reader)
|
||||
|
||||
result := renderer.Render(doc, source)
|
||||
|
||||
// The result should be structurally similar
|
||||
// (may have minor formatting differences)
|
||||
assert.NotEmpty(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
367
plugin/scheduler/README.md
Normal file
367
plugin/scheduler/README.md
Normal file
@@ -0,0 +1,367 @@
|
||||
# Scheduler Plugin
|
||||
|
||||
A production-ready, GitHub Actions-inspired cron job scheduler for Go.
|
||||
|
||||
## Features
|
||||
|
||||
- **Standard Cron Syntax**: Supports both 5-field and 6-field (with seconds) cron expressions
|
||||
- **Timezone-Aware**: Explicit timezone handling to avoid DST surprises
|
||||
- **Middleware Pattern**: Composable job wrappers for logging, metrics, panic recovery, timeouts
|
||||
- **Graceful Shutdown**: Jobs complete cleanly or cancel when context expires
|
||||
- **Zero Dependencies**: Core functionality uses only the standard library
|
||||
- **Type-Safe**: Strong typing with clear error messages
|
||||
- **Well-Tested**: Comprehensive test coverage
|
||||
|
||||
## Installation
|
||||
|
||||
This package is included with Memos. No separate installation required.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/usememos/memos/plugin/scheduler"
|
||||
)
|
||||
|
||||
func main() {
|
||||
s := scheduler.New()
|
||||
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "daily-cleanup",
|
||||
Schedule: "0 2 * * *", // 2 AM daily
|
||||
Handler: func(ctx context.Context) error {
|
||||
fmt.Println("Running cleanup...")
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
s.Start()
|
||||
defer s.Stop(context.Background())
|
||||
|
||||
// Keep running...
|
||||
select {}
|
||||
}
|
||||
```
|
||||
|
||||
## Cron Expression Format
|
||||
|
||||
### 5-Field Format (Standard)
|
||||
```
|
||||
┌───────────── minute (0 - 59)
|
||||
│ ┌───────────── hour (0 - 23)
|
||||
│ │ ┌───────────── day of month (1 - 31)
|
||||
│ │ │ ┌───────────── month (1 - 12)
|
||||
│ │ │ │ ┌───────────── day of week (0 - 7) (Sunday = 0 or 7)
|
||||
│ │ │ │ │
|
||||
* * * * *
|
||||
```
|
||||
|
||||
### 6-Field Format (With Seconds)
|
||||
```
|
||||
┌───────────── second (0 - 59)
|
||||
│ ┌───────────── minute (0 - 59)
|
||||
│ │ ┌───────────── hour (0 - 23)
|
||||
│ │ │ ┌───────────── day of month (1 - 31)
|
||||
│ │ │ │ ┌───────────── month (1 - 12)
|
||||
│ │ │ │ │ ┌───────────── day of week (0 - 7)
|
||||
│ │ │ │ │ │
|
||||
* * * * * *
|
||||
```
|
||||
|
||||
### Special Characters
|
||||
|
||||
- `*` - Any value (every minute, every hour, etc.)
|
||||
- `,` - List of values: `1,15,30` (1st, 15th, and 30th)
|
||||
- `-` - Range: `9-17` (9 AM through 5 PM)
|
||||
- `/` - Step: `*/15` (every 15 units)
|
||||
|
||||
### Common Examples
|
||||
|
||||
| Schedule | Description |
|
||||
|----------|-------------|
|
||||
| `* * * * *` | Every minute |
|
||||
| `0 * * * *` | Every hour |
|
||||
| `0 0 * * *` | Daily at midnight |
|
||||
| `0 9 * * 1-5` | Weekdays at 9 AM |
|
||||
| `*/15 * * * *` | Every 15 minutes |
|
||||
| `0 0 1 * *` | First day of every month |
|
||||
| `0 0 * * 0` | Every Sunday at midnight |
|
||||
| `30 14 * * *` | Every day at 2:30 PM |
|
||||
|
||||
## Timezone Support
|
||||
|
||||
```go
|
||||
// Global timezone for all jobs
|
||||
s := scheduler.New(
|
||||
scheduler.WithTimezone("America/New_York"),
|
||||
)
|
||||
|
||||
// Per-job timezone (overrides global)
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "tokyo-report",
|
||||
Schedule: "0 9 * * *", // 9 AM Tokyo time
|
||||
Timezone: "Asia/Tokyo",
|
||||
Handler: func(ctx context.Context) error {
|
||||
// Runs at 9 AM in Tokyo
|
||||
return nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
**Important**: Always use IANA timezone names (`America/New_York`, not `EST`).
|
||||
|
||||
## Middleware
|
||||
|
||||
Middleware wraps job handlers to add cross-cutting behavior. Multiple middleware can be chained together.
|
||||
|
||||
### Built-in Middleware
|
||||
|
||||
#### Recovery (Panic Handling)
|
||||
|
||||
```go
|
||||
s := scheduler.New(
|
||||
scheduler.WithMiddleware(
|
||||
scheduler.Recovery(func(jobName string, r interface{}) {
|
||||
log.Printf("Job %s panicked: %v", jobName, r)
|
||||
}),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
#### Logging
|
||||
|
||||
```go
|
||||
type Logger interface {
|
||||
Info(msg string, args ...interface{})
|
||||
Error(msg string, args ...interface{})
|
||||
}
|
||||
|
||||
s := scheduler.New(
|
||||
scheduler.WithMiddleware(
|
||||
scheduler.Logging(myLogger),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
#### Timeout
|
||||
|
||||
```go
|
||||
s := scheduler.New(
|
||||
scheduler.WithMiddleware(
|
||||
scheduler.Timeout(5 * time.Minute),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### Combining Middleware
|
||||
|
||||
```go
|
||||
s := scheduler.New(
|
||||
scheduler.WithMiddleware(
|
||||
scheduler.Recovery(panicHandler),
|
||||
scheduler.Logging(logger),
|
||||
scheduler.Timeout(10 * time.Minute),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
**Order matters**: Middleware are applied left-to-right. In the example above:
|
||||
1. Recovery (outermost) catches panics from everything
|
||||
2. Logging logs the execution
|
||||
3. Timeout (innermost) wraps the actual handler
|
||||
|
||||
### Custom Middleware
|
||||
|
||||
```go
|
||||
func Metrics(recorder MetricsRecorder) scheduler.Middleware {
|
||||
return func(next scheduler.JobHandler) scheduler.JobHandler {
|
||||
return func(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
err := next(ctx)
|
||||
duration := time.Since(start)
|
||||
|
||||
jobName := scheduler.GetJobName(ctx)
|
||||
recorder.Record(jobName, duration, err)
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Graceful Shutdown
|
||||
|
||||
Always use `Stop()` with a context to allow jobs to finish cleanly:
|
||||
|
||||
```go
|
||||
// Give jobs up to 30 seconds to complete
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Stop(ctx); err != nil {
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
Jobs should respect context cancellation:
|
||||
|
||||
```go
|
||||
Handler: func(ctx context.Context) error {
|
||||
for i := 0; i < 100; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err() // Canceled
|
||||
default:
|
||||
// Do work
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Always Name Your Jobs
|
||||
|
||||
Names are used for logging, metrics, and debugging:
|
||||
|
||||
```go
|
||||
Name: "user-cleanup-job" // Good
|
||||
Name: "job1" // Bad
|
||||
```
|
||||
|
||||
### 2. Add Descriptions and Tags
|
||||
|
||||
```go
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "stale-session-cleanup",
|
||||
Description: "Removes user sessions older than 30 days",
|
||||
Tags: []string{"maintenance", "security"},
|
||||
Schedule: "0 3 * * *",
|
||||
Handler: cleanupSessions,
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Use Appropriate Middleware
|
||||
|
||||
Always include Recovery and Logging in production:
|
||||
|
||||
```go
|
||||
scheduler.New(
|
||||
scheduler.WithMiddleware(
|
||||
scheduler.Recovery(logPanic),
|
||||
scheduler.Logging(logger),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Avoid Scheduling Exactly on the Hour
|
||||
|
||||
Many systems schedule jobs at `:00`, causing load spikes. Stagger your jobs:
|
||||
|
||||
```go
|
||||
"5 2 * * *" // 2:05 AM (good)
|
||||
"0 2 * * *" // 2:00 AM (often overloaded)
|
||||
```
|
||||
|
||||
### 5. Make Jobs Idempotent
|
||||
|
||||
Jobs may run multiple times (crash recovery, etc.). Design them to be safely re-runnable:
|
||||
|
||||
```go
|
||||
Handler: func(ctx context.Context) error {
|
||||
// Use unique constraint or check-before-insert
|
||||
db.Exec("INSERT IGNORE INTO processed_items ...")
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### 6. Handle Timezones Explicitly
|
||||
|
||||
Always specify timezone for business-hour jobs:
|
||||
|
||||
```go
|
||||
Timezone: "America/New_York" // Good
|
||||
// Timezone: "" // Bad (defaults to UTC)
|
||||
```
|
||||
|
||||
### 7. Test Your Cron Expressions
|
||||
|
||||
Use a cron expression calculator before deploying:
|
||||
- [crontab.guru](https://crontab.guru/)
|
||||
- Write unit tests with the parser
|
||||
|
||||
## Testing Jobs
|
||||
|
||||
Test job handlers independently of the scheduler:
|
||||
|
||||
```go
|
||||
func TestCleanupJob(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
err := cleanupHandler(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("cleanup failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify cleanup occurred
|
||||
}
|
||||
```
|
||||
|
||||
Test schedule parsing:
|
||||
|
||||
```go
|
||||
func TestScheduleParsing(t *testing.T) {
|
||||
job := &scheduler.Job{
|
||||
Name: "test",
|
||||
Schedule: "0 2 * * *",
|
||||
Handler: func(ctx context.Context) error { return nil },
|
||||
}
|
||||
|
||||
if err := job.Validate(); err != nil {
|
||||
t.Fatalf("invalid job: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Comparison to Other Solutions
|
||||
|
||||
| Feature | scheduler | robfig/cron | github.com/go-co-op/gocron |
|
||||
|---------|-----------|-------------|----------------------------|
|
||||
| Standard cron syntax | ✅ | ✅ | ✅ |
|
||||
| Seconds support | ✅ | ✅ | ✅ |
|
||||
| Timezone support | ✅ | ✅ | ✅ |
|
||||
| Middleware pattern | ✅ | ⚠️ (basic) | ❌ |
|
||||
| Graceful shutdown | ✅ | ⚠️ (basic) | ✅ |
|
||||
| Zero dependencies | ✅ | ❌ | ❌ |
|
||||
| Job metadata | ✅ | ❌ | ⚠️ (limited) |
|
||||
|
||||
## API Reference
|
||||
|
||||
See [example_test.go](./example_test.go) for comprehensive examples.
|
||||
|
||||
### Core Types
|
||||
|
||||
- `Scheduler` - Manages scheduled jobs
|
||||
- `Job` - Job definition with schedule and handler
|
||||
- `Middleware` - Function that wraps job handlers
|
||||
|
||||
### Functions
|
||||
|
||||
- `New(opts ...Option) *Scheduler` - Create new scheduler
|
||||
- `WithTimezone(tz string) Option` - Set default timezone
|
||||
- `WithMiddleware(mw ...Middleware) Option` - Add middleware
|
||||
|
||||
### Methods
|
||||
|
||||
- `Register(job *Job) error` - Add job to scheduler
|
||||
- `Start() error` - Begin executing jobs
|
||||
- `Stop(ctx context.Context) error` - Graceful shutdown
|
||||
|
||||
## License
|
||||
|
||||
This package is part of the Memos project and shares its license.
|
||||
35
plugin/scheduler/doc.go
Normal file
35
plugin/scheduler/doc.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Package scheduler provides a GitHub Actions-inspired cron job scheduler.
|
||||
//
|
||||
// Features:
|
||||
// - Standard cron expression syntax (5-field and 6-field formats)
|
||||
// - Timezone-aware scheduling
|
||||
// - Middleware pattern for cross-cutting concerns (logging, metrics, recovery)
|
||||
// - Graceful shutdown with context cancellation
|
||||
// - Zero external dependencies
|
||||
//
|
||||
// Basic usage:
|
||||
//
|
||||
// s := scheduler.New()
|
||||
//
|
||||
// s.Register(&scheduler.Job{
|
||||
// Name: "daily-cleanup",
|
||||
// Schedule: "0 2 * * *", // 2 AM daily
|
||||
// Handler: func(ctx context.Context) error {
|
||||
// // Your cleanup logic here
|
||||
// return nil
|
||||
// },
|
||||
// })
|
||||
//
|
||||
// s.Start()
|
||||
// defer s.Stop(context.Background())
|
||||
//
|
||||
// With middleware:
|
||||
//
|
||||
// s := scheduler.New(
|
||||
// scheduler.WithTimezone("America/New_York"),
|
||||
// scheduler.WithMiddleware(
|
||||
// scheduler.Recovery(),
|
||||
// scheduler.Logging(),
|
||||
// ),
|
||||
// )
|
||||
package scheduler
|
||||
165
plugin/scheduler/example_test.go
Normal file
165
plugin/scheduler/example_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package scheduler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/usememos/memos/plugin/scheduler"
|
||||
)
|
||||
|
||||
// Example demonstrates basic scheduler usage.
|
||||
func Example_basic() {
|
||||
s := scheduler.New()
|
||||
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "hello",
|
||||
Schedule: "*/5 * * * *", // Every 5 minutes
|
||||
Description: "Say hello",
|
||||
Handler: func(_ context.Context) error {
|
||||
fmt.Println("Hello from scheduler!")
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
s.Start()
|
||||
defer s.Stop(context.Background())
|
||||
|
||||
// Scheduler runs in background
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Example demonstrates timezone-aware scheduling.
|
||||
func Example_timezone() {
|
||||
s := scheduler.New(
|
||||
scheduler.WithTimezone("America/New_York"),
|
||||
)
|
||||
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "daily-report",
|
||||
Schedule: "0 9 * * *", // 9 AM in New York
|
||||
Handler: func(_ context.Context) error {
|
||||
fmt.Println("Generating daily report...")
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
s.Start()
|
||||
defer s.Stop(context.Background())
|
||||
}
|
||||
|
||||
// Example demonstrates middleware usage.
|
||||
func Example_middleware() {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
s := scheduler.New(
|
||||
scheduler.WithMiddleware(
|
||||
scheduler.Recovery(func(jobName string, r interface{}) {
|
||||
logger.Error("Job panicked", "job", jobName, "panic", r)
|
||||
}),
|
||||
scheduler.Logging(&slogAdapter{logger}),
|
||||
scheduler.Timeout(5*time.Minute),
|
||||
),
|
||||
)
|
||||
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "data-sync",
|
||||
Schedule: "0 */2 * * *", // Every 2 hours
|
||||
Handler: func(_ context.Context) error {
|
||||
// Your sync logic here
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
s.Start()
|
||||
defer s.Stop(context.Background())
|
||||
}
|
||||
|
||||
// slogAdapter adapts slog.Logger to scheduler.Logger interface.
|
||||
type slogAdapter struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (a *slogAdapter) Info(msg string, args ...interface{}) {
|
||||
a.logger.Info(msg, args...)
|
||||
}
|
||||
|
||||
func (a *slogAdapter) Error(msg string, args ...interface{}) {
|
||||
a.logger.Error(msg, args...)
|
||||
}
|
||||
|
||||
// Example demonstrates multiple jobs with different schedules.
|
||||
func Example_multipleJobs() {
|
||||
s := scheduler.New()
|
||||
|
||||
// Cleanup old data every night at 2 AM
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "cleanup",
|
||||
Schedule: "0 2 * * *",
|
||||
Tags: []string{"maintenance"},
|
||||
Handler: func(_ context.Context) error {
|
||||
fmt.Println("Cleaning up old data...")
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// Health check every 5 minutes
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "health-check",
|
||||
Schedule: "*/5 * * * *",
|
||||
Tags: []string{"monitoring"},
|
||||
Handler: func(_ context.Context) error {
|
||||
fmt.Println("Running health check...")
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// Weekly backup on Sundays at 1 AM
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "weekly-backup",
|
||||
Schedule: "0 1 * * 0",
|
||||
Tags: []string{"backup"},
|
||||
Handler: func(_ context.Context) error {
|
||||
fmt.Println("Creating weekly backup...")
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
s.Start()
|
||||
defer s.Stop(context.Background())
|
||||
}
|
||||
|
||||
// Example demonstrates graceful shutdown with timeout.
|
||||
func Example_gracefulShutdown() {
|
||||
s := scheduler.New()
|
||||
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "long-running",
|
||||
Schedule: "* * * * *",
|
||||
Handler: func(ctx context.Context) error {
|
||||
select {
|
||||
case <-time.After(30 * time.Second):
|
||||
fmt.Println("Job completed")
|
||||
case <-ctx.Done():
|
||||
fmt.Println("Job canceled, cleaning up...")
|
||||
return ctx.Err()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
s.Start()
|
||||
|
||||
// Simulate shutdown signal
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Give jobs 10 seconds to finish
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Stop(shutdownCtx); err != nil {
|
||||
fmt.Printf("Shutdown error: %v\n", err)
|
||||
}
|
||||
}
|
||||
393
plugin/scheduler/integration_test.go
Normal file
393
plugin/scheduler/integration_test.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package scheduler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/usememos/memos/plugin/scheduler"
|
||||
)
|
||||
|
||||
// TestRealWorldScenario tests a realistic multi-job scenario.
|
||||
func TestRealWorldScenario(t *testing.T) {
|
||||
var (
|
||||
quickJobCount atomic.Int32
|
||||
hourlyJobCount atomic.Int32
|
||||
logEntries []string
|
||||
logMu sync.Mutex
|
||||
)
|
||||
|
||||
logger := &testLogger{
|
||||
onInfo: func(msg string, _ ...interface{}) {
|
||||
logMu.Lock()
|
||||
logEntries = append(logEntries, fmt.Sprintf("INFO: %s", msg))
|
||||
logMu.Unlock()
|
||||
},
|
||||
onError: func(msg string, _ ...interface{}) {
|
||||
logMu.Lock()
|
||||
logEntries = append(logEntries, fmt.Sprintf("ERROR: %s", msg))
|
||||
logMu.Unlock()
|
||||
},
|
||||
}
|
||||
|
||||
s := scheduler.New(
|
||||
scheduler.WithTimezone("UTC"),
|
||||
scheduler.WithMiddleware(
|
||||
scheduler.Recovery(func(jobName string, r interface{}) {
|
||||
t.Logf("Job %s panicked: %v", jobName, r)
|
||||
}),
|
||||
scheduler.Logging(logger),
|
||||
scheduler.Timeout(5*time.Second),
|
||||
),
|
||||
)
|
||||
|
||||
// Quick job (every second)
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "quick-check",
|
||||
Schedule: "* * * * * *",
|
||||
Handler: func(_ context.Context) error {
|
||||
quickJobCount.Add(1)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// Slower job (every 2 seconds)
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "slow-process",
|
||||
Schedule: "*/2 * * * * *",
|
||||
Handler: func(_ context.Context) error {
|
||||
hourlyJobCount.Add(1)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// Start scheduler
|
||||
if err := s.Start(); err != nil {
|
||||
t.Fatalf("failed to start scheduler: %v", err)
|
||||
}
|
||||
|
||||
// Let it run for 5 seconds
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Graceful shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Stop(ctx); err != nil {
|
||||
t.Fatalf("failed to stop scheduler: %v", err)
|
||||
}
|
||||
|
||||
// Verify execution counts
|
||||
quick := quickJobCount.Load()
|
||||
slow := hourlyJobCount.Load()
|
||||
|
||||
t.Logf("Quick job ran %d times", quick)
|
||||
t.Logf("Slow job ran %d times", slow)
|
||||
|
||||
if quick < 4 {
|
||||
t.Errorf("expected quick job to run at least 4 times, ran %d", quick)
|
||||
}
|
||||
|
||||
if slow < 2 {
|
||||
t.Errorf("expected slow job to run at least 2 times, ran %d", slow)
|
||||
}
|
||||
|
||||
// Verify logging
|
||||
logMu.Lock()
|
||||
defer logMu.Unlock()
|
||||
|
||||
hasStartLog := false
|
||||
hasCompleteLog := false
|
||||
for _, entry := range logEntries {
|
||||
if contains(entry, "Job started") {
|
||||
hasStartLog = true
|
||||
}
|
||||
if contains(entry, "Job completed") {
|
||||
hasCompleteLog = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasStartLog {
|
||||
t.Error("expected job start logs")
|
||||
}
|
||||
if !hasCompleteLog {
|
||||
t.Error("expected job completion logs")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCancellationDuringExecution verifies jobs can be canceled mid-execution.
|
||||
func TestCancellationDuringExecution(t *testing.T) {
|
||||
var canceled atomic.Bool
|
||||
var started atomic.Bool
|
||||
|
||||
s := scheduler.New()
|
||||
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "long-job",
|
||||
Schedule: "* * * * * *",
|
||||
Handler: func(ctx context.Context) error {
|
||||
started.Store(true)
|
||||
// Simulate long-running work
|
||||
for i := 0; i < 100; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
canceled.Store(true)
|
||||
return ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Keep working
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
if err := s.Start(); err != nil {
|
||||
t.Fatalf("failed to start: %v", err)
|
||||
}
|
||||
|
||||
// Wait until job starts
|
||||
for i := 0; i < 30; i++ {
|
||||
if started.Load() {
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
if !started.Load() {
|
||||
t.Fatal("job did not start within timeout")
|
||||
}
|
||||
|
||||
// Stop with reasonable timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Stop(ctx); err != nil {
|
||||
t.Logf("stop returned error (may be expected): %v", err)
|
||||
}
|
||||
|
||||
if !canceled.Load() {
|
||||
t.Error("expected job to detect cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTimezoneHandling verifies timezone-aware scheduling.
|
||||
func TestTimezoneHandling(t *testing.T) {
|
||||
// Parse a schedule in a specific timezone
|
||||
schedule, err := scheduler.ParseCronExpression("0 9 * * *") // 9 AM
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse schedule: %v", err)
|
||||
}
|
||||
|
||||
// Test in New York timezone
|
||||
nyc, err := time.LoadLocation("America/New_York")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load timezone: %v", err)
|
||||
}
|
||||
|
||||
// Current time: 8:30 AM in New York
|
||||
now := time.Date(2025, 1, 15, 8, 30, 0, 0, nyc)
|
||||
|
||||
// Next run should be 9:00 AM same day
|
||||
next := schedule.Next(now)
|
||||
expected := time.Date(2025, 1, 15, 9, 0, 0, 0, nyc)
|
||||
|
||||
if !next.Equal(expected) {
|
||||
t.Errorf("next = %v, expected %v", next, expected)
|
||||
}
|
||||
|
||||
// If it's already past 9 AM
|
||||
now = time.Date(2025, 1, 15, 9, 30, 0, 0, nyc)
|
||||
next = schedule.Next(now)
|
||||
expected = time.Date(2025, 1, 16, 9, 0, 0, 0, nyc)
|
||||
|
||||
if !next.Equal(expected) {
|
||||
t.Errorf("next = %v, expected %v", next, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorPropagation verifies error handling.
|
||||
func TestErrorPropagation(t *testing.T) {
|
||||
var errorLogged atomic.Bool
|
||||
|
||||
logger := &testLogger{
|
||||
onError: func(msg string, _ ...interface{}) {
|
||||
if msg == "Job failed" {
|
||||
errorLogged.Store(true)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
s := scheduler.New(
|
||||
scheduler.WithMiddleware(
|
||||
scheduler.Logging(logger),
|
||||
),
|
||||
)
|
||||
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "failing-job",
|
||||
Schedule: "* * * * * *",
|
||||
Handler: func(_ context.Context) error {
|
||||
return errors.New("intentional error")
|
||||
},
|
||||
})
|
||||
|
||||
if err := s.Start(); err != nil {
|
||||
t.Fatalf("failed to start: %v", err)
|
||||
}
|
||||
|
||||
// Let it run once
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Stop(ctx); err != nil {
|
||||
t.Fatalf("failed to stop: %v", err)
|
||||
}
|
||||
|
||||
if !errorLogged.Load() {
|
||||
t.Error("expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPanicRecovery verifies panic recovery middleware.
|
||||
func TestPanicRecovery(t *testing.T) {
|
||||
var panicRecovered atomic.Bool
|
||||
|
||||
s := scheduler.New(
|
||||
scheduler.WithMiddleware(
|
||||
scheduler.Recovery(func(jobName string, r interface{}) {
|
||||
panicRecovered.Store(true)
|
||||
t.Logf("Recovered from panic in job %s: %v", jobName, r)
|
||||
}),
|
||||
),
|
||||
)
|
||||
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "panicking-job",
|
||||
Schedule: "* * * * * *",
|
||||
Handler: func(_ context.Context) error {
|
||||
panic("intentional panic for testing")
|
||||
},
|
||||
})
|
||||
|
||||
if err := s.Start(); err != nil {
|
||||
t.Fatalf("failed to start: %v", err)
|
||||
}
|
||||
|
||||
// Let it run once
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Stop(ctx); err != nil {
|
||||
t.Fatalf("failed to stop: %v", err)
|
||||
}
|
||||
|
||||
if !panicRecovered.Load() {
|
||||
t.Error("expected panic to be recovered")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipleJobsWithDifferentSchedules verifies concurrent job execution.
|
||||
func TestMultipleJobsWithDifferentSchedules(t *testing.T) {
|
||||
var (
|
||||
job1Count atomic.Int32
|
||||
job2Count atomic.Int32
|
||||
job3Count atomic.Int32
|
||||
)
|
||||
|
||||
s := scheduler.New()
|
||||
|
||||
// Job 1: Every second
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "job-1sec",
|
||||
Schedule: "* * * * * *",
|
||||
Handler: func(_ context.Context) error {
|
||||
job1Count.Add(1)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// Job 2: Every 2 seconds
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "job-2sec",
|
||||
Schedule: "*/2 * * * * *",
|
||||
Handler: func(_ context.Context) error {
|
||||
job2Count.Add(1)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// Job 3: Every 3 seconds
|
||||
s.Register(&scheduler.Job{
|
||||
Name: "job-3sec",
|
||||
Schedule: "*/3 * * * * *",
|
||||
Handler: func(_ context.Context) error {
|
||||
job3Count.Add(1)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
if err := s.Start(); err != nil {
|
||||
t.Fatalf("failed to start: %v", err)
|
||||
}
|
||||
|
||||
// Let them run for 6 seconds
|
||||
time.Sleep(6 * time.Second)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Stop(ctx); err != nil {
|
||||
t.Fatalf("failed to stop: %v", err)
|
||||
}
|
||||
|
||||
// Verify counts (allowing for timing variance)
|
||||
c1 := job1Count.Load()
|
||||
c2 := job2Count.Load()
|
||||
c3 := job3Count.Load()
|
||||
|
||||
t.Logf("Job 1 ran %d times, Job 2 ran %d times, Job 3 ran %d times", c1, c2, c3)
|
||||
|
||||
if c1 < 5 {
|
||||
t.Errorf("expected job1 to run at least 5 times, ran %d", c1)
|
||||
}
|
||||
if c2 < 2 {
|
||||
t.Errorf("expected job2 to run at least 2 times, ran %d", c2)
|
||||
}
|
||||
if c3 < 1 {
|
||||
t.Errorf("expected job3 to run at least 1 time, ran %d", c3)
|
||||
}
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
type testLogger struct {
|
||||
onInfo func(msg string, args ...interface{})
|
||||
onError func(msg string, args ...interface{})
|
||||
}
|
||||
|
||||
func (l *testLogger) Info(msg string, args ...interface{}) {
|
||||
if l.onInfo != nil {
|
||||
l.onInfo(msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *testLogger) Error(msg string, args ...interface{}) {
|
||||
if l.onError != nil {
|
||||
l.onError(msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return strings.Contains(s, substr)
|
||||
}
|
||||
58
plugin/scheduler/job.go
Normal file
58
plugin/scheduler/job.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// JobHandler is the function signature for scheduled job handlers.
|
||||
// The context passed to the handler will be canceled if the scheduler is shutting down.
|
||||
type JobHandler func(ctx context.Context) error
|
||||
|
||||
// Job represents a scheduled task.
|
||||
type Job struct {
|
||||
// Name is a unique identifier for this job (required).
|
||||
// Used for logging and metrics.
|
||||
Name string
|
||||
|
||||
// Schedule is a cron expression defining when this job runs (required).
|
||||
// Supports standard 5-field format: "minute hour day month weekday"
|
||||
// Examples: "0 * * * *" (hourly), "0 0 * * *" (daily at midnight)
|
||||
Schedule string
|
||||
|
||||
// Timezone for schedule evaluation (optional, defaults to UTC).
|
||||
// Use IANA timezone names: "America/New_York", "Europe/London", etc.
|
||||
Timezone string
|
||||
|
||||
// Handler is the function to execute when the job triggers (required).
|
||||
Handler JobHandler
|
||||
|
||||
// Description provides human-readable context about what this job does (optional).
|
||||
Description string
|
||||
|
||||
// Tags allow categorizing jobs for filtering/monitoring (optional).
|
||||
Tags []string
|
||||
}
|
||||
|
||||
// Validate checks if the job definition is valid.
|
||||
func (j *Job) Validate() error {
|
||||
if j.Name == "" {
|
||||
return errors.New("job name is required")
|
||||
}
|
||||
|
||||
if j.Schedule == "" {
|
||||
return errors.New("job schedule is required")
|
||||
}
|
||||
|
||||
// Validate cron expression using parser
|
||||
if _, err := ParseCronExpression(j.Schedule); err != nil {
|
||||
return errors.Wrap(err, "invalid cron expression")
|
||||
}
|
||||
|
||||
if j.Handler == nil {
|
||||
return errors.New("job handler is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
90
plugin/scheduler/job_test.go
Normal file
90
plugin/scheduler/job_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJobDefinition(t *testing.T) {
|
||||
callCount := 0
|
||||
job := &Job{
|
||||
Name: "test-job",
|
||||
Handler: func(_ context.Context) error {
|
||||
callCount++
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
if job.Name != "test-job" {
|
||||
t.Errorf("expected name 'test-job', got %s", job.Name)
|
||||
}
|
||||
|
||||
// Test handler execution
|
||||
if err := job.Handler(context.Background()); err != nil {
|
||||
t.Fatalf("handler failed: %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("expected handler to be called once, called %d times", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
job *Job
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid job",
|
||||
job: &Job{
|
||||
Name: "valid",
|
||||
Schedule: "0 * * * *",
|
||||
Handler: func(_ context.Context) error { return nil },
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
job: &Job{
|
||||
Schedule: "0 * * * *",
|
||||
Handler: func(_ context.Context) error { return nil },
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing schedule",
|
||||
job: &Job{
|
||||
Name: "test",
|
||||
Handler: func(_ context.Context) error { return nil },
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid cron expression",
|
||||
job: &Job{
|
||||
Name: "test",
|
||||
Schedule: "invalid cron",
|
||||
Handler: func(_ context.Context) error { return nil },
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing handler",
|
||||
job: &Job{
|
||||
Name: "test",
|
||||
Schedule: "0 * * * *",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.job.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
120
plugin/scheduler/middleware.go
Normal file
120
plugin/scheduler/middleware.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Middleware wraps a JobHandler to add cross-cutting behavior.
|
||||
type Middleware func(JobHandler) JobHandler
|
||||
|
||||
// Chain combines multiple middleware into a single middleware.
|
||||
// Middleware are applied in the order they're provided (left to right).
|
||||
func Chain(middlewares ...Middleware) Middleware {
|
||||
return func(handler JobHandler) JobHandler {
|
||||
// Apply middleware in reverse order so first middleware wraps outermost
|
||||
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||
handler = middlewares[i](handler)
|
||||
}
|
||||
return handler
|
||||
}
|
||||
}
|
||||
|
||||
// Recovery recovers from panics in job handlers and converts them to errors.
|
||||
func Recovery(onPanic func(jobName string, recovered interface{})) Middleware {
|
||||
return func(next JobHandler) JobHandler {
|
||||
return func(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
jobName := getJobName(ctx)
|
||||
if onPanic != nil {
|
||||
onPanic(jobName, r)
|
||||
}
|
||||
err = errors.Errorf("job %q panicked: %v", jobName, r)
|
||||
}
|
||||
}()
|
||||
return next(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Logger is a minimal logging interface.
|
||||
type Logger interface {
|
||||
Info(msg string, args ...interface{})
|
||||
Error(msg string, args ...interface{})
|
||||
}
|
||||
|
||||
// Logging adds execution logging to jobs.
|
||||
func Logging(logger Logger) Middleware {
|
||||
return func(next JobHandler) JobHandler {
|
||||
return func(ctx context.Context) error {
|
||||
jobName := getJobName(ctx)
|
||||
start := time.Now()
|
||||
|
||||
logger.Info("Job started", "job", jobName)
|
||||
|
||||
err := next(ctx)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Job failed", "job", jobName, "duration", duration, "error", err)
|
||||
} else {
|
||||
logger.Info("Job completed", "job", jobName, "duration", duration)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Timeout wraps a job handler with a timeout.
|
||||
func Timeout(duration time.Duration) Middleware {
|
||||
return func(next JobHandler) JobHandler {
|
||||
return func(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, duration)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- next(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return errors.Errorf("job %q timed out after %v", getJobName(ctx), duration)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Context keys for job metadata.
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
jobNameKey contextKey = iota
|
||||
)
|
||||
|
||||
// withJobName adds the job name to the context.
|
||||
func withJobName(ctx context.Context, name string) context.Context {
|
||||
return context.WithValue(ctx, jobNameKey, name)
|
||||
}
|
||||
|
||||
// getJobName retrieves the job name from the context.
|
||||
func getJobName(ctx context.Context) string {
|
||||
if name, ok := ctx.Value(jobNameKey).(string); ok {
|
||||
return name
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// GetJobName retrieves the job name from the context (public API).
|
||||
// Returns empty string if not found.
|
||||
//
|
||||
//nolint:revive // GetJobName is the public API, getJobName is internal
|
||||
func GetJobName(ctx context.Context) string {
|
||||
return getJobName(ctx)
|
||||
}
|
||||
146
plugin/scheduler/middleware_test.go
Normal file
146
plugin/scheduler/middleware_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMiddlewareChaining(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
mw1 := func(next JobHandler) JobHandler {
|
||||
return func(ctx context.Context) error {
|
||||
order = append(order, "before-1")
|
||||
err := next(ctx)
|
||||
order = append(order, "after-1")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
mw2 := func(next JobHandler) JobHandler {
|
||||
return func(ctx context.Context) error {
|
||||
order = append(order, "before-2")
|
||||
err := next(ctx)
|
||||
order = append(order, "after-2")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
handler := func(_ context.Context) error {
|
||||
order = append(order, "handler")
|
||||
return nil
|
||||
}
|
||||
|
||||
chain := Chain(mw1, mw2)
|
||||
wrapped := chain(handler)
|
||||
|
||||
if err := wrapped(context.Background()); err != nil {
|
||||
t.Fatalf("wrapped handler failed: %v", err)
|
||||
}
|
||||
|
||||
expected := []string{"before-1", "before-2", "handler", "after-2", "after-1"}
|
||||
if len(order) != len(expected) {
|
||||
t.Fatalf("expected %d calls, got %d", len(expected), len(order))
|
||||
}
|
||||
|
||||
for i, want := range expected {
|
||||
if order[i] != want {
|
||||
t.Errorf("order[%d] = %q, want %q", i, order[i], want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryMiddleware(t *testing.T) {
|
||||
var panicRecovered atomic.Bool
|
||||
|
||||
onPanic := func(_ string, _ interface{}) {
|
||||
panicRecovered.Store(true)
|
||||
}
|
||||
|
||||
handler := func(_ context.Context) error {
|
||||
panic("simulated panic")
|
||||
}
|
||||
|
||||
recovery := Recovery(onPanic)
|
||||
wrapped := recovery(handler)
|
||||
|
||||
// Should not panic, error should be returned
|
||||
err := wrapped(withJobName(context.Background(), "test-job"))
|
||||
if err == nil {
|
||||
t.Error("expected error from recovered panic")
|
||||
}
|
||||
|
||||
if !panicRecovered.Load() {
|
||||
t.Error("panic handler was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggingMiddleware(t *testing.T) {
|
||||
var loggedStart, loggedEnd atomic.Bool
|
||||
var loggedError atomic.Bool
|
||||
|
||||
logger := &testLogger{
|
||||
onInfo: func(msg string, _ ...interface{}) {
|
||||
if msg == "Job started" {
|
||||
loggedStart.Store(true)
|
||||
} else if msg == "Job completed" {
|
||||
loggedEnd.Store(true)
|
||||
}
|
||||
},
|
||||
onError: func(msg string, _ ...interface{}) {
|
||||
if msg == "Job failed" {
|
||||
loggedError.Store(true)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Test successful execution
|
||||
handler := func(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
logging := Logging(logger)
|
||||
wrapped := logging(handler)
|
||||
|
||||
if err := wrapped(withJobName(context.Background(), "test-job")); err != nil {
|
||||
t.Fatalf("handler failed: %v", err)
|
||||
}
|
||||
|
||||
if !loggedStart.Load() {
|
||||
t.Error("start was not logged")
|
||||
}
|
||||
if !loggedEnd.Load() {
|
||||
t.Error("end was not logged")
|
||||
}
|
||||
|
||||
// Test error handling
|
||||
handlerErr := func(_ context.Context) error {
|
||||
return errors.New("job error")
|
||||
}
|
||||
|
||||
wrappedErr := logging(handlerErr)
|
||||
_ = wrappedErr(withJobName(context.Background(), "test-job-error"))
|
||||
|
||||
if !loggedError.Load() {
|
||||
t.Error("error was not logged")
|
||||
}
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
onInfo func(msg string, args ...interface{})
|
||||
onError func(msg string, args ...interface{})
|
||||
}
|
||||
|
||||
func (l *testLogger) Info(msg string, args ...interface{}) {
|
||||
if l.onInfo != nil {
|
||||
l.onInfo(msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *testLogger) Error(msg string, args ...interface{}) {
|
||||
if l.onError != nil {
|
||||
l.onError(msg, args...)
|
||||
}
|
||||
}
|
||||
229
plugin/scheduler/parser.go
Normal file
229
plugin/scheduler/parser.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Schedule represents a parsed cron expression.
|
||||
type Schedule struct {
|
||||
seconds fieldMatcher // 0-59 (optional, for 6-field format)
|
||||
minutes fieldMatcher // 0-59
|
||||
hours fieldMatcher // 0-23
|
||||
days fieldMatcher // 1-31
|
||||
months fieldMatcher // 1-12
|
||||
weekdays fieldMatcher // 0-7 (0 and 7 are Sunday)
|
||||
hasSecs bool
|
||||
}
|
||||
|
||||
// fieldMatcher determines if a field value matches.
|
||||
type fieldMatcher interface {
|
||||
matches(value int) bool
|
||||
}
|
||||
|
||||
// ParseCronExpression parses a cron expression and returns a Schedule.
|
||||
// Supports both 5-field (minute hour day month weekday) and 6-field (second minute hour day month weekday) formats.
|
||||
func ParseCronExpression(expr string) (*Schedule, error) {
|
||||
if expr == "" {
|
||||
return nil, errors.New("empty cron expression")
|
||||
}
|
||||
|
||||
fields := strings.Fields(expr)
|
||||
if len(fields) != 5 && len(fields) != 6 {
|
||||
return nil, errors.Errorf("invalid cron expression: expected 5 or 6 fields, got %d", len(fields))
|
||||
}
|
||||
|
||||
s := &Schedule{
|
||||
hasSecs: len(fields) == 6,
|
||||
}
|
||||
|
||||
var err error
|
||||
offset := 0
|
||||
|
||||
// Parse seconds (if 6-field format)
|
||||
if s.hasSecs {
|
||||
s.seconds, err = parseField(fields[0], 0, 59)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid seconds field")
|
||||
}
|
||||
offset = 1
|
||||
} else {
|
||||
s.seconds = &exactMatcher{value: 0} // Default to 0 seconds
|
||||
}
|
||||
|
||||
// Parse minutes
|
||||
s.minutes, err = parseField(fields[offset], 0, 59)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid minutes field")
|
||||
}
|
||||
|
||||
// Parse hours
|
||||
s.hours, err = parseField(fields[offset+1], 0, 23)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid hours field")
|
||||
}
|
||||
|
||||
// Parse days
|
||||
s.days, err = parseField(fields[offset+2], 1, 31)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid days field")
|
||||
}
|
||||
|
||||
// Parse months
|
||||
s.months, err = parseField(fields[offset+3], 1, 12)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid months field")
|
||||
}
|
||||
|
||||
// Parse weekdays (0-7, where both 0 and 7 represent Sunday)
|
||||
s.weekdays, err = parseField(fields[offset+4], 0, 7)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid weekdays field")
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Next returns the next time the schedule should run after the given time.
|
||||
func (s *Schedule) Next(from time.Time) time.Time {
|
||||
// Start from the next second/minute
|
||||
if s.hasSecs {
|
||||
from = from.Add(1 * time.Second).Truncate(time.Second)
|
||||
} else {
|
||||
from = from.Add(1 * time.Minute).Truncate(time.Minute)
|
||||
}
|
||||
|
||||
// Cap search at 4 years to prevent infinite loops
|
||||
maxTime := from.AddDate(4, 0, 0)
|
||||
|
||||
for from.Before(maxTime) {
|
||||
if s.matches(from) {
|
||||
return from
|
||||
}
|
||||
|
||||
// Advance to next potential match
|
||||
if s.hasSecs {
|
||||
from = from.Add(1 * time.Second)
|
||||
} else {
|
||||
from = from.Add(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// Should never reach here with valid cron expressions
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// matches checks if the given time matches the schedule.
|
||||
func (s *Schedule) matches(t time.Time) bool {
|
||||
return s.seconds.matches(t.Second()) &&
|
||||
s.minutes.matches(t.Minute()) &&
|
||||
s.hours.matches(t.Hour()) &&
|
||||
s.months.matches(int(t.Month())) &&
|
||||
(s.days.matches(t.Day()) || s.weekdays.matches(int(t.Weekday())))
|
||||
}
|
||||
|
||||
// parseField parses a single cron field (supports *, ranges, lists, steps).
|
||||
func parseField(field string, min, max int) (fieldMatcher, error) {
|
||||
// Wildcard
|
||||
if field == "*" {
|
||||
return &wildcardMatcher{}, nil
|
||||
}
|
||||
|
||||
// Step values (*/N)
|
||||
if strings.HasPrefix(field, "*/") {
|
||||
step, err := strconv.Atoi(field[2:])
|
||||
if err != nil || step < 1 || step > max {
|
||||
return nil, errors.Errorf("invalid step value: %s", field)
|
||||
}
|
||||
return &stepMatcher{step: step, min: min, max: max}, nil
|
||||
}
|
||||
|
||||
// List (1,2,3)
|
||||
if strings.Contains(field, ",") {
|
||||
parts := strings.Split(field, ",")
|
||||
values := make([]int, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
val, err := strconv.Atoi(strings.TrimSpace(p))
|
||||
if err != nil || val < min || val > max {
|
||||
return nil, errors.Errorf("invalid list value: %s", p)
|
||||
}
|
||||
values = append(values, val)
|
||||
}
|
||||
return &listMatcher{values: values}, nil
|
||||
}
|
||||
|
||||
// Range (1-5)
|
||||
if strings.Contains(field, "-") {
|
||||
parts := strings.Split(field, "-")
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.Errorf("invalid range: %s", field)
|
||||
}
|
||||
start, err1 := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||
end, err2 := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||
if err1 != nil || err2 != nil || start < min || end > max || start > end {
|
||||
return nil, errors.Errorf("invalid range: %s", field)
|
||||
}
|
||||
return &rangeMatcher{start: start, end: end}, nil
|
||||
}
|
||||
|
||||
// Exact value
|
||||
val, err := strconv.Atoi(field)
|
||||
if err != nil || val < min || val > max {
|
||||
return nil, errors.Errorf("invalid value: %s (must be between %d and %d)", field, min, max)
|
||||
}
|
||||
return &exactMatcher{value: val}, nil
|
||||
}
|
||||
|
||||
// wildcardMatcher matches any value.
|
||||
type wildcardMatcher struct{}
|
||||
|
||||
func (*wildcardMatcher) matches(_ int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// exactMatcher matches a specific value.
|
||||
type exactMatcher struct {
|
||||
value int
|
||||
}
|
||||
|
||||
func (m *exactMatcher) matches(value int) bool {
|
||||
return value == m.value
|
||||
}
|
||||
|
||||
// rangeMatcher matches values in a range.
|
||||
type rangeMatcher struct {
|
||||
start, end int
|
||||
}
|
||||
|
||||
func (m *rangeMatcher) matches(value int) bool {
|
||||
return value >= m.start && value <= m.end
|
||||
}
|
||||
|
||||
// listMatcher matches any value in a list.
|
||||
type listMatcher struct {
|
||||
values []int
|
||||
}
|
||||
|
||||
func (m *listMatcher) matches(value int) bool {
|
||||
for _, v := range m.values {
|
||||
if v == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stepMatcher matches values at regular intervals.
|
||||
type stepMatcher struct {
|
||||
step, min, max int
|
||||
}
|
||||
|
||||
func (m *stepMatcher) matches(value int) bool {
|
||||
if value < m.min || value > m.max {
|
||||
return false
|
||||
}
|
||||
return (value-m.min)%m.step == 0
|
||||
}
|
||||
127
plugin/scheduler/parser_test.go
Normal file
127
plugin/scheduler/parser_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseCronExpression(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expr string
|
||||
wantErr bool
|
||||
}{
|
||||
// Standard 5-field format
|
||||
{"every minute", "* * * * *", false},
|
||||
{"hourly", "0 * * * *", false},
|
||||
{"daily midnight", "0 0 * * *", false},
|
||||
{"weekly sunday", "0 0 * * 0", false},
|
||||
{"monthly", "0 0 1 * *", false},
|
||||
{"specific time", "30 14 * * *", false}, // 2:30 PM daily
|
||||
{"range", "0 9-17 * * *", false}, // Every hour 9 AM - 5 PM
|
||||
{"step", "*/15 * * * *", false}, // Every 15 minutes
|
||||
{"list", "0 8,12,18 * * *", false}, // 8 AM, 12 PM, 6 PM
|
||||
|
||||
// 6-field format with seconds
|
||||
{"with seconds", "0 * * * * *", false},
|
||||
{"every 30 seconds", "*/30 * * * * *", false},
|
||||
|
||||
// Invalid expressions
|
||||
{"empty", "", true},
|
||||
{"too few fields", "* * *", true},
|
||||
{"too many fields", "* * * * * * *", true},
|
||||
{"invalid minute", "60 * * * *", true},
|
||||
{"invalid hour", "0 24 * * *", true},
|
||||
{"invalid day", "0 0 32 * *", true},
|
||||
{"invalid month", "0 0 1 13 *", true},
|
||||
{"invalid weekday", "0 0 * * 8", true},
|
||||
{"garbage", "not a cron expression", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
schedule, err := ParseCronExpression(tt.expr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseCronExpression(%q) error = %v, wantErr %v", tt.expr, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && schedule == nil {
|
||||
t.Errorf("ParseCronExpression(%q) returned nil schedule without error", tt.expr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleNext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expr string
|
||||
from time.Time
|
||||
expected time.Time
|
||||
}{
|
||||
{
|
||||
name: "every minute from start of hour",
|
||||
expr: "* * * * *",
|
||||
from: time.Date(2025, 1, 1, 10, 0, 0, 0, time.UTC),
|
||||
expected: time.Date(2025, 1, 1, 10, 1, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "hourly at minute 30",
|
||||
expr: "30 * * * *",
|
||||
from: time.Date(2025, 1, 1, 10, 0, 0, 0, time.UTC),
|
||||
expected: time.Date(2025, 1, 1, 10, 30, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "hourly at minute 30 (already past)",
|
||||
expr: "30 * * * *",
|
||||
from: time.Date(2025, 1, 1, 10, 45, 0, 0, time.UTC),
|
||||
expected: time.Date(2025, 1, 1, 11, 30, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "daily at 2 AM",
|
||||
expr: "0 2 * * *",
|
||||
from: time.Date(2025, 1, 1, 10, 0, 0, 0, time.UTC),
|
||||
expected: time.Date(2025, 1, 2, 2, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "every 15 minutes",
|
||||
expr: "*/15 * * * *",
|
||||
from: time.Date(2025, 1, 1, 10, 7, 0, 0, time.UTC),
|
||||
expected: time.Date(2025, 1, 1, 10, 15, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
schedule, err := ParseCronExpression(tt.expr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse expression: %v", err)
|
||||
}
|
||||
|
||||
next := schedule.Next(tt.from)
|
||||
if !next.Equal(tt.expected) {
|
||||
t.Errorf("Next(%v) = %v, expected %v", tt.from, next, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleNextWithTimezone(t *testing.T) {
|
||||
nyc, _ := time.LoadLocation("America/New_York")
|
||||
|
||||
// Schedule for 9 AM in New York
|
||||
schedule, err := ParseCronExpression("0 9 * * *")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse expression: %v", err)
|
||||
}
|
||||
|
||||
// Current time: 8 AM in New York
|
||||
from := time.Date(2025, 1, 1, 8, 0, 0, 0, nyc)
|
||||
next := schedule.Next(from)
|
||||
|
||||
// Should be 9 AM same day in New York
|
||||
expected := time.Date(2025, 1, 1, 9, 0, 0, 0, nyc)
|
||||
if !next.Equal(expected) {
|
||||
t.Errorf("Next(%v) = %v, expected %v", from, next, expected)
|
||||
}
|
||||
}
|
||||
202
plugin/scheduler/scheduler.go
Normal file
202
plugin/scheduler/scheduler.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Scheduler manages scheduled jobs.
|
||||
type Scheduler struct {
|
||||
jobs map[string]*registeredJob
|
||||
jobsMu sync.RWMutex
|
||||
timezone *time.Location
|
||||
middleware Middleware
|
||||
running bool
|
||||
runningMu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// registeredJob wraps a Job with runtime state.
|
||||
type registeredJob struct {
|
||||
job *Job
|
||||
cancelFn context.CancelFunc
|
||||
}
|
||||
|
||||
// Option configures a Scheduler.
|
||||
type Option func(*Scheduler)
|
||||
|
||||
// WithTimezone sets the default timezone for all jobs.
|
||||
func WithTimezone(tz string) Option {
|
||||
return func(s *Scheduler) {
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err != nil {
|
||||
// Default to UTC on invalid timezone
|
||||
loc = time.UTC
|
||||
}
|
||||
s.timezone = loc
|
||||
}
|
||||
}
|
||||
|
||||
// WithMiddleware sets middleware to wrap all job handlers.
|
||||
func WithMiddleware(mw ...Middleware) Option {
|
||||
return func(s *Scheduler) {
|
||||
if len(mw) > 0 {
|
||||
s.middleware = Chain(mw...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new Scheduler with optional configuration.
|
||||
func New(opts ...Option) *Scheduler {
|
||||
s := &Scheduler{
|
||||
jobs: make(map[string]*registeredJob),
|
||||
timezone: time.UTC,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(s)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Register adds a job to the scheduler.
|
||||
// Jobs must be registered before calling Start().
|
||||
func (s *Scheduler) Register(job *Job) error {
|
||||
if job == nil {
|
||||
return errors.New("job cannot be nil")
|
||||
}
|
||||
|
||||
if err := job.Validate(); err != nil {
|
||||
return errors.Wrap(err, "invalid job")
|
||||
}
|
||||
|
||||
s.jobsMu.Lock()
|
||||
defer s.jobsMu.Unlock()
|
||||
|
||||
if _, exists := s.jobs[job.Name]; exists {
|
||||
return errors.Errorf("job with name %q already registered", job.Name)
|
||||
}
|
||||
|
||||
s.jobs[job.Name] = ®isteredJob{job: job}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start begins executing scheduled jobs.
|
||||
func (s *Scheduler) Start() error {
|
||||
s.runningMu.Lock()
|
||||
defer s.runningMu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return errors.New("scheduler already running")
|
||||
}
|
||||
|
||||
s.jobsMu.RLock()
|
||||
defer s.jobsMu.RUnlock()
|
||||
|
||||
// Parse and schedule all jobs
|
||||
for _, rj := range s.jobs {
|
||||
schedule, err := ParseCronExpression(rj.job.Schedule)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to parse schedule for job %q", rj.job.Name)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
rj.cancelFn = cancel
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.runJobWithSchedule(ctx, rj, schedule)
|
||||
}
|
||||
|
||||
s.running = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// runJobWithSchedule executes a job according to its cron schedule.
|
||||
func (s *Scheduler) runJobWithSchedule(ctx context.Context, rj *registeredJob, schedule *Schedule) {
|
||||
defer s.wg.Done()
|
||||
|
||||
// Apply middleware to handler
|
||||
handler := rj.job.Handler
|
||||
if s.middleware != nil {
|
||||
handler = s.middleware(handler)
|
||||
}
|
||||
|
||||
for {
|
||||
// Calculate next run time
|
||||
now := time.Now()
|
||||
if rj.job.Timezone != "" {
|
||||
loc, err := time.LoadLocation(rj.job.Timezone)
|
||||
if err == nil {
|
||||
now = now.In(loc)
|
||||
}
|
||||
} else if s.timezone != nil {
|
||||
now = now.In(s.timezone)
|
||||
}
|
||||
|
||||
next := schedule.Next(now)
|
||||
duration := time.Until(next)
|
||||
|
||||
timer := time.NewTimer(duration)
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
// Add job name to context and execute
|
||||
jobCtx := withJobName(ctx, rj.job.Name)
|
||||
if err := handler(jobCtx); err != nil {
|
||||
// Error already handled by middleware (if any)
|
||||
_ = err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
// Stop the timer to prevent it from firing. The timer will be garbage collected.
|
||||
timer.Stop()
|
||||
return
|
||||
case <-s.stopCh:
|
||||
// Stop the timer to prevent it from firing. The timer will be garbage collected.
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the scheduler.
|
||||
// It waits for all running jobs to complete or until the context is canceled.
|
||||
func (s *Scheduler) Stop(ctx context.Context) error {
|
||||
s.runningMu.Lock()
|
||||
if !s.running {
|
||||
s.runningMu.Unlock()
|
||||
return errors.New("scheduler not running")
|
||||
}
|
||||
s.running = false
|
||||
s.runningMu.Unlock()
|
||||
|
||||
// Cancel all job contexts
|
||||
s.jobsMu.RLock()
|
||||
for _, rj := range s.jobs {
|
||||
if rj.cancelFn != nil {
|
||||
rj.cancelFn()
|
||||
}
|
||||
}
|
||||
s.jobsMu.RUnlock()
|
||||
|
||||
// Signal stop and wait for jobs to finish
|
||||
close(s.stopCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
165
plugin/scheduler/scheduler_test.go
Normal file
165
plugin/scheduler/scheduler_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSchedulerCreation(t *testing.T) {
|
||||
s := New()
|
||||
if s == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerWithTimezone(t *testing.T) {
|
||||
s := New(WithTimezone("America/New_York"))
|
||||
if s == nil {
|
||||
t.Fatal("New() with timezone returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobRegistration(t *testing.T) {
|
||||
s := New()
|
||||
|
||||
job := &Job{
|
||||
Name: "test-registration",
|
||||
Schedule: "0 * * * *",
|
||||
Handler: func(_ context.Context) error { return nil },
|
||||
}
|
||||
|
||||
if err := s.Register(job); err != nil {
|
||||
t.Fatalf("failed to register valid job: %v", err)
|
||||
}
|
||||
|
||||
// Registering duplicate name should fail
|
||||
if err := s.Register(job); err == nil {
|
||||
t.Error("expected error when registering duplicate job name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerStartStop(t *testing.T) {
|
||||
s := New()
|
||||
|
||||
var runCount atomic.Int32
|
||||
job := &Job{
|
||||
Name: "test-start-stop",
|
||||
Schedule: "* * * * * *", // Every second (6-field format)
|
||||
Handler: func(_ context.Context) error {
|
||||
runCount.Add(1)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.Register(job); err != nil {
|
||||
t.Fatalf("failed to register job: %v", err)
|
||||
}
|
||||
|
||||
// Start scheduler
|
||||
if err := s.Start(); err != nil {
|
||||
t.Fatalf("failed to start scheduler: %v", err)
|
||||
}
|
||||
|
||||
// Let it run for 2.5 seconds
|
||||
time.Sleep(2500 * time.Millisecond)
|
||||
|
||||
// Stop scheduler
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Stop(ctx); err != nil {
|
||||
t.Fatalf("failed to stop scheduler: %v", err)
|
||||
}
|
||||
|
||||
count := runCount.Load()
|
||||
// Should have run at least twice (at 0s and 1s, maybe 2s)
|
||||
if count < 2 {
|
||||
t.Errorf("expected job to run at least 2 times, ran %d times", count)
|
||||
}
|
||||
|
||||
// Verify it stopped (count shouldn't increase)
|
||||
finalCount := count
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
if runCount.Load() != finalCount {
|
||||
t.Error("scheduler did not stop - job continued running")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerWithMiddleware(t *testing.T) {
|
||||
var executionLog []string
|
||||
var logMu sync.Mutex
|
||||
|
||||
logger := &testLogger{
|
||||
onInfo: func(msg string, _ ...interface{}) {
|
||||
logMu.Lock()
|
||||
executionLog = append(executionLog, fmt.Sprintf("INFO: %s", msg))
|
||||
logMu.Unlock()
|
||||
},
|
||||
onError: func(msg string, _ ...interface{}) {
|
||||
logMu.Lock()
|
||||
executionLog = append(executionLog, fmt.Sprintf("ERROR: %s", msg))
|
||||
logMu.Unlock()
|
||||
},
|
||||
}
|
||||
|
||||
s := New(WithMiddleware(
|
||||
Recovery(func(jobName string, r interface{}) {
|
||||
logMu.Lock()
|
||||
executionLog = append(executionLog, fmt.Sprintf("PANIC: %s - %v", jobName, r))
|
||||
logMu.Unlock()
|
||||
}),
|
||||
Logging(logger),
|
||||
))
|
||||
|
||||
job := &Job{
|
||||
Name: "test-middleware",
|
||||
Schedule: "* * * * * *", // Every second
|
||||
Handler: func(_ context.Context) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.Register(job); err != nil {
|
||||
t.Fatalf("failed to register job: %v", err)
|
||||
}
|
||||
|
||||
if err := s.Start(); err != nil {
|
||||
t.Fatalf("failed to start: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Stop(ctx); err != nil {
|
||||
t.Fatalf("failed to stop: %v", err)
|
||||
}
|
||||
|
||||
logMu.Lock()
|
||||
defer logMu.Unlock()
|
||||
|
||||
// Should have at least one start and one completion log
|
||||
hasStart := false
|
||||
hasCompletion := false
|
||||
for _, log := range executionLog {
|
||||
if strings.Contains(log, "Job started") {
|
||||
hasStart = true
|
||||
}
|
||||
if strings.Contains(log, "Job completed") {
|
||||
hasCompletion = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasStart {
|
||||
t.Error("expected job start log")
|
||||
}
|
||||
if !hasCompletion {
|
||||
t.Error("expected job completion log")
|
||||
}
|
||||
}
|
||||
118
plugin/storage/s3/s3.go
Normal file
118
plugin/storage/s3/s3.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Client *s3.Client
|
||||
Bucket *string
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, s3Config *storepb.StorageS3Config) (*Client, error) {
|
||||
cfg, err := config.LoadDefaultConfig(ctx,
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(s3Config.AccessKeyId, s3Config.AccessKeySecret, "")),
|
||||
config.WithRegion(s3Config.Region),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load s3 config")
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
|
||||
o.BaseEndpoint = aws.String(s3Config.Endpoint)
|
||||
o.UsePathStyle = s3Config.UsePathStyle
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired
|
||||
})
|
||||
return &Client{
|
||||
Client: client,
|
||||
Bucket: aws.String(s3Config.Bucket),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UploadObject uploads an object to S3.
|
||||
func (c *Client) UploadObject(ctx context.Context, key string, fileType string, content io.Reader) (string, error) {
|
||||
uploader := manager.NewUploader(c.Client)
|
||||
putInput := s3.PutObjectInput{
|
||||
Bucket: c.Bucket,
|
||||
Key: aws.String(key),
|
||||
ContentType: aws.String(fileType),
|
||||
Body: content,
|
||||
}
|
||||
result, err := uploader.Upload(ctx, &putInput)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resultKey := result.Key
|
||||
if resultKey == nil || *resultKey == "" {
|
||||
return "", errors.New("failed to get file key")
|
||||
}
|
||||
return *resultKey, nil
|
||||
}
|
||||
|
||||
// PresignGetObject presigns an object in S3.
|
||||
func (c *Client) PresignGetObject(ctx context.Context, key string) (string, error) {
|
||||
presignClient := s3.NewPresignClient(c.Client)
|
||||
presignResult, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(*c.Bucket),
|
||||
Key: aws.String(key),
|
||||
}, func(opts *s3.PresignOptions) {
|
||||
// Set the expiration time of the presigned URL to 5 days.
|
||||
// Reference: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html
|
||||
opts.Expires = time.Duration(5 * 24 * time.Hour)
|
||||
})
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to presign put object")
|
||||
}
|
||||
return presignResult.URL, nil
|
||||
}
|
||||
|
||||
// GetObject retrieves an object from S3.
|
||||
func (c *Client) GetObject(ctx context.Context, key string) ([]byte, error) {
|
||||
downloader := manager.NewDownloader(c.Client)
|
||||
buffer := manager.NewWriteAtBuffer([]byte{})
|
||||
_, err := downloader.Download(ctx, buffer, &s3.GetObjectInput{
|
||||
Bucket: c.Bucket,
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to download object")
|
||||
}
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// GetObjectStream retrieves an object from S3 as a stream.
|
||||
func (c *Client) GetObjectStream(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
output, err := c.Client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: c.Bucket,
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get object")
|
||||
}
|
||||
return output.Body, nil
|
||||
}
|
||||
|
||||
// DeleteObject deletes an object in S3.
|
||||
func (c *Client) DeleteObject(ctx context.Context, key string) error {
|
||||
_, err := c.Client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: c.Bucket,
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to delete object")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
75
plugin/webhook/validate.go
Normal file
75
plugin/webhook/validate.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// reservedCIDRs lists IP ranges that must never be targeted by outbound webhook requests.
|
||||
// Covers loopback, RFC-1918 private, link-local (including cloud IMDS at 169.254.169.254),
|
||||
// and their IPv6 equivalents.
|
||||
var reservedCIDRs = []string{
|
||||
"127.0.0.0/8", // IPv4 loopback
|
||||
"10.0.0.0/8", // RFC-1918 class A
|
||||
"172.16.0.0/12", // RFC-1918 class B
|
||||
"192.168.0.0/16", // RFC-1918 class C
|
||||
"169.254.0.0/16", // Link-local / cloud IMDS
|
||||
"::1/128", // IPv6 loopback
|
||||
"fc00::/7", // IPv6 unique local
|
||||
"fe80::/10", // IPv6 link-local
|
||||
}
|
||||
|
||||
// reservedNetworks is the parsed form of reservedCIDRs, built once at startup.
|
||||
var reservedNetworks []*net.IPNet
|
||||
|
||||
func init() {
|
||||
for _, cidr := range reservedCIDRs {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
panic("webhook: invalid reserved CIDR " + cidr + ": " + err.Error())
|
||||
}
|
||||
reservedNetworks = append(reservedNetworks, network)
|
||||
}
|
||||
}
|
||||
|
||||
// isReservedIP reports whether ip falls within any reserved/private range.
|
||||
func isReservedIP(ip net.IP) bool {
|
||||
for _, network := range reservedNetworks {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateURL checks that rawURL:
|
||||
// 1. Parses as a valid absolute URL.
|
||||
// 2. Uses the http or https scheme.
|
||||
// 3. Does not resolve to a reserved/private IP address.
|
||||
//
|
||||
// It returns a gRPC InvalidArgument status error so callers can return it directly.
|
||||
func ValidateURL(rawURL string) error {
|
||||
u, err := url.ParseRequestURI(rawURL)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid webhook URL: %v", err)
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return status.Errorf(codes.InvalidArgument, "webhook URL must use http or https scheme, got %q", u.Scheme)
|
||||
}
|
||||
|
||||
ips, err := net.LookupHost(u.Hostname())
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "webhook URL hostname could not be resolved: %v", err)
|
||||
}
|
||||
|
||||
for _, ipStr := range ips {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip != nil && isReservedIP(ip) {
|
||||
return status.Errorf(codes.InvalidArgument, "webhook URL must not resolve to a reserved or private IP address")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
120
plugin/webhook/webhook.go
Normal file
120
plugin/webhook/webhook.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
var (
|
||||
// timeout is the timeout for webhook request. Default to 30 seconds.
|
||||
timeout = 30 * time.Second
|
||||
|
||||
// safeClient is the shared HTTP client used for all webhook dispatches.
|
||||
// Its Transport guards against SSRF by blocking connections to reserved/private
|
||||
// IP addresses at dial time, which also defeats DNS rebinding attacks.
|
||||
safeClient = &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
DialContext: safeDialContext,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// safeDialContext is a net.Dialer.DialContext replacement that resolves the target
|
||||
// hostname and rejects any address that falls within a reserved/private IP range.
|
||||
func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("webhook: invalid address %q", addr)
|
||||
}
|
||||
|
||||
ips, err := net.DefaultResolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "webhook: failed to resolve host %q", host)
|
||||
}
|
||||
|
||||
for _, ipStr := range ips {
|
||||
if ip := net.ParseIP(ipStr); ip != nil && isReservedIP(ip) {
|
||||
return nil, errors.Errorf("webhook: connection to reserved/private IP address is not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
return (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(host, port))
|
||||
}
|
||||
|
||||
type WebhookRequestPayload struct {
|
||||
// The target URL for the webhook request.
|
||||
URL string `json:"url"`
|
||||
// The type of activity that triggered this webhook.
|
||||
ActivityType string `json:"activityType"`
|
||||
// The resource name of the creator. Format: users/{user}
|
||||
Creator string `json:"creator"`
|
||||
// The memo that triggered this webhook (if applicable).
|
||||
Memo *v1pb.Memo `json:"memo"`
|
||||
}
|
||||
|
||||
// Post posts the message to webhook endpoint.
|
||||
func Post(requestPayload *WebhookRequestPayload) error {
|
||||
body, err := json.Marshal(requestPayload)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to marshal webhook request to %s", requestPayload.URL)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", requestPayload.URL, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to construct webhook request to %s", requestPayload.URL)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := safeClient.Do(req)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to post webhook to %s", requestPayload.URL)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read webhook response from %s", requestPayload.URL)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return errors.Errorf("failed to post webhook %s, status code: %d", requestPayload.URL, resp.StatusCode)
|
||||
}
|
||||
|
||||
response := &struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}{}
|
||||
if err := json.Unmarshal(b, response); err != nil {
|
||||
return errors.Wrapf(err, "failed to unmarshal webhook response from %s", requestPayload.URL)
|
||||
}
|
||||
|
||||
if response.Code != 0 {
|
||||
return errors.Errorf("receive error code sent by webhook server, code %d, msg: %s", response.Code, response.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostAsync posts the message to webhook endpoint asynchronously.
|
||||
// It spawns a new goroutine to handle the request and does not wait for the response.
|
||||
func PostAsync(requestPayload *WebhookRequestPayload) {
|
||||
go func() {
|
||||
if err := Post(requestPayload); err != nil {
|
||||
slog.Warn("Failed to dispatch webhook asynchronously",
|
||||
slog.String("url", requestPayload.URL),
|
||||
slog.String("activityType", requestPayload.ActivityType),
|
||||
slog.Any("err", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
1
plugin/webhook/webhook_test.go
Normal file
1
plugin/webhook/webhook_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package webhook
|
||||
Reference in New Issue
Block a user