first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled

This commit is contained in:
2026-03-04 06:30:47 +00:00
commit bb402d4ccc
777 changed files with 135661 additions and 0 deletions

110
plugin/ai/groq/groq.go Normal file
View File

@@ -0,0 +1,110 @@
package ai
import (
"context"
"fmt"
"github.com/sashabaranov/go-openai"
)
type GroqClient struct {
client *openai.Client
config GroqConfig
}
type GroqConfig struct {
APIKey string
BaseURL string
DefaultModel string
}
type CompletionRequest struct {
Model string
Prompt string
Temperature float32
MaxTokens int
}
type CompletionResponse struct {
Text string
Model string
PromptTokens int
CompletionTokens int
TotalTokens int
}
func NewGroqClient(config GroqConfig) *GroqClient {
if config.BaseURL == "" {
config.BaseURL = "https://api.groq.com/openai/v1"
}
clientConfig := openai.DefaultConfig(config.APIKey)
clientConfig.BaseURL = config.BaseURL
return &GroqClient{
client: openai.NewClientWithConfig(clientConfig),
config: config,
}
}
func (g *GroqClient) TestConnection(ctx context.Context) error {
// Test by listing available models
_, err := g.client.ListModels(ctx)
if err != nil {
return fmt.Errorf("failed to connect to Groq API: %w", err)
}
return nil
}
func (g *GroqClient) GenerateCompletion(ctx context.Context, req CompletionRequest) (*CompletionResponse, error) {
model := req.Model
if model == "" {
model = g.config.DefaultModel
}
if model == "" {
model = "llama-3.1-8b-instant"
}
chatReq := openai.ChatCompletionRequest{
Model: model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: req.Prompt,
},
},
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
}
resp, err := g.client.CreateChatCompletion(ctx, chatReq)
if err != nil {
return nil, fmt.Errorf("failed to generate completion: %w", err)
}
if len(resp.Choices) == 0 {
return nil, fmt.Errorf("no completion choices returned")
}
return &CompletionResponse{
Text: resp.Choices[0].Message.Content,
Model: resp.Model,
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}, nil
}
func (g *GroqClient) ListModels(ctx context.Context) ([]string, error) {
models, err := g.client.ListModels(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list models: %w", err)
}
var modelNames []string
for _, model := range models.Models {
modelNames = append(modelNames, model.ID)
}
return modelNames, nil
}

199
plugin/ai/ollama/ollama.go Normal file
View File

@@ -0,0 +1,199 @@
package ai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type OllamaClient struct {
httpClient *http.Client
config OllamaConfig
}
type OllamaConfig struct {
Host string
DefaultModel string
Timeout time.Duration
}
// CompletionRequest is the input for generating completions.
type CompletionRequest struct {
Model string
Prompt string
Temperature float32
MaxTokens int
}
// CompletionResponse is the output from generating completions.
type CompletionResponse struct {
Text string
Model string
PromptTokens int
CompletionTokens int
TotalTokens int
}
type OllamaGenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
Options map[string]interface{} `json:"options,omitempty"`
}
type OllamaGenerateResponse struct {
Model string `json:"model"`
Response string `json:"response"`
Done bool `json:"done"`
Context []int `json:"context,omitempty"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int64 `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
}
type OllamaTagsResponse struct {
Models []struct {
Name string `json:"name"`
Model string `json:"model"`
Size int64 `json:"size"`
Digest string `json:"digest"`
} `json:"models"`
}
func NewOllamaClient(config OllamaConfig) *OllamaClient {
if config.Timeout == 0 {
config.Timeout = 120 * time.Second // Increase to 2 minutes for generation
}
return &OllamaClient{
httpClient: &http.Client{
Timeout: config.Timeout,
},
config: config,
}
}
func (o *OllamaClient) TestConnection(ctx context.Context) error {
url := fmt.Sprintf("%s/api/tags", o.config.Host)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := o.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to connect to Ollama: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
}
return nil
}
func (o *OllamaClient) GenerateCompletion(ctx context.Context, req CompletionRequest) (*CompletionResponse, error) {
model := req.Model
if model == "" {
model = o.config.DefaultModel
}
if model == "" {
model = "llama3"
}
ollamaReq := OllamaGenerateRequest{
Model: model,
Prompt: req.Prompt,
Stream: false,
Options: map[string]interface{}{
"temperature": req.Temperature,
},
}
if req.MaxTokens > 0 {
ollamaReq.Options["num_predict"] = req.MaxTokens
}
jsonData, err := json.Marshal(ollamaReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/api/generate", o.config.Host)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := o.httpClient.Do(httpReq)
if err != nil {
// Check if it's a timeout error
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("Ollama request timed out after %.0f seconds. Try reducing the max tokens or using a smaller model", o.config.Timeout.Seconds())
}
return nil, fmt.Errorf("failed to call Ollama API: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
}
var ollamaResp OllamaGenerateResponse
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &CompletionResponse{
Text: ollamaResp.Response,
Model: ollamaResp.Model,
PromptTokens: ollamaResp.PromptEvalCount,
CompletionTokens: ollamaResp.EvalCount,
TotalTokens: ollamaResp.PromptEvalCount + ollamaResp.EvalCount,
}, nil
}
func (o *OllamaClient) ListModels(ctx context.Context) ([]string, error) {
url := fmt.Sprintf("%s/api/tags", o.config.Host)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get models: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("Ollama returned status %d: %s", resp.StatusCode, string(body))
}
var tagsResp OllamaTagsResponse
if err := json.NewDecoder(resp.Body).Decode(&tagsResp); err != nil {
return nil, fmt.Errorf("failed to decode models response: %w", err)
}
var modelNames []string
for _, model := range tagsResp.Models {
modelNames = append(modelNames, model.Name)
}
return modelNames, nil
}

1
plugin/cron/README.md Normal file
View File

@@ -0,0 +1 @@
Fork from https://github.com/robfig/cron

96
plugin/cron/chain.go Normal file
View File

@@ -0,0 +1,96 @@
package cron
import (
"errors"
"fmt"
"runtime"
"sync"
"time"
)
// JobWrapper decorates the given Job with some behavior.
type JobWrapper func(Job) Job
// Chain is a sequence of JobWrappers that decorates submitted jobs with
// cross-cutting behaviors like logging or synchronization.
type Chain struct {
wrappers []JobWrapper
}
// NewChain returns a Chain consisting of the given JobWrappers.
func NewChain(c ...JobWrapper) Chain {
return Chain{c}
}
// Then decorates the given job with all JobWrappers in the chain.
//
// This:
//
// NewChain(m1, m2, m3).Then(job)
//
// is equivalent to:
//
// m1(m2(m3(job)))
func (c Chain) Then(j Job) Job {
for i := range c.wrappers {
j = c.wrappers[len(c.wrappers)-i-1](j)
}
return j
}
// Recover panics in wrapped jobs and log them with the provided logger.
func Recover(logger Logger) JobWrapper {
return func(j Job) Job {
return FuncJob(func() {
defer func() {
if r := recover(); r != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
err, ok := r.(error)
if !ok {
err = errors.New("panic: " + fmt.Sprint(r))
}
logger.Error(err, "panic", "stack", "...\n"+string(buf))
}
}()
j.Run()
})
}
}
// DelayIfStillRunning serializes jobs, delaying subsequent runs until the
// previous one is complete. Jobs running after a delay of more than a minute
// have the delay logged at Info.
func DelayIfStillRunning(logger Logger) JobWrapper {
return func(j Job) Job {
var mu sync.Mutex
return FuncJob(func() {
start := time.Now()
mu.Lock()
defer mu.Unlock()
if dur := time.Since(start); dur > time.Minute {
logger.Info("delay", "duration", dur)
}
j.Run()
})
}
}
// SkipIfStillRunning skips an invocation of the Job if a previous invocation is
// still running. It logs skips to the given logger at Info level.
func SkipIfStillRunning(logger Logger) JobWrapper {
return func(j Job) Job {
var ch = make(chan struct{}, 1)
ch <- struct{}{}
return FuncJob(func() {
select {
case v := <-ch:
defer func() { ch <- v }()
j.Run()
default:
logger.Info("skip")
}
})
}
}

239
plugin/cron/chain_test.go Normal file
View File

@@ -0,0 +1,239 @@
//nolint:all
package cron
import (
"io"
"log"
"reflect"
"sync"
"testing"
"time"
)
func appendingJob(slice *[]int, value int) Job {
var m sync.Mutex
return FuncJob(func() {
m.Lock()
*slice = append(*slice, value)
m.Unlock()
})
}
func appendingWrapper(slice *[]int, value int) JobWrapper {
return func(j Job) Job {
return FuncJob(func() {
appendingJob(slice, value).Run()
j.Run()
})
}
}
func TestChain(t *testing.T) {
var nums []int
var (
append1 = appendingWrapper(&nums, 1)
append2 = appendingWrapper(&nums, 2)
append3 = appendingWrapper(&nums, 3)
append4 = appendingJob(&nums, 4)
)
NewChain(append1, append2, append3).Then(append4).Run()
if !reflect.DeepEqual(nums, []int{1, 2, 3, 4}) {
t.Error("unexpected order of calls:", nums)
}
}
func TestChainRecover(t *testing.T) {
panickingJob := FuncJob(func() {
panic("panickingJob panics")
})
t.Run("panic exits job by default", func(*testing.T) {
defer func() {
if err := recover(); err == nil {
t.Errorf("panic expected, but none received")
}
}()
NewChain().Then(panickingJob).
Run()
})
t.Run("Recovering JobWrapper recovers", func(*testing.T) {
NewChain(Recover(PrintfLogger(log.New(io.Discard, "", 0)))).
Then(panickingJob).
Run()
})
t.Run("composed with the *IfStillRunning wrappers", func(*testing.T) {
NewChain(Recover(PrintfLogger(log.New(io.Discard, "", 0)))).
Then(panickingJob).
Run()
})
}
type countJob struct {
m sync.Mutex
started int
done int
delay time.Duration
}
func (j *countJob) Run() {
j.m.Lock()
j.started++
j.m.Unlock()
time.Sleep(j.delay)
j.m.Lock()
j.done++
j.m.Unlock()
}
func (j *countJob) Started() int {
defer j.m.Unlock()
j.m.Lock()
return j.started
}
func (j *countJob) Done() int {
defer j.m.Unlock()
j.m.Lock()
return j.done
}
func TestChainDelayIfStillRunning(t *testing.T) {
t.Run("runs immediately", func(*testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
}
})
t.Run("second run immediate if first done", func(*testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(time.Millisecond)
go wrappedJob.Run()
}()
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
if c := j.Done(); c != 2 {
t.Errorf("expected job run twice, immediately, got %d", c)
}
})
t.Run("second run delayed if first not done", func(*testing.T) {
var j countJob
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(time.Millisecond)
go wrappedJob.Run()
}()
// After 5ms, the first job is still in progress, and the second job was
// run but should be waiting for it to finish.
time.Sleep(5 * time.Millisecond)
started, done := j.Started(), j.Done()
if started != 1 || done != 0 {
t.Error("expected first job started, but not finished, got", started, done)
}
// Verify that the second job completes.
time.Sleep(25 * time.Millisecond)
started, done = j.Started(), j.Done()
if started != 2 || done != 2 {
t.Error("expected both jobs done, got", started, done)
}
})
}
func TestChainSkipIfStillRunning(t *testing.T) {
t.Run("runs immediately", func(*testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
}
})
t.Run("second run immediate if first done", func(*testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(time.Millisecond)
go wrappedJob.Run()
}()
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
if c := j.Done(); c != 2 {
t.Errorf("expected job run twice, immediately, got %d", c)
}
})
t.Run("second run skipped if first not done", func(*testing.T) {
var j countJob
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(time.Millisecond)
go wrappedJob.Run()
}()
// After 5ms, the first job is still in progress, and the second job was
// already skipped.
time.Sleep(5 * time.Millisecond)
started, done := j.Started(), j.Done()
if started != 1 || done != 0 {
t.Error("expected first job started, but not finished, got", started, done)
}
// Verify that the first job completes and second does not run.
time.Sleep(25 * time.Millisecond)
started, done = j.Started(), j.Done()
if started != 1 || done != 1 {
t.Error("expected second job skipped, got", started, done)
}
})
t.Run("skip 10 jobs on rapid fire", func(*testing.T) {
var j countJob
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
for i := 0; i < 11; i++ {
go wrappedJob.Run()
}
time.Sleep(200 * time.Millisecond)
done := j.Done()
if done != 1 {
t.Error("expected 1 jobs executed, 10 jobs dropped, got", done)
}
})
t.Run("different jobs independent", func(*testing.T) {
var j1, j2 countJob
j1.delay = 10 * time.Millisecond
j2.delay = 10 * time.Millisecond
chain := NewChain(SkipIfStillRunning(DiscardLogger))
wrappedJob1 := chain.Then(&j1)
wrappedJob2 := chain.Then(&j2)
for i := 0; i < 11; i++ {
go wrappedJob1.Run()
go wrappedJob2.Run()
}
time.Sleep(100 * time.Millisecond)
var (
done1 = j1.Done()
done2 = j2.Done()
)
if done1 != 1 || done2 != 1 {
t.Error("expected both jobs executed once, got", done1, "and", done2)
}
})
}

View File

@@ -0,0 +1,27 @@
package cron
import "time"
// ConstantDelaySchedule represents a simple recurring duty cycle, e.g. "Every 5 minutes".
// It does not support jobs more frequent than once a second.
type ConstantDelaySchedule struct {
Delay time.Duration
}
// Every returns a crontab Schedule that activates once every duration.
// Delays of less than a second are not supported (will round up to 1 second).
// Any fields less than a Second are truncated.
func Every(duration time.Duration) ConstantDelaySchedule {
if duration < time.Second {
duration = time.Second
}
return ConstantDelaySchedule{
Delay: duration - time.Duration(duration.Nanoseconds())%time.Second,
}
}
// Next returns the next time this should be run.
// This rounds so that the next activation time will be on the second.
func (schedule ConstantDelaySchedule) Next(t time.Time) time.Time {
return t.Add(schedule.Delay - time.Duration(t.Nanosecond())*time.Nanosecond)
}

View File

@@ -0,0 +1,55 @@
//nolint:all
package cron
import (
"testing"
"time"
)
func TestConstantDelayNext(t *testing.T) {
tests := []struct {
time string
delay time.Duration
expected string
}{
// Simple cases
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
{"Mon Jul 9 14:59 2012", 15 * time.Minute, "Mon Jul 9 15:14 2012"},
{"Mon Jul 9 14:59:59 2012", 15 * time.Minute, "Mon Jul 9 15:14:59 2012"},
// Wrap around hours
{"Mon Jul 9 15:45 2012", 35 * time.Minute, "Mon Jul 9 16:20 2012"},
// Wrap around days
{"Mon Jul 9 23:46 2012", 14 * time.Minute, "Tue Jul 10 00:00 2012"},
{"Mon Jul 9 23:45 2012", 35 * time.Minute, "Tue Jul 10 00:20 2012"},
{"Mon Jul 9 23:35:51 2012", 44*time.Minute + 24*time.Second, "Tue Jul 10 00:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", 25*time.Hour + 44*time.Minute + 24*time.Second, "Thu Jul 11 01:20:15 2012"},
// Wrap around months
{"Mon Jul 9 23:35 2012", 91*24*time.Hour + 25*time.Minute, "Thu Oct 9 00:00 2012"},
// Wrap around minute, hour, day, month, and year
{"Mon Dec 31 23:59:45 2012", 15 * time.Second, "Tue Jan 1 00:00:00 2013"},
// Round to nearest second on the delay
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
// Round up to 1 second if the duration is less.
{"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:01 2012"},
// Round to nearest second when calculating the next time.
{"Mon Jul 9 14:45:00.005 2012", 15 * time.Minute, "Mon Jul 9 15:00 2012"},
// Round to nearest second for both.
{"Mon Jul 9 14:45:00.005 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
}
for _, c := range tests {
actual := Every(c.delay).Next(getTime(c.time))
expected := getTime(c.expected)
if actual != expected {
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.delay, expected, actual)
}
}
}

353
plugin/cron/cron.go Normal file
View File

@@ -0,0 +1,353 @@
package cron
import (
"context"
"sort"
"sync"
"time"
)
// Cron keeps track of any number of entries, invoking the associated func as
// specified by the schedule. It may be started, stopped, and the entries may
// be inspected while running.
type Cron struct {
entries []*Entry
chain Chain
stop chan struct{}
add chan *Entry
remove chan EntryID
snapshot chan chan []Entry
running bool
logger Logger
runningMu sync.Mutex
location *time.Location
parser ScheduleParser
nextID EntryID
jobWaiter sync.WaitGroup
}
// ScheduleParser is an interface for schedule spec parsers that return a Schedule.
type ScheduleParser interface {
Parse(spec string) (Schedule, error)
}
// Job is an interface for submitted cron jobs.
type Job interface {
Run()
}
// Schedule describes a job's duty cycle.
type Schedule interface {
// Next returns the next activation time, later than the given time.
// Next is invoked initially, and then each time the job is run.
Next(time.Time) time.Time
}
// EntryID identifies an entry within a Cron instance.
type EntryID int
// Entry consists of a schedule and the func to execute on that schedule.
type Entry struct {
// ID is the cron-assigned ID of this entry, which may be used to look up a
// snapshot or remove it.
ID EntryID
// Schedule on which this job should be run.
Schedule Schedule
// Next time the job will run, or the zero time if Cron has not been
// started or this entry's schedule is unsatisfiable
Next time.Time
// Prev is the last time this job was run, or the zero time if never.
Prev time.Time
// WrappedJob is the thing to run when the Schedule is activated.
WrappedJob Job
// Job is the thing that was submitted to cron.
// It is kept around so that user code that needs to get at the job later,
// e.g. via Entries() can do so.
Job Job
}
// Valid returns true if this is not the zero entry.
func (e Entry) Valid() bool { return e.ID != 0 }
// byTime is a wrapper for sorting the entry array by time
// (with zero time at the end).
type byTime []*Entry
func (s byTime) Len() int { return len(s) }
func (s byTime) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s byTime) Less(i, j int) bool {
// Two zero times should return false.
// Otherwise, zero is "greater" than any other time.
// (To sort it at the end of the list.)
if s[i].Next.IsZero() {
return false
}
if s[j].Next.IsZero() {
return true
}
return s[i].Next.Before(s[j].Next)
}
// New returns a new Cron job runner, modified by the given options.
//
// Available Settings
//
// Time Zone
// Description: The time zone in which schedules are interpreted
// Default: time.Local
//
// Parser
// Description: Parser converts cron spec strings into cron.Schedules.
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
//
// Chain
// Description: Wrap submitted jobs to customize behavior.
// Default: A chain that recovers panics and logs them to stderr.
//
// See "cron.With*" to modify the default behavior.
func New(opts ...Option) *Cron {
c := &Cron{
entries: nil,
chain: NewChain(),
add: make(chan *Entry),
stop: make(chan struct{}),
snapshot: make(chan chan []Entry),
remove: make(chan EntryID),
running: false,
runningMu: sync.Mutex{},
logger: DefaultLogger,
location: time.Local,
parser: standardParser,
}
for _, opt := range opts {
opt(c)
}
return c
}
// FuncJob is a wrapper that turns a func() into a cron.Job.
type FuncJob func()
func (f FuncJob) Run() { f() }
// AddFunc adds a func to the Cron to be run on the given schedule.
// The spec is parsed using the time zone of this Cron instance as the default.
// An opaque ID is returned that can be used to later remove it.
func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) {
return c.AddJob(spec, FuncJob(cmd))
}
// AddJob adds a Job to the Cron to be run on the given schedule.
// The spec is parsed using the time zone of this Cron instance as the default.
// An opaque ID is returned that can be used to later remove it.
func (c *Cron) AddJob(spec string, cmd Job) (EntryID, error) {
schedule, err := c.parser.Parse(spec)
if err != nil {
return 0, err
}
return c.Schedule(schedule, cmd), nil
}
// Schedule adds a Job to the Cron to be run on the given schedule.
// The job is wrapped with the configured Chain.
func (c *Cron) Schedule(schedule Schedule, cmd Job) EntryID {
c.runningMu.Lock()
defer c.runningMu.Unlock()
c.nextID++
entry := &Entry{
ID: c.nextID,
Schedule: schedule,
WrappedJob: c.chain.Then(cmd),
Job: cmd,
}
if !c.running {
c.entries = append(c.entries, entry)
} else {
c.add <- entry
}
return entry.ID
}
// Entries returns a snapshot of the cron entries.
func (c *Cron) Entries() []Entry {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
replyChan := make(chan []Entry, 1)
c.snapshot <- replyChan
return <-replyChan
}
return c.entrySnapshot()
}
// Location gets the time zone location.
func (c *Cron) Location() *time.Location {
return c.location
}
// Entry returns a snapshot of the given entry, or nil if it couldn't be found.
func (c *Cron) Entry(id EntryID) Entry {
for _, entry := range c.Entries() {
if id == entry.ID {
return entry
}
}
return Entry{}
}
// Remove an entry from being run in the future.
func (c *Cron) Remove(id EntryID) {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
c.remove <- id
} else {
c.removeEntry(id)
}
}
// Start the cron scheduler in its own goroutine, or no-op if already started.
func (c *Cron) Start() {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
return
}
c.running = true
go c.runScheduler()
}
// Run the cron scheduler, or no-op if already running.
func (c *Cron) Run() {
c.runningMu.Lock()
if c.running {
c.runningMu.Unlock()
return
}
c.running = true
c.runningMu.Unlock()
c.runScheduler()
}
// runScheduler runs the scheduler.. this is private just due to the need to synchronize
// access to the 'running' state variable.
func (c *Cron) runScheduler() {
c.logger.Info("start")
// Figure out the next activation times for each entry.
now := c.now()
for _, entry := range c.entries {
entry.Next = entry.Schedule.Next(now)
c.logger.Info("schedule", "now", now, "entry", entry.ID, "next", entry.Next)
}
for {
// Determine the next entry to run.
sort.Sort(byTime(c.entries))
var timer *time.Timer
if len(c.entries) == 0 || c.entries[0].Next.IsZero() {
// If there are no entries yet, just sleep - it still handles new entries
// and stop requests.
timer = time.NewTimer(100000 * time.Hour)
} else {
timer = time.NewTimer(c.entries[0].Next.Sub(now))
}
for {
select {
case now = <-timer.C:
now = now.In(c.location)
c.logger.Info("wake", "now", now)
// Run every entry whose next time was less than now
for _, e := range c.entries {
if e.Next.After(now) || e.Next.IsZero() {
break
}
c.startJob(e.WrappedJob)
e.Prev = e.Next
e.Next = e.Schedule.Next(now)
c.logger.Info("run", "now", now, "entry", e.ID, "next", e.Next)
}
case newEntry := <-c.add:
timer.Stop()
now = c.now()
newEntry.Next = newEntry.Schedule.Next(now)
c.entries = append(c.entries, newEntry)
c.logger.Info("added", "now", now, "entry", newEntry.ID, "next", newEntry.Next)
case replyChan := <-c.snapshot:
replyChan <- c.entrySnapshot()
continue
case <-c.stop:
timer.Stop()
c.logger.Info("stop")
return
case id := <-c.remove:
timer.Stop()
now = c.now()
c.removeEntry(id)
c.logger.Info("removed", "entry", id)
}
break
}
}
}
// startJob runs the given job in a new goroutine.
func (c *Cron) startJob(j Job) {
c.jobWaiter.Go(func() {
j.Run()
})
}
// now returns current time in c location.
func (c *Cron) now() time.Time {
return time.Now().In(c.location)
}
// Stop stops the cron scheduler if it is running; otherwise it does nothing.
// A context is returned so the caller can wait for running jobs to complete.
func (c *Cron) Stop() context.Context {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
c.stop <- struct{}{}
c.running = false
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
c.jobWaiter.Wait()
cancel()
}()
return ctx
}
// entrySnapshot returns a copy of the current cron entry list.
func (c *Cron) entrySnapshot() []Entry {
var entries = make([]Entry, len(c.entries))
for i, e := range c.entries {
entries[i] = *e
}
return entries
}
func (c *Cron) removeEntry(id EntryID) {
var entries []*Entry
for _, e := range c.entries {
if e.ID != id {
entries = append(entries, e)
}
}
c.entries = entries
}

702
plugin/cron/cron_test.go Normal file
View File

@@ -0,0 +1,702 @@
//nolint:all
package cron
import (
"bytes"
"fmt"
"log"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// Many tests schedule a job for every second, and then wait at most a second
// for it to run. This amount is just slightly larger than 1 second to
// compensate for a few milliseconds of runtime.
const OneSecond = 1*time.Second + 50*time.Millisecond
type syncWriter struct {
wr bytes.Buffer
m sync.Mutex
}
func (sw *syncWriter) Write(data []byte) (n int, err error) {
sw.m.Lock()
n, err = sw.wr.Write(data)
sw.m.Unlock()
return
}
func (sw *syncWriter) String() string {
sw.m.Lock()
defer sw.m.Unlock()
return sw.wr.String()
}
func newBufLogger(sw *syncWriter) Logger {
return PrintfLogger(log.New(sw, "", log.LstdFlags))
}
func TestFuncPanicRecovery(t *testing.T) {
var buf syncWriter
cron := New(WithParser(secondParser),
WithChain(Recover(newBufLogger(&buf))))
cron.Start()
defer cron.Stop()
cron.AddFunc("* * * * * ?", func() {
panic("YOLO")
})
select {
case <-time.After(OneSecond):
if !strings.Contains(buf.String(), "YOLO") {
t.Error("expected a panic to be logged, got none")
}
return
}
}
type DummyJob struct{}
func (DummyJob) Run() {
panic("YOLO")
}
func TestJobPanicRecovery(t *testing.T) {
var job DummyJob
var buf syncWriter
cron := New(WithParser(secondParser),
WithChain(Recover(newBufLogger(&buf))))
cron.Start()
defer cron.Stop()
cron.AddJob("* * * * * ?", job)
select {
case <-time.After(OneSecond):
if !strings.Contains(buf.String(), "YOLO") {
t.Error("expected a panic to be logged, got none")
}
return
}
}
// Start and stop cron with no entries.
func TestNoEntries(t *testing.T) {
cron := newWithSeconds()
cron.Start()
select {
case <-time.After(OneSecond):
t.Fatal("expected cron will be stopped immediately")
case <-stop(cron):
}
}
// Start, stop, then add an entry. Verify entry doesn't run.
func TestStopCausesJobsToNotRun(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.Start()
cron.Stop()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
select {
case <-time.After(OneSecond):
// No job ran!
case <-wait(wg):
t.Fatal("expected stopped cron does not run any job")
}
}
// Add a job, start cron, expect it runs.
func TestAddBeforeRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Start()
defer cron.Stop()
// Give cron 2 seconds to run our job (which is always activated).
select {
case <-time.After(OneSecond):
t.Fatal("expected job runs")
case <-wait(wg):
}
}
// Start cron, add a job, expect it runs.
func TestAddWhileRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.Start()
defer cron.Stop()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
select {
case <-time.After(OneSecond):
t.Fatal("expected job runs")
case <-wait(wg):
}
}
// Test for #34. Adding a job after calling start results in multiple job invocations
func TestAddWhileRunningWithDelay(t *testing.T) {
cron := newWithSeconds()
cron.Start()
defer cron.Stop()
time.Sleep(5 * time.Second)
var calls int64
cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) })
<-time.After(OneSecond)
if atomic.LoadInt64(&calls) != 1 {
t.Errorf("called %d times, expected 1\n", calls)
}
}
// Add a job, remove a job, start cron, expect nothing runs.
func TestRemoveBeforeRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Remove(id)
cron.Start()
defer cron.Stop()
select {
case <-time.After(OneSecond):
// Success, shouldn't run
case <-wait(wg):
t.FailNow()
}
}
// Start cron, add a job, remove it, expect it doesn't run.
func TestRemoveWhileRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.Start()
defer cron.Stop()
id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Remove(id)
select {
case <-time.After(OneSecond):
case <-wait(wg):
t.FailNow()
}
}
// Test timing with Entries.
func TestSnapshotEntries(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := New()
cron.AddFunc("@every 2s", func() { wg.Done() })
cron.Start()
defer cron.Stop()
// Cron should fire in 2 seconds. After 1 second, call Entries.
select {
case <-time.After(OneSecond):
cron.Entries()
}
// Even though Entries was called, the cron should fire at the 2 second mark.
select {
case <-time.After(OneSecond):
t.Error("expected job runs at 2 second mark")
case <-wait(wg):
}
}
// Test that the entries are correctly sorted.
// Add a bunch of long-in-the-future entries, and an immediate entry, and ensure
// that the immediate entry runs immediately.
// Also: Test that multiple jobs run in the same instant.
func TestMultipleEntries(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
cron := newWithSeconds()
cron.AddFunc("0 0 0 1 1 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
id1, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() })
id2, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() })
cron.AddFunc("0 0 0 31 12 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Remove(id1)
cron.Start()
cron.Remove(id2)
defer cron.Stop()
select {
case <-time.After(OneSecond):
t.Error("expected job run in proper order")
case <-wait(wg):
}
}
// Test running the same job twice.
func TestRunningJobTwice(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
cron := newWithSeconds()
cron.AddFunc("0 0 0 1 1 ?", func() {})
cron.AddFunc("0 0 0 31 12 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Start()
defer cron.Stop()
select {
case <-time.After(2 * OneSecond):
t.Error("expected job fires 2 times")
case <-wait(wg):
}
}
func TestRunningMultipleSchedules(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
cron := newWithSeconds()
cron.AddFunc("0 0 0 1 1 ?", func() {})
cron.AddFunc("0 0 0 31 12 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Schedule(Every(time.Minute), FuncJob(func() {}))
cron.Schedule(Every(time.Second), FuncJob(func() { wg.Done() }))
cron.Schedule(Every(time.Hour), FuncJob(func() {}))
cron.Start()
defer cron.Stop()
select {
case <-time.After(2 * OneSecond):
t.Error("expected job fires 2 times")
case <-wait(wg):
}
}
// Test that the cron is run in the local time zone (as opposed to UTC).
func TestLocalTimezone(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
now := time.Now()
// FIX: Issue #205
// This calculation doesn't work in seconds 58 or 59.
// Take the easy way out and sleep.
if now.Second() >= 58 {
time.Sleep(2 * time.Second)
now = time.Now()
}
spec := fmt.Sprintf("%d,%d %d %d %d %d ?",
now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month())
cron := newWithSeconds()
cron.AddFunc(spec, func() { wg.Done() })
cron.Start()
defer cron.Stop()
select {
case <-time.After(OneSecond * 2):
t.Error("expected job fires 2 times")
case <-wait(wg):
}
}
// Test that the cron is run in the given time zone (as opposed to local).
func TestNonLocalTimezone(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
loc, err := time.LoadLocation("Atlantic/Cape_Verde")
if err != nil {
fmt.Printf("Failed to load time zone Atlantic/Cape_Verde: %+v", err)
t.Fail()
}
now := time.Now().In(loc)
// FIX: Issue #205
// This calculation doesn't work in seconds 58 or 59.
// Take the easy way out and sleep.
if now.Second() >= 58 {
time.Sleep(2 * time.Second)
now = time.Now().In(loc)
}
spec := fmt.Sprintf("%d,%d %d %d %d %d ?",
now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month())
cron := New(WithLocation(loc), WithParser(secondParser))
cron.AddFunc(spec, func() { wg.Done() })
cron.Start()
defer cron.Stop()
select {
case <-time.After(OneSecond * 2):
t.Error("expected job fires 2 times")
case <-wait(wg):
}
}
// Test that calling stop before start silently returns without
// blocking the stop channel.
func TestStopWithoutStart(t *testing.T) {
cron := New()
cron.Stop()
}
type testJob struct {
wg *sync.WaitGroup
name string
}
func (t testJob) Run() {
t.wg.Done()
}
// Test that adding an invalid job spec returns an error
func TestInvalidJobSpec(t *testing.T) {
cron := New()
_, err := cron.AddJob("this will not parse", nil)
if err == nil {
t.Errorf("expected an error with invalid spec, got nil")
}
}
// Test blocking run method behaves as Start()
func TestBlockingRun(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
var unblockChan = make(chan struct{})
go func() {
cron.Run()
close(unblockChan)
}()
defer cron.Stop()
select {
case <-time.After(OneSecond):
t.Error("expected job fires")
case <-unblockChan:
t.Error("expected that Run() blocks")
case <-wait(wg):
}
}
// Test that double-running is a no-op
func TestStartNoop(t *testing.T) {
var tickChan = make(chan struct{}, 2)
cron := newWithSeconds()
cron.AddFunc("* * * * * ?", func() {
tickChan <- struct{}{}
})
cron.Start()
defer cron.Stop()
// Wait for the first firing to ensure the runner is going
<-tickChan
cron.Start()
<-tickChan
// Fail if this job fires again in a short period, indicating a double-run
select {
case <-time.After(time.Millisecond):
case <-tickChan:
t.Error("expected job fires exactly twice")
}
}
// Simple test using Runnables.
func TestJob(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.AddJob("0 0 0 30 Feb ?", testJob{wg, "job0"})
cron.AddJob("0 0 0 1 1 ?", testJob{wg, "job1"})
job2, _ := cron.AddJob("* * * * * ?", testJob{wg, "job2"})
cron.AddJob("1 0 0 1 1 ?", testJob{wg, "job3"})
cron.Schedule(Every(5*time.Second+5*time.Nanosecond), testJob{wg, "job4"})
job5 := cron.Schedule(Every(5*time.Minute), testJob{wg, "job5"})
// Test getting an Entry pre-Start.
if actualName := cron.Entry(job2).Job.(testJob).name; actualName != "job2" {
t.Error("wrong job retrieved:", actualName)
}
if actualName := cron.Entry(job5).Job.(testJob).name; actualName != "job5" {
t.Error("wrong job retrieved:", actualName)
}
cron.Start()
defer cron.Stop()
select {
case <-time.After(OneSecond):
t.FailNow()
case <-wait(wg):
}
// Ensure the entries are in the right order.
expecteds := []string{"job2", "job4", "job5", "job1", "job3", "job0"}
var actuals []string
for _, entry := range cron.Entries() {
actuals = append(actuals, entry.Job.(testJob).name)
}
for i, expected := range expecteds {
if actuals[i] != expected {
t.Fatalf("Jobs not in the right order. (expected) %s != %s (actual)", expecteds, actuals)
}
}
// Test getting Entries.
if actualName := cron.Entry(job2).Job.(testJob).name; actualName != "job2" {
t.Error("wrong job retrieved:", actualName)
}
if actualName := cron.Entry(job5).Job.(testJob).name; actualName != "job5" {
t.Error("wrong job retrieved:", actualName)
}
}
// Issue #206
// Ensure that the next run of a job after removing an entry is accurate.
func TestScheduleAfterRemoval(t *testing.T) {
var wg1 sync.WaitGroup
var wg2 sync.WaitGroup
wg1.Add(1)
wg2.Add(1)
// The first time this job is run, set a timer and remove the other job
// 750ms later. Correct behavior would be to still run the job again in
// 250ms, but the bug would cause it to run instead 1s later.
var calls int
var mu sync.Mutex
cron := newWithSeconds()
hourJob := cron.Schedule(Every(time.Hour), FuncJob(func() {}))
cron.Schedule(Every(time.Second), FuncJob(func() {
mu.Lock()
defer mu.Unlock()
switch calls {
case 0:
wg1.Done()
calls++
case 1:
time.Sleep(750 * time.Millisecond)
cron.Remove(hourJob)
calls++
case 2:
calls++
wg2.Done()
case 3:
panic("unexpected 3rd call")
}
}))
cron.Start()
defer cron.Stop()
// the first run might be any length of time 0 - 1s, since the schedule
// rounds to the second. wait for the first run to true up.
wg1.Wait()
select {
case <-time.After(2 * OneSecond):
t.Error("expected job fires 2 times")
case <-wait(&wg2):
}
}
type ZeroSchedule struct{}
func (*ZeroSchedule) Next(time.Time) time.Time {
return time.Time{}
}
// Tests that job without time does not run
func TestJobWithZeroTimeDoesNotRun(t *testing.T) {
cron := newWithSeconds()
var calls int64
cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) })
cron.Schedule(new(ZeroSchedule), FuncJob(func() { t.Error("expected zero task will not run") }))
cron.Start()
defer cron.Stop()
<-time.After(OneSecond)
if atomic.LoadInt64(&calls) != 1 {
t.Errorf("called %d times, expected 1\n", calls)
}
}
func TestStopAndWait(t *testing.T) {
t.Run("nothing running, returns immediately", func(*testing.T) {
cron := newWithSeconds()
cron.Start()
ctx := cron.Stop()
select {
case <-ctx.Done():
case <-time.After(time.Millisecond):
t.Error("context was not done immediately")
}
})
t.Run("repeated calls to Stop", func(*testing.T) {
cron := newWithSeconds()
cron.Start()
_ = cron.Stop()
time.Sleep(time.Millisecond)
ctx := cron.Stop()
select {
case <-ctx.Done():
case <-time.After(time.Millisecond):
t.Error("context was not done immediately")
}
})
t.Run("a couple fast jobs added, still returns immediately", func(*testing.T) {
cron := newWithSeconds()
cron.AddFunc("* * * * * *", func() {})
cron.Start()
cron.AddFunc("* * * * * *", func() {})
cron.AddFunc("* * * * * *", func() {})
cron.AddFunc("* * * * * *", func() {})
time.Sleep(time.Second)
ctx := cron.Stop()
select {
case <-ctx.Done():
case <-time.After(time.Millisecond):
t.Error("context was not done immediately")
}
})
t.Run("a couple fast jobs and a slow job added, waits for slow job", func(*testing.T) {
cron := newWithSeconds()
cron.AddFunc("* * * * * *", func() {})
cron.Start()
cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) })
cron.AddFunc("* * * * * *", func() {})
time.Sleep(time.Second)
ctx := cron.Stop()
// Verify that it is not done for at least 750ms
select {
case <-ctx.Done():
t.Error("context was done too quickly immediately")
case <-time.After(750 * time.Millisecond):
// expected, because the job sleeping for 1 second is still running
}
// Verify that it IS done in the next 500ms (giving 250ms buffer)
select {
case <-ctx.Done():
// expected
case <-time.After(1500 * time.Millisecond):
t.Error("context not done after job should have completed")
}
})
t.Run("repeated calls to stop, waiting for completion and after", func(*testing.T) {
cron := newWithSeconds()
cron.AddFunc("* * * * * *", func() {})
cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) })
cron.Start()
cron.AddFunc("* * * * * *", func() {})
time.Sleep(time.Second)
ctx := cron.Stop()
ctx2 := cron.Stop()
// Verify that it is not done for at least 1500ms
select {
case <-ctx.Done():
t.Error("context was done too quickly immediately")
case <-ctx2.Done():
t.Error("context2 was done too quickly immediately")
case <-time.After(1500 * time.Millisecond):
// expected, because the job sleeping for 2 seconds is still running
}
// Verify that it IS done in the next 1s (giving 500ms buffer)
select {
case <-ctx.Done():
// expected
case <-time.After(time.Second):
t.Error("context not done after job should have completed")
}
// Verify that ctx2 is also done.
select {
case <-ctx2.Done():
// expected
case <-time.After(time.Millisecond):
t.Error("context2 not done even though context1 is")
}
// Verify that a new context retrieved from stop is immediately done.
ctx3 := cron.Stop()
select {
case <-ctx3.Done():
// expected
case <-time.After(time.Millisecond):
t.Error("context not done even when cron Stop is completed")
}
})
}
func TestMultiThreadedStartAndStop(t *testing.T) {
cron := New()
go cron.Run()
time.Sleep(2 * time.Millisecond)
cron.Stop()
}
func wait(wg *sync.WaitGroup) chan bool {
ch := make(chan bool)
go func() {
wg.Wait()
ch <- true
}()
return ch
}
func stop(cron *Cron) chan bool {
ch := make(chan bool)
go func() {
cron.Stop()
ch <- true
}()
return ch
}
// newWithSeconds returns a Cron with the seconds field enabled.
func newWithSeconds() *Cron {
return New(WithParser(secondParser), WithChain())
}

86
plugin/cron/logger.go Normal file
View File

@@ -0,0 +1,86 @@
package cron
import (
"io"
"log"
"os"
"strings"
"time"
)
// DefaultLogger is used by Cron if none is specified.
var DefaultLogger = PrintfLogger(log.New(os.Stdout, "cron: ", log.LstdFlags))
// DiscardLogger can be used by callers to discard all log messages.
var DiscardLogger = PrintfLogger(log.New(io.Discard, "", 0))
// Logger is the interface used in this package for logging, so that any backend
// can be plugged in. It is a subset of the github.com/go-logr/logr interface.
type Logger interface {
// Info logs routine messages about cron's operation.
Info(msg string, keysAndValues ...interface{})
// Error logs an error condition.
Error(err error, msg string, keysAndValues ...interface{})
}
// PrintfLogger wraps a Printf-based logger (such as the standard library "log")
// into an implementation of the Logger interface which logs errors only.
func PrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger {
return printfLogger{l, false}
}
// VerbosePrintfLogger wraps a Printf-based logger (such as the standard library
// "log") into an implementation of the Logger interface which logs everything.
func VerbosePrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger {
return printfLogger{l, true}
}
type printfLogger struct {
logger interface{ Printf(string, ...interface{}) }
logInfo bool
}
func (pl printfLogger) Info(msg string, keysAndValues ...interface{}) {
if pl.logInfo {
keysAndValues = formatTimes(keysAndValues)
pl.logger.Printf(
formatString(len(keysAndValues)),
append([]interface{}{msg}, keysAndValues...)...)
}
}
func (pl printfLogger) Error(err error, msg string, keysAndValues ...interface{}) {
keysAndValues = formatTimes(keysAndValues)
pl.logger.Printf(
formatString(len(keysAndValues)+2),
append([]interface{}{msg, "error", err}, keysAndValues...)...)
}
// formatString returns a logfmt-like format string for the number of
// key/values.
func formatString(numKeysAndValues int) string {
var sb strings.Builder
sb.WriteString("%s")
if numKeysAndValues > 0 {
sb.WriteString(", ")
}
for i := 0; i < numKeysAndValues/2; i++ {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString("%v=%v")
}
return sb.String()
}
// formatTimes formats any time.Time values as RFC3339.
func formatTimes(keysAndValues []interface{}) []interface{} {
var formattedArgs []interface{}
for _, arg := range keysAndValues {
if t, ok := arg.(time.Time); ok {
arg = t.Format(time.RFC3339)
}
formattedArgs = append(formattedArgs, arg)
}
return formattedArgs
}

45
plugin/cron/option.go Normal file
View File

@@ -0,0 +1,45 @@
package cron
import (
"time"
)
// Option represents a modification to the default behavior of a Cron.
type Option func(*Cron)
// WithLocation overrides the timezone of the cron instance.
func WithLocation(loc *time.Location) Option {
return func(c *Cron) {
c.location = loc
}
}
// WithSeconds overrides the parser used for interpreting job schedules to
// include a seconds field as the first one.
func WithSeconds() Option {
return WithParser(NewParser(
Second | Minute | Hour | Dom | Month | Dow | Descriptor,
))
}
// WithParser overrides the parser used for interpreting job schedules.
func WithParser(p ScheduleParser) Option {
return func(c *Cron) {
c.parser = p
}
}
// WithChain specifies Job wrappers to apply to all jobs added to this cron.
// Refer to the Chain* functions in this package for provided wrappers.
func WithChain(wrappers ...JobWrapper) Option {
return func(c *Cron) {
c.chain = NewChain(wrappers...)
}
}
// WithLogger uses the provided logger.
func WithLogger(logger Logger) Option {
return func(c *Cron) {
c.logger = logger
}
}

View File

@@ -0,0 +1,43 @@
//nolint:all
package cron
import (
"log"
"strings"
"testing"
"time"
)
func TestWithLocation(t *testing.T) {
c := New(WithLocation(time.UTC))
if c.location != time.UTC {
t.Errorf("expected UTC, got %v", c.location)
}
}
func TestWithParser(t *testing.T) {
var parser = NewParser(Dow)
c := New(WithParser(parser))
if c.parser != parser {
t.Error("expected provided parser")
}
}
func TestWithVerboseLogger(t *testing.T) {
var buf syncWriter
var logger = log.New(&buf, "", log.LstdFlags)
c := New(WithLogger(VerbosePrintfLogger(logger)))
if c.logger.(printfLogger).logger != logger {
t.Error("expected provided logger")
}
c.AddFunc("@every 1s", func() {})
c.Start()
time.Sleep(OneSecond)
c.Stop()
out := buf.String()
if !strings.Contains(out, "schedule,") ||
!strings.Contains(out, "run,") {
t.Error("expected to see some actions, got:", out)
}
}

437
plugin/cron/parser.go Normal file
View File

@@ -0,0 +1,437 @@
package cron
import (
"math"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
)
// Configuration options for creating a parser. Most options specify which
// fields should be included, while others enable features. If a field is not
// included the parser will assume a default value. These options do not change
// the order fields are parse in.
type ParseOption int
const (
Second ParseOption = 1 << iota // Seconds field, default 0
SecondOptional // Optional seconds field, default 0
Minute // Minutes field, default 0
Hour // Hours field, default 0
Dom // Day of month field, default *
Month // Month field, default *
Dow // Day of week field, default *
DowOptional // Optional day of week field, default *
Descriptor // Allow descriptors such as @monthly, @weekly, etc.
)
var places = []ParseOption{
Second,
Minute,
Hour,
Dom,
Month,
Dow,
}
var defaults = []string{
"0",
"0",
"0",
"*",
"*",
"*",
}
// A custom Parser that can be configured.
type Parser struct {
options ParseOption
}
// NewParser creates a Parser with custom options.
//
// It panics if more than one Optional is given, since it would be impossible to
// correctly infer which optional is provided or missing in general.
//
// Examples
//
// // Standard parser without descriptors
// specParser := NewParser(Minute | Hour | Dom | Month | Dow)
// sched, err := specParser.Parse("0 0 15 */3 *")
//
// // Same as above, just excludes time fields
// specParser := NewParser(Dom | Month | Dow)
// sched, err := specParser.Parse("15 */3 *")
//
// // Same as above, just makes Dow optional
// specParser := NewParser(Dom | Month | DowOptional)
// sched, err := specParser.Parse("15 */3")
func NewParser(options ParseOption) Parser {
optionals := 0
if options&DowOptional > 0 {
optionals++
}
if options&SecondOptional > 0 {
optionals++
}
if optionals > 1 {
panic("multiple optionals may not be configured")
}
return Parser{options}
}
// Parse returns a new crontab schedule representing the given spec.
// It returns a descriptive error if the spec is not valid.
// It accepts crontab specs and features configured by NewParser.
func (p Parser) Parse(spec string) (Schedule, error) {
if len(spec) == 0 {
return nil, errors.New("empty spec string")
}
// Extract timezone if present
var loc = time.Local
if strings.HasPrefix(spec, "TZ=") || strings.HasPrefix(spec, "CRON_TZ=") {
var err error
i := strings.Index(spec, " ")
eq := strings.Index(spec, "=")
if loc, err = time.LoadLocation(spec[eq+1 : i]); err != nil {
return nil, errors.Wrap(err, "provided bad location")
}
spec = strings.TrimSpace(spec[i:])
}
// Handle named schedules (descriptors), if configured
if strings.HasPrefix(spec, "@") {
if p.options&Descriptor == 0 {
return nil, errors.New("descriptors not enabled")
}
return parseDescriptor(spec, loc)
}
// Split on whitespace.
fields := strings.Fields(spec)
// Validate & fill in any omitted or optional fields
var err error
fields, err = normalizeFields(fields, p.options)
if err != nil {
return nil, err
}
field := func(field string, r bounds) uint64 {
if err != nil {
return 0
}
var bits uint64
bits, err = getField(field, r)
return bits
}
var (
second = field(fields[0], seconds)
minute = field(fields[1], minutes)
hour = field(fields[2], hours)
dayofmonth = field(fields[3], dom)
month = field(fields[4], months)
dayofweek = field(fields[5], dow)
)
if err != nil {
return nil, err
}
return &SpecSchedule{
Second: second,
Minute: minute,
Hour: hour,
Dom: dayofmonth,
Month: month,
Dow: dayofweek,
Location: loc,
}, nil
}
// normalizeFields takes a subset set of the time fields and returns the full set
// with defaults (zeroes) populated for unset fields.
//
// As part of performing this function, it also validates that the provided
// fields are compatible with the configured options.
func normalizeFields(fields []string, options ParseOption) ([]string, error) {
// Validate optionals & add their field to options
optionals := 0
if options&SecondOptional > 0 {
options |= Second
optionals++
}
if options&DowOptional > 0 {
options |= Dow
optionals++
}
if optionals > 1 {
return nil, errors.New("multiple optionals may not be configured")
}
// Figure out how many fields we need
max := 0
for _, place := range places {
if options&place > 0 {
max++
}
}
min := max - optionals
// Validate number of fields
if count := len(fields); count < min || count > max {
if min == max {
return nil, errors.New("incorrect number of fields")
}
return nil, errors.New("incorrect number of fields, expected " + strconv.Itoa(min) + "-" + strconv.Itoa(max))
}
// Populate the optional field if not provided
if min < max && len(fields) == min {
switch {
case options&DowOptional > 0:
fields = append(fields, defaults[5]) // TODO: improve access to default
case options&SecondOptional > 0:
fields = append([]string{defaults[0]}, fields...)
default:
return nil, errors.New("unexpected optional field")
}
}
// Populate all fields not part of options with their defaults
n := 0
expandedFields := make([]string, len(places))
copy(expandedFields, defaults)
for i, place := range places {
if options&place > 0 {
expandedFields[i] = fields[n]
n++
}
}
return expandedFields, nil
}
var standardParser = NewParser(
Minute | Hour | Dom | Month | Dow | Descriptor,
)
// ParseStandard returns a new crontab schedule representing the given
// standardSpec (https://en.wikipedia.org/wiki/Cron). It requires 5 entries
// representing: minute, hour, day of month, month and day of week, in that
// order. It returns a descriptive error if the spec is not valid.
//
// It accepts
// - Standard crontab specs, e.g. "* * * * ?"
// - Descriptors, e.g. "@midnight", "@every 1h30m"
func ParseStandard(standardSpec string) (Schedule, error) {
return standardParser.Parse(standardSpec)
}
// getField returns an Int with the bits set representing all of the times that
// the field represents or error parsing field value. A "field" is a comma-separated
// list of "ranges".
func getField(field string, r bounds) (uint64, error) {
var bits uint64
ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' })
for _, expr := range ranges {
bit, err := getRange(expr, r)
if err != nil {
return bits, err
}
bits |= bit
}
return bits, nil
}
// getRange returns the bits indicated by the given expression:
//
// number | number "-" number [ "/" number ]
//
// or error parsing range.
func getRange(expr string, r bounds) (uint64, error) {
var (
start, end, step uint
rangeAndStep = strings.Split(expr, "/")
lowAndHigh = strings.Split(rangeAndStep[0], "-")
singleDigit = len(lowAndHigh) == 1
err error
)
var extra uint64
if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" {
start = r.min
end = r.max
extra = starBit
} else {
start, err = parseIntOrName(lowAndHigh[0], r.names)
if err != nil {
return 0, err
}
switch len(lowAndHigh) {
case 1:
end = start
case 2:
end, err = parseIntOrName(lowAndHigh[1], r.names)
if err != nil {
return 0, err
}
default:
return 0, errors.New("too many hyphens: " + expr)
}
}
switch len(rangeAndStep) {
case 1:
step = 1
case 2:
step, err = mustParseInt(rangeAndStep[1])
if err != nil {
return 0, err
}
// Special handling: "N/step" means "N-max/step".
if singleDigit {
end = r.max
}
if step > 1 {
extra = 0
}
default:
return 0, errors.New("too many slashes: " + expr)
}
if start < r.min {
return 0, errors.New("beginning of range below minimum: " + expr)
}
if end > r.max {
return 0, errors.New("end of range above maximum: " + expr)
}
if start > end {
return 0, errors.New("beginning of range after end: " + expr)
}
if step == 0 {
return 0, errors.New("step cannot be zero: " + expr)
}
return getBits(start, end, step) | extra, nil
}
// parseIntOrName returns the (possibly-named) integer contained in expr.
func parseIntOrName(expr string, names map[string]uint) (uint, error) {
if names != nil {
if namedInt, ok := names[strings.ToLower(expr)]; ok {
return namedInt, nil
}
}
return mustParseInt(expr)
}
// mustParseInt parses the given expression as an int or returns an error.
func mustParseInt(expr string) (uint, error) {
num, err := strconv.Atoi(expr)
if err != nil {
return 0, errors.Wrap(err, "failed to parse number")
}
if num < 0 {
return 0, errors.New("number must be positive")
}
return uint(num), nil
}
// getBits sets all bits in the range [min, max], modulo the given step size.
func getBits(min, max, step uint) uint64 {
var bits uint64
// If step is 1, use shifts.
if step == 1 {
return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min)
}
// Else, use a simple loop.
for i := min; i <= max; i += step {
bits |= 1 << i
}
return bits
}
// all returns all bits within the given bounds.
func all(r bounds) uint64 {
return getBits(r.min, r.max, 1) | starBit
}
// parseDescriptor returns a predefined schedule for the expression, or error if none matches.
func parseDescriptor(descriptor string, loc *time.Location) (Schedule, error) {
switch descriptor {
case "@yearly", "@annually":
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: 1 << hours.min,
Dom: 1 << dom.min,
Month: 1 << months.min,
Dow: all(dow),
Location: loc,
}, nil
case "@monthly":
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: 1 << hours.min,
Dom: 1 << dom.min,
Month: all(months),
Dow: all(dow),
Location: loc,
}, nil
case "@weekly":
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: 1 << hours.min,
Dom: all(dom),
Month: all(months),
Dow: 1 << dow.min,
Location: loc,
}, nil
case "@daily", "@midnight":
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: 1 << hours.min,
Dom: all(dom),
Month: all(months),
Dow: all(dow),
Location: loc,
}, nil
case "@hourly":
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: all(hours),
Dom: all(dom),
Month: all(months),
Dow: all(dow),
Location: loc,
}, nil
default:
// Continue to check @every prefix below
}
const every = "@every "
if strings.HasPrefix(descriptor, every) {
duration, err := time.ParseDuration(descriptor[len(every):])
if err != nil {
return nil, errors.Wrap(err, "failed to parse duration")
}
return Every(duration), nil
}
return nil, errors.New("unrecognized descriptor: " + descriptor)
}

384
plugin/cron/parser_test.go Normal file
View File

@@ -0,0 +1,384 @@
//nolint:all
package cron
import (
"reflect"
"strings"
"testing"
"time"
)
var secondParser = NewParser(Second | Minute | Hour | Dom | Month | DowOptional | Descriptor)
func TestRange(t *testing.T) {
zero := uint64(0)
ranges := []struct {
expr string
min, max uint
expected uint64
err string
}{
{"5", 0, 7, 1 << 5, ""},
{"0", 0, 7, 1 << 0, ""},
{"7", 0, 7, 1 << 7, ""},
{"5-5", 0, 7, 1 << 5, ""},
{"5-6", 0, 7, 1<<5 | 1<<6, ""},
{"5-7", 0, 7, 1<<5 | 1<<6 | 1<<7, ""},
{"5-6/2", 0, 7, 1 << 5, ""},
{"5-7/2", 0, 7, 1<<5 | 1<<7, ""},
{"5-7/1", 0, 7, 1<<5 | 1<<6 | 1<<7, ""},
{"*", 1, 3, 1<<1 | 1<<2 | 1<<3 | starBit, ""},
{"*/2", 1, 3, 1<<1 | 1<<3, ""},
{"5--5", 0, 0, zero, "too many hyphens"},
{"jan-x", 0, 0, zero, `failed to parse number: strconv.Atoi: parsing "jan": invalid syntax`},
{"2-x", 1, 5, zero, `failed to parse number: strconv.Atoi: parsing "x": invalid syntax`},
{"*/-12", 0, 0, zero, "number must be positive"},
{"*//2", 0, 0, zero, "too many slashes"},
{"1", 3, 5, zero, "below minimum"},
{"6", 3, 5, zero, "above maximum"},
{"5-3", 3, 5, zero, "beginning of range after end: 5-3"},
{"*/0", 0, 0, zero, "step cannot be zero: */0"},
}
for _, c := range ranges {
actual, err := getRange(c.expr, bounds{c.min, c.max, nil})
if len(c.err) != 0 && (err == nil || !strings.Contains(err.Error(), c.err)) {
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
}
if len(c.err) == 0 && err != nil {
t.Errorf("%s => unexpected error %v", c.expr, err)
}
if actual != c.expected {
t.Errorf("%s => expected %d, got %d", c.expr, c.expected, actual)
}
}
}
func TestField(t *testing.T) {
fields := []struct {
expr string
min, max uint
expected uint64
}{
{"5", 1, 7, 1 << 5},
{"5,6", 1, 7, 1<<5 | 1<<6},
{"5,6,7", 1, 7, 1<<5 | 1<<6 | 1<<7},
{"1,5-7/2,3", 1, 7, 1<<1 | 1<<5 | 1<<7 | 1<<3},
}
for _, c := range fields {
actual, _ := getField(c.expr, bounds{c.min, c.max, nil})
if actual != c.expected {
t.Errorf("%s => expected %d, got %d", c.expr, c.expected, actual)
}
}
}
func TestAll(t *testing.T) {
allBits := []struct {
r bounds
expected uint64
}{
{minutes, 0xfffffffffffffff}, // 0-59: 60 ones
{hours, 0xffffff}, // 0-23: 24 ones
{dom, 0xfffffffe}, // 1-31: 31 ones, 1 zero
{months, 0x1ffe}, // 1-12: 12 ones, 1 zero
{dow, 0x7f}, // 0-6: 7 ones
}
for _, c := range allBits {
actual := all(c.r) // all() adds the starBit, so compensate for that..
if c.expected|starBit != actual {
t.Errorf("%d-%d/%d => expected %b, got %b",
c.r.min, c.r.max, 1, c.expected|starBit, actual)
}
}
}
func TestBits(t *testing.T) {
bits := []struct {
min, max, step uint
expected uint64
}{
{0, 0, 1, 0x1},
{1, 1, 1, 0x2},
{1, 5, 2, 0x2a}, // 101010
{1, 4, 2, 0xa}, // 1010
}
for _, c := range bits {
actual := getBits(c.min, c.max, c.step)
if c.expected != actual {
t.Errorf("%d-%d/%d => expected %b, got %b",
c.min, c.max, c.step, c.expected, actual)
}
}
}
func TestParseScheduleErrors(t *testing.T) {
var tests = []struct{ expr, err string }{
{"* 5 j * * *", `failed to parse number: strconv.Atoi: parsing "j": invalid syntax`},
{"@every Xm", "failed to parse duration"},
{"@unrecognized", "unrecognized descriptor"},
{"* * * *", "incorrect number of fields, expected 5-6"},
{"", "empty spec string"},
}
for _, c := range tests {
actual, err := secondParser.Parse(c.expr)
if err == nil || !strings.Contains(err.Error(), c.err) {
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
}
if actual != nil {
t.Errorf("expected nil schedule on error, got %v", actual)
}
}
}
func TestParseSchedule(t *testing.T) {
tokyo, _ := time.LoadLocation("Asia/Tokyo")
entries := []struct {
parser Parser
expr string
expected Schedule
}{
{secondParser, "0 5 * * * *", every5min(time.Local)},
{standardParser, "5 * * * *", every5min(time.Local)},
{secondParser, "CRON_TZ=UTC 0 5 * * * *", every5min(time.UTC)},
{standardParser, "CRON_TZ=UTC 5 * * * *", every5min(time.UTC)},
{secondParser, "CRON_TZ=Asia/Tokyo 0 5 * * * *", every5min(tokyo)},
{secondParser, "@every 5m", ConstantDelaySchedule{5 * time.Minute}},
{secondParser, "@midnight", midnight(time.Local)},
{secondParser, "TZ=UTC @midnight", midnight(time.UTC)},
{secondParser, "TZ=Asia/Tokyo @midnight", midnight(tokyo)},
{secondParser, "@yearly", annual(time.Local)},
{secondParser, "@annually", annual(time.Local)},
{
parser: secondParser,
expr: "* 5 * * * *",
expected: &SpecSchedule{
Second: all(seconds),
Minute: 1 << 5,
Hour: all(hours),
Dom: all(dom),
Month: all(months),
Dow: all(dow),
Location: time.Local,
},
},
}
for _, c := range entries {
actual, err := c.parser.Parse(c.expr)
if err != nil {
t.Errorf("%s => unexpected error %v", c.expr, err)
}
if !reflect.DeepEqual(actual, c.expected) {
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
}
}
}
func TestOptionalSecondSchedule(t *testing.T) {
parser := NewParser(SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor)
entries := []struct {
expr string
expected Schedule
}{
{"0 5 * * * *", every5min(time.Local)},
{"5 5 * * * *", every5min5s(time.Local)},
{"5 * * * *", every5min(time.Local)},
}
for _, c := range entries {
actual, err := parser.Parse(c.expr)
if err != nil {
t.Errorf("%s => unexpected error %v", c.expr, err)
}
if !reflect.DeepEqual(actual, c.expected) {
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
}
}
}
func TestNormalizeFields(t *testing.T) {
tests := []struct {
name string
input []string
options ParseOption
expected []string
}{
{
"AllFields_NoOptional",
[]string{"0", "5", "*", "*", "*", "*"},
Second | Minute | Hour | Dom | Month | Dow | Descriptor,
[]string{"0", "5", "*", "*", "*", "*"},
},
{
"AllFields_SecondOptional_Provided",
[]string{"0", "5", "*", "*", "*", "*"},
SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor,
[]string{"0", "5", "*", "*", "*", "*"},
},
{
"AllFields_SecondOptional_NotProvided",
[]string{"5", "*", "*", "*", "*"},
SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor,
[]string{"0", "5", "*", "*", "*", "*"},
},
{
"SubsetFields_NoOptional",
[]string{"5", "15", "*"},
Hour | Dom | Month,
[]string{"0", "0", "5", "15", "*", "*"},
},
{
"SubsetFields_DowOptional_Provided",
[]string{"5", "15", "*", "4"},
Hour | Dom | Month | DowOptional,
[]string{"0", "0", "5", "15", "*", "4"},
},
{
"SubsetFields_DowOptional_NotProvided",
[]string{"5", "15", "*"},
Hour | Dom | Month | DowOptional,
[]string{"0", "0", "5", "15", "*", "*"},
},
{
"SubsetFields_SecondOptional_NotProvided",
[]string{"5", "15", "*"},
SecondOptional | Hour | Dom | Month,
[]string{"0", "0", "5", "15", "*", "*"},
},
}
for _, test := range tests {
t.Run(test.name, func(*testing.T) {
actual, err := normalizeFields(test.input, test.options)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !reflect.DeepEqual(actual, test.expected) {
t.Errorf("expected %v, got %v", test.expected, actual)
}
})
}
}
func TestNormalizeFields_Errors(t *testing.T) {
tests := []struct {
name string
input []string
options ParseOption
err string
}{
{
"TwoOptionals",
[]string{"0", "5", "*", "*", "*", "*"},
SecondOptional | Minute | Hour | Dom | Month | DowOptional,
"",
},
{
"TooManyFields",
[]string{"0", "5", "*", "*"},
SecondOptional | Minute | Hour,
"",
},
{
"NoFields",
[]string{},
SecondOptional | Minute | Hour,
"",
},
{
"TooFewFields",
[]string{"*"},
SecondOptional | Minute | Hour,
"",
},
}
for _, test := range tests {
t.Run(test.name, func(*testing.T) {
actual, err := normalizeFields(test.input, test.options)
if err == nil {
t.Errorf("expected an error, got none. results: %v", actual)
}
if !strings.Contains(err.Error(), test.err) {
t.Errorf("expected error %q, got %q", test.err, err.Error())
}
})
}
}
func TestStandardSpecSchedule(t *testing.T) {
entries := []struct {
expr string
expected Schedule
err string
}{
{
expr: "5 * * * *",
expected: &SpecSchedule{1 << seconds.min, 1 << 5, all(hours), all(dom), all(months), all(dow), time.Local},
},
{
expr: "@every 5m",
expected: ConstantDelaySchedule{time.Duration(5) * time.Minute},
},
{
expr: "5 j * * *",
err: `failed to parse number: strconv.Atoi: parsing "j": invalid syntax`,
},
{
expr: "* * * *",
err: "incorrect number of fields",
},
}
for _, c := range entries {
actual, err := ParseStandard(c.expr)
if len(c.err) != 0 && (err == nil || !strings.Contains(err.Error(), c.err)) {
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
}
if len(c.err) == 0 && err != nil {
t.Errorf("%s => unexpected error %v", c.expr, err)
}
if !reflect.DeepEqual(actual, c.expected) {
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
}
}
}
func TestNoDescriptorParser(t *testing.T) {
parser := NewParser(Minute | Hour)
_, err := parser.Parse("@every 1m")
if err == nil {
t.Error("expected an error, got none")
}
}
func every5min(loc *time.Location) *SpecSchedule {
return &SpecSchedule{1 << 0, 1 << 5, all(hours), all(dom), all(months), all(dow), loc}
}
func every5min5s(loc *time.Location) *SpecSchedule {
return &SpecSchedule{1 << 5, 1 << 5, all(hours), all(dom), all(months), all(dow), loc}
}
func midnight(loc *time.Location) *SpecSchedule {
return &SpecSchedule{1, 1, 1, all(dom), all(months), all(dow), loc}
}
func annual(loc *time.Location) *SpecSchedule {
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: 1 << hours.min,
Dom: 1 << dom.min,
Month: 1 << months.min,
Dow: all(dow),
Location: loc,
}
}

188
plugin/cron/spec.go Normal file
View File

@@ -0,0 +1,188 @@
package cron
import "time"
// SpecSchedule specifies a duty cycle (to the second granularity), based on a
// traditional crontab specification. It is computed initially and stored as bit sets.
type SpecSchedule struct {
Second, Minute, Hour, Dom, Month, Dow uint64
// Override location for this schedule.
Location *time.Location
}
// bounds provides a range of acceptable values (plus a map of name to value).
type bounds struct {
min, max uint
names map[string]uint
}
// The bounds for each field.
var (
seconds = bounds{0, 59, nil}
minutes = bounds{0, 59, nil}
hours = bounds{0, 23, nil}
dom = bounds{1, 31, nil}
months = bounds{1, 12, map[string]uint{
"jan": 1,
"feb": 2,
"mar": 3,
"apr": 4,
"may": 5,
"jun": 6,
"jul": 7,
"aug": 8,
"sep": 9,
"oct": 10,
"nov": 11,
"dec": 12,
}}
dow = bounds{0, 6, map[string]uint{
"sun": 0,
"mon": 1,
"tue": 2,
"wed": 3,
"thu": 4,
"fri": 5,
"sat": 6,
}}
)
const (
// Set the top bit if a star was included in the expression.
starBit = 1 << 63
)
// Next returns the next time this schedule is activated, greater than the given
// time. If no time can be found to satisfy the schedule, return the zero time.
func (s *SpecSchedule) Next(t time.Time) time.Time {
// General approach
//
// For Month, Day, Hour, Minute, Second:
// Check if the time value matches. If yes, continue to the next field.
// If the field doesn't match the schedule, then increment the field until it matches.
// While incrementing the field, a wrap-around brings it back to the beginning
// of the field list (since it is necessary to re-verify previous field
// values)
// Convert the given time into the schedule's timezone, if one is specified.
// Save the original timezone so we can convert back after we find a time.
// Note that schedules without a time zone specified (time.Local) are treated
// as local to the time provided.
origLocation := t.Location()
loc := s.Location
if loc == time.Local {
loc = t.Location()
}
if s.Location != time.Local {
t = t.In(s.Location)
}
// Start at the earliest possible time (the upcoming second).
t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond)
// This flag indicates whether a field has been incremented.
added := false
// If no time is found within five years, return zero.
yearLimit := t.Year() + 5
WRAP:
if t.Year() > yearLimit {
return time.Time{}
}
// Find the first applicable month.
// If it's this month, then do nothing.
for 1<<uint(t.Month())&s.Month == 0 {
// If we have to add a month, reset the other parts to 0.
if !added {
added = true
// Otherwise, set the date at the beginning (since the current time is irrelevant).
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
}
t = t.AddDate(0, 1, 0)
// Wrapped around.
if t.Month() == time.January {
goto WRAP
}
}
// Now get a day in that month.
//
// NOTE: This causes issues for daylight savings regimes where midnight does
// not exist. For example: Sao Paulo has DST that transforms midnight on
// 11/3 into 1am. Handle that by noticing when the Hour ends up != 0.
for !dayMatches(s, t) {
if !added {
added = true
t = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
}
t = t.AddDate(0, 0, 1)
// Notice if the hour is no longer midnight due to DST.
// Add an hour if it's 23, subtract an hour if it's 1.
if t.Hour() != 0 {
if t.Hour() > 12 {
t = t.Add(time.Duration(24-t.Hour()) * time.Hour)
} else {
t = t.Add(time.Duration(-t.Hour()) * time.Hour)
}
}
if t.Day() == 1 {
goto WRAP
}
}
for 1<<uint(t.Hour())&s.Hour == 0 {
if !added {
added = true
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), 0, 0, 0, loc)
}
t = t.Add(1 * time.Hour)
if t.Hour() == 0 {
goto WRAP
}
}
for 1<<uint(t.Minute())&s.Minute == 0 {
if !added {
added = true
t = t.Truncate(time.Minute)
}
t = t.Add(1 * time.Minute)
if t.Minute() == 0 {
goto WRAP
}
}
for 1<<uint(t.Second())&s.Second == 0 {
if !added {
added = true
t = t.Truncate(time.Second)
}
t = t.Add(1 * time.Second)
if t.Second() == 0 {
goto WRAP
}
}
return t.In(origLocation)
}
// dayMatches returns true if the schedule's day-of-week and day-of-month
// restrictions are satisfied by the given time.
func dayMatches(s *SpecSchedule, t time.Time) bool {
var (
domMatch = 1<<uint(t.Day())&s.Dom > 0
dowMatch = 1<<uint(t.Weekday())&s.Dow > 0
)
if s.Dom&starBit > 0 || s.Dow&starBit > 0 {
return domMatch && dowMatch
}
return domMatch || dowMatch
}

301
plugin/cron/spec_test.go Normal file
View File

@@ -0,0 +1,301 @@
//nolint:all
package cron
import (
"strings"
"testing"
"time"
)
func TestActivation(t *testing.T) {
tests := []struct {
time, spec string
expected bool
}{
// Every fifteen minutes.
{"Mon Jul 9 15:00 2012", "0/15 * * * *", true},
{"Mon Jul 9 15:45 2012", "0/15 * * * *", true},
{"Mon Jul 9 15:40 2012", "0/15 * * * *", false},
// Every fifteen minutes, starting at 5 minutes.
{"Mon Jul 9 15:05 2012", "5/15 * * * *", true},
{"Mon Jul 9 15:20 2012", "5/15 * * * *", true},
{"Mon Jul 9 15:50 2012", "5/15 * * * *", true},
// Named months
{"Sun Jul 15 15:00 2012", "0/15 * * Jul *", true},
{"Sun Jul 15 15:00 2012", "0/15 * * Jun *", false},
// Everything set.
{"Sun Jul 15 08:30 2012", "30 08 ? Jul Sun", true},
{"Sun Jul 15 08:30 2012", "30 08 15 Jul ?", true},
{"Mon Jul 16 08:30 2012", "30 08 ? Jul Sun", false},
{"Mon Jul 16 08:30 2012", "30 08 15 Jul ?", false},
// Predefined schedules
{"Mon Jul 9 15:00 2012", "@hourly", true},
{"Mon Jul 9 15:04 2012", "@hourly", false},
{"Mon Jul 9 15:00 2012", "@daily", false},
{"Mon Jul 9 00:00 2012", "@daily", true},
{"Mon Jul 9 00:00 2012", "@weekly", false},
{"Sun Jul 8 00:00 2012", "@weekly", true},
{"Sun Jul 8 01:00 2012", "@weekly", false},
{"Sun Jul 8 00:00 2012", "@monthly", false},
{"Sun Jul 1 00:00 2012", "@monthly", true},
// Test interaction of DOW and DOM.
// If both are restricted, then only one needs to match.
{"Sun Jul 15 00:00 2012", "* * 1,15 * Sun", true},
{"Fri Jun 15 00:00 2012", "* * 1,15 * Sun", true},
{"Wed Aug 1 00:00 2012", "* * 1,15 * Sun", true},
{"Sun Jul 15 00:00 2012", "* * */10 * Sun", true}, // verifies #70
// However, if one has a star, then both need to match.
{"Sun Jul 15 00:00 2012", "* * * * Mon", false},
{"Mon Jul 9 00:00 2012", "* * 1,15 * *", false},
{"Sun Jul 15 00:00 2012", "* * 1,15 * *", true},
{"Sun Jul 15 00:00 2012", "* * */2 * Sun", true},
}
for _, test := range tests {
sched, err := ParseStandard(test.spec)
if err != nil {
t.Error(err)
continue
}
actual := sched.Next(getTime(test.time).Add(-1 * time.Second))
expected := getTime(test.time)
if test.expected && expected != actual || !test.expected && expected == actual {
t.Errorf("Fail evaluating %s on %s: (expected) %s != %s (actual)",
test.spec, test.time, expected, actual)
}
}
}
func TestNext(t *testing.T) {
runs := []struct {
time, spec string
expected string
}{
// Simple cases
{"Mon Jul 9 14:45 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
{"Mon Jul 9 14:59 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
{"Mon Jul 9 14:59:59 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
// Wrap around hours
{"Mon Jul 9 15:45 2012", "0 20-35/15 * * * *", "Mon Jul 9 16:20 2012"},
// Wrap around days
{"Mon Jul 9 23:46 2012", "0 */15 * * * *", "Tue Jul 10 00:00 2012"},
{"Mon Jul 9 23:45 2012", "0 20-35/15 * * * *", "Tue Jul 10 00:20 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * * * *", "Tue Jul 10 00:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 1/2 * * *", "Tue Jul 10 01:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 10-12 * * *", "Tue Jul 10 10:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 1/2 */2 * *", "Thu Jul 11 01:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * 9-20 * *", "Wed Jul 10 00:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * 9-20 Jul *", "Wed Jul 10 00:20:15 2012"},
// Wrap around months
{"Mon Jul 9 23:35 2012", "0 0 0 9 Apr-Oct ?", "Thu Aug 9 00:00 2012"},
{"Mon Jul 9 23:35 2012", "0 0 0 */5 Apr,Aug,Oct Mon", "Tue Aug 1 00:00 2012"},
{"Mon Jul 9 23:35 2012", "0 0 0 */5 Oct Mon", "Mon Oct 1 00:00 2012"},
// Wrap around years
{"Mon Jul 9 23:35 2012", "0 0 0 * Feb Mon", "Mon Feb 4 00:00 2013"},
{"Mon Jul 9 23:35 2012", "0 0 0 * Feb Mon/2", "Fri Feb 1 00:00 2013"},
// Wrap around minute, hour, day, month, and year
{"Mon Dec 31 23:59:45 2012", "0 * * * * *", "Tue Jan 1 00:00:00 2013"},
// Leap year
{"Mon Jul 9 23:35 2012", "0 0 0 29 Feb ?", "Mon Feb 29 00:00 2016"},
// Daylight savings time 2am EST (-5) -> 3am EDT (-4)
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 30 2 11 Mar ?", "2013-03-11T02:30:00-0400"},
// hourly job
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T01:00:00-0500"},
{"2012-03-11T01:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T03:00:00-0400"},
{"2012-03-11T03:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T04:00:00-0400"},
{"2012-03-11T04:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T05:00:00-0400"},
// hourly job using CRON_TZ
{"2012-03-11T00:00:00-0500", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T01:00:00-0500"},
{"2012-03-11T01:00:00-0500", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T03:00:00-0400"},
{"2012-03-11T03:00:00-0400", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T04:00:00-0400"},
{"2012-03-11T04:00:00-0400", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T05:00:00-0400"},
// 1am nightly job
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-03-11T01:00:00-0500"},
{"2012-03-11T01:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-03-12T01:00:00-0400"},
// 2am nightly job (skipped)
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 2 * * ?", "2012-03-12T02:00:00-0400"},
// Daylight savings time 2am EDT (-4) => 1am EST (-5)
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 30 2 04 Nov ?", "2012-11-04T02:30:00-0500"},
{"2012-11-04T01:45:00-0400", "TZ=America/New_York 0 30 1 04 Nov ?", "2012-11-04T01:30:00-0500"},
// hourly job
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T01:00:00-0400"},
{"2012-11-04T01:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T01:00:00-0500"},
{"2012-11-04T01:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T02:00:00-0500"},
// 1am nightly job (runs twice)
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 1 * * ?", "2012-11-04T01:00:00-0400"},
{"2012-11-04T01:00:00-0400", "TZ=America/New_York 0 0 1 * * ?", "2012-11-04T01:00:00-0500"},
{"2012-11-04T01:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-11-05T01:00:00-0500"},
// 2am nightly job
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 2 * * ?", "2012-11-04T02:00:00-0500"},
{"2012-11-04T02:00:00-0500", "TZ=America/New_York 0 0 2 * * ?", "2012-11-05T02:00:00-0500"},
// 3am nightly job
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 3 * * ?", "2012-11-04T03:00:00-0500"},
{"2012-11-04T03:00:00-0500", "TZ=America/New_York 0 0 3 * * ?", "2012-11-05T03:00:00-0500"},
// hourly job
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 * * * ?", "2012-11-04T01:00:00-0400"},
{"TZ=America/New_York 2012-11-04T01:00:00-0400", "0 0 * * * ?", "2012-11-04T01:00:00-0500"},
{"TZ=America/New_York 2012-11-04T01:00:00-0500", "0 0 * * * ?", "2012-11-04T02:00:00-0500"},
// 1am nightly job (runs twice)
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 1 * * ?", "2012-11-04T01:00:00-0400"},
{"TZ=America/New_York 2012-11-04T01:00:00-0400", "0 0 1 * * ?", "2012-11-04T01:00:00-0500"},
{"TZ=America/New_York 2012-11-04T01:00:00-0500", "0 0 1 * * ?", "2012-11-05T01:00:00-0500"},
// 2am nightly job
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 2 * * ?", "2012-11-04T02:00:00-0500"},
{"TZ=America/New_York 2012-11-04T02:00:00-0500", "0 0 2 * * ?", "2012-11-05T02:00:00-0500"},
// 3am nightly job
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 3 * * ?", "2012-11-04T03:00:00-0500"},
{"TZ=America/New_York 2012-11-04T03:00:00-0500", "0 0 3 * * ?", "2012-11-05T03:00:00-0500"},
// Unsatisfiable
{"Mon Jul 9 23:35 2012", "0 0 0 30 Feb ?", ""},
{"Mon Jul 9 23:35 2012", "0 0 0 31 Apr ?", ""},
// Monthly job
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 3 3 * ?", "2012-12-03T03:00:00-0500"},
// Test the scenario of DST resulting in midnight not being a valid time.
// https://github.com/robfig/cron/issues/157
{"2018-10-17T05:00:00-0400", "TZ=America/Sao_Paulo 0 0 9 10 * ?", "2018-11-10T06:00:00-0500"},
{"2018-02-14T05:00:00-0500", "TZ=America/Sao_Paulo 0 0 9 22 * ?", "2018-02-22T07:00:00-0500"},
}
for _, c := range runs {
sched, err := secondParser.Parse(c.spec)
if err != nil {
t.Error(err)
continue
}
actual := sched.Next(getTime(c.time))
expected := getTime(c.expected)
if !actual.Equal(expected) {
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.spec, expected, actual)
}
}
}
func TestErrors(t *testing.T) {
invalidSpecs := []string{
"xyz",
"60 0 * * *",
"0 60 * * *",
"0 0 * * XYZ",
}
for _, spec := range invalidSpecs {
_, err := ParseStandard(spec)
if err == nil {
t.Error("expected an error parsing: ", spec)
}
}
}
func getTime(value string) time.Time {
if value == "" {
return time.Time{}
}
var location = time.Local
if strings.HasPrefix(value, "TZ=") {
parts := strings.Fields(value)
loc, err := time.LoadLocation(parts[0][len("TZ="):])
if err != nil {
panic("could not parse location:" + err.Error())
}
location = loc
value = parts[1]
}
var layouts = []string{
"Mon Jan 2 15:04 2006",
"Mon Jan 2 15:04:05 2006",
}
for _, layout := range layouts {
if t, err := time.ParseInLocation(layout, value, location); err == nil {
return t
}
}
if t, err := time.ParseInLocation("2006-01-02T15:04:05-0700", value, location); err == nil {
return t
}
panic("could not parse time value " + value)
}
func TestNextWithTz(t *testing.T) {
runs := []struct {
time, spec string
expected string
}{
// Failing tests
{"2016-01-03T13:09:03+0530", "14 14 * * *", "2016-01-03T14:14:00+0530"},
{"2016-01-03T04:09:03+0530", "14 14 * * ?", "2016-01-03T14:14:00+0530"},
// Passing tests
{"2016-01-03T14:09:03+0530", "14 14 * * *", "2016-01-03T14:14:00+0530"},
{"2016-01-03T14:00:00+0530", "14 14 * * ?", "2016-01-03T14:14:00+0530"},
}
for _, c := range runs {
sched, err := ParseStandard(c.spec)
if err != nil {
t.Error(err)
continue
}
actual := sched.Next(getTimeTZ(c.time))
expected := getTimeTZ(c.expected)
if !actual.Equal(expected) {
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.spec, expected, actual)
}
}
}
func getTimeTZ(value string) time.Time {
if value == "" {
return time.Time{}
}
t, err := time.Parse("Mon Jan 2 15:04 2006", value)
if err != nil {
t, err = time.Parse("Mon Jan 2 15:04:05 2006", value)
if err != nil {
t, err = time.Parse("2006-01-02T15:04:05-0700", value)
if err != nil {
panic(err)
}
}
}
return t
}
// https://github.com/robfig/cron/issues/144
func TestSlash0NoHang(t *testing.T) {
schedule := "TZ=America/New_York 15/0 * * * *"
_, err := ParseStandard(schedule)
if err == nil {
t.Error("expected an error on 0 increment")
}
}

507
plugin/email/README.md Normal file
View File

@@ -0,0 +1,507 @@
# Email Plugin
SMTP email sending functionality for self-hosted Memos instances.
## Overview
This plugin provides a simple, reliable email sending interface following industry-standard SMTP protocols. It's designed for self-hosted environments where instance administrators configure their own email service, similar to platforms like GitHub, GitLab, and Discourse.
## Features
- Standard SMTP protocol support
- TLS/STARTTLS and SSL/TLS encryption
- HTML and plain text emails
- Multiple recipients (To, Cc, Bcc)
- Synchronous and asynchronous sending
- Detailed error reporting with context
- Works with all major email providers
- Reply-To header support
- RFC 5322 compliant message formatting
## Quick Start
### 1. Configure SMTP Settings
```go
import "github.com/usememos/memos/plugin/email"
config := &email.Config{
SMTPHost: "smtp.gmail.com",
SMTPPort: 587,
SMTPUsername: "your-email@gmail.com",
SMTPPassword: "your-app-password",
FromEmail: "noreply@yourdomain.com",
FromName: "Memos",
UseTLS: true,
}
```
### 2. Create and Send Email
```go
message := &email.Message{
To: []string{"user@example.com"},
Subject: "Welcome to Memos!",
Body: "Thanks for signing up.",
IsHTML: false,
}
// Synchronous send (waits for result)
err := email.Send(config, message)
if err != nil {
log.Printf("Failed to send email: %v", err)
}
// Asynchronous send (returns immediately)
email.SendAsync(config, message)
```
## Provider Configuration
### Gmail
Requires an [App Password](https://support.google.com/accounts/answer/185833) (2FA must be enabled):
```go
config := &email.Config{
SMTPHost: "smtp.gmail.com",
SMTPPort: 587,
SMTPUsername: "your-email@gmail.com",
SMTPPassword: "your-16-char-app-password",
FromEmail: "your-email@gmail.com",
FromName: "Memos",
UseTLS: true,
}
```
**Alternative (SSL):**
```go
config := &email.Config{
SMTPHost: "smtp.gmail.com",
SMTPPort: 465,
SMTPUsername: "your-email@gmail.com",
SMTPPassword: "your-16-char-app-password",
FromEmail: "your-email@gmail.com",
FromName: "Memos",
UseSSL: true,
}
```
### SendGrid
```go
config := &email.Config{
SMTPHost: "smtp.sendgrid.net",
SMTPPort: 587,
SMTPUsername: "apikey",
SMTPPassword: "your-sendgrid-api-key",
FromEmail: "noreply@yourdomain.com",
FromName: "Memos",
UseTLS: true,
}
```
### AWS SES
```go
config := &email.Config{
SMTPHost: "email-smtp.us-east-1.amazonaws.com",
SMTPPort: 587,
SMTPUsername: "your-smtp-username",
SMTPPassword: "your-smtp-password",
FromEmail: "verified@yourdomain.com",
FromName: "Memos",
UseTLS: true,
}
```
**Note:** Replace `us-east-1` with your AWS region. Email must be verified in SES.
### Mailgun
```go
config := &email.Config{
SMTPHost: "smtp.mailgun.org",
SMTPPort: 587,
SMTPUsername: "postmaster@yourdomain.com",
SMTPPassword: "your-mailgun-smtp-password",
FromEmail: "noreply@yourdomain.com",
FromName: "Memos",
UseTLS: true,
}
```
### Self-Hosted SMTP (Postfix, Exim, etc.)
```go
config := &email.Config{
SMTPHost: "mail.yourdomain.com",
SMTPPort: 587,
SMTPUsername: "username",
SMTPPassword: "password",
FromEmail: "noreply@yourdomain.com",
FromName: "Memos",
UseTLS: true,
}
```
## HTML Emails
```go
message := &email.Message{
To: []string{"user@example.com"},
Subject: "Welcome to Memos!",
Body: `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
</head>
<body style="font-family: Arial, sans-serif;">
<h1 style="color: #333;">Welcome to Memos!</h1>
<p>We're excited to have you on board.</p>
<a href="https://yourdomain.com" style="background-color: #4CAF50; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">Get Started</a>
</body>
</html>
`,
IsHTML: true,
}
email.Send(config, message)
```
## Multiple Recipients
```go
message := &email.Message{
To: []string{"user1@example.com", "user2@example.com"},
Cc: []string{"manager@example.com"},
Bcc: []string{"admin@example.com"},
Subject: "Team Update",
Body: "Important team announcement...",
ReplyTo: "support@yourdomain.com",
}
email.Send(config, message)
```
## Testing
### Run Tests
```bash
# All tests
go test ./plugin/email/... -v
# With coverage
go test ./plugin/email/... -v -cover
# With race detector
go test ./plugin/email/... -race
```
### Manual Testing
Create a simple test program:
```go
package main
import (
"log"
"github.com/usememos/memos/plugin/email"
)
func main() {
config := &email.Config{
SMTPHost: "smtp.gmail.com",
SMTPPort: 587,
SMTPUsername: "your-email@gmail.com",
SMTPPassword: "your-app-password",
FromEmail: "your-email@gmail.com",
FromName: "Test",
UseTLS: true,
}
message := &email.Message{
To: []string{"recipient@example.com"},
Subject: "Test Email",
Body: "This is a test email from Memos email plugin.",
}
if err := email.Send(config, message); err != nil {
log.Fatalf("Failed to send email: %v", err)
}
log.Println("Email sent successfully!")
}
```
## Security Best Practices
### 1. Use TLS/SSL Encryption
Always enable encryption in production:
```go
// STARTTLS (port 587) - Recommended
config.UseTLS = true
// SSL/TLS (port 465)
config.UseSSL = true
```
### 2. Secure Credential Storage
Never hardcode credentials. Use environment variables:
```go
import "os"
config := &email.Config{
SMTPHost: os.Getenv("SMTP_HOST"),
SMTPPort: 587,
SMTPUsername: os.Getenv("SMTP_USERNAME"),
SMTPPassword: os.Getenv("SMTP_PASSWORD"),
FromEmail: os.Getenv("SMTP_FROM_EMAIL"),
UseTLS: true,
}
```
### 3. Use App-Specific Passwords
For Gmail and similar services, use app passwords instead of your main account password.
### 4. Validate and Sanitize Input
Always validate email addresses and sanitize content:
```go
// Validate before sending
if err := message.Validate(); err != nil {
return err
}
```
### 5. Implement Rate Limiting
Prevent abuse by limiting email sending:
```go
// Example using golang.org/x/time/rate
limiter := rate.NewLimiter(rate.Every(time.Second), 10) // 10 emails per second
if !limiter.Allow() {
return errors.New("rate limit exceeded")
}
```
### 6. Monitor and Log
Log email sending activity for security monitoring:
```go
if err := email.Send(config, message); err != nil {
slog.Error("Email send failed",
slog.String("recipient", message.To[0]),
slog.Any("error", err))
}
```
## Common Ports
| Port | Protocol | Security | Use Case |
|------|----------|----------|----------|
| **587** | SMTP + STARTTLS | Encrypted | **Recommended** for most providers |
| **465** | SMTP over SSL/TLS | Encrypted | Alternative secure option |
| **25** | SMTP | Unencrypted | Legacy, often blocked by ISPs |
| **2525** | SMTP + STARTTLS | Encrypted | Alternative when 587 is blocked |
**Port 587 (STARTTLS)** is the recommended standard for modern SMTP:
```go
config := &email.Config{
SMTPPort: 587,
UseTLS: true,
}
```
**Port 465 (SSL/TLS)** is the alternative:
```go
config := &email.Config{
SMTPPort: 465,
UseSSL: true,
}
```
## Error Handling
The package provides detailed, contextual errors:
```go
err := email.Send(config, message)
if err != nil {
// Error messages include context:
switch {
case strings.Contains(err.Error(), "invalid email configuration"):
// Configuration error (missing host, invalid port, etc.)
log.Printf("Configuration error: %v", err)
case strings.Contains(err.Error(), "invalid email message"):
// Message validation error (missing recipients, subject, body)
log.Printf("Message error: %v", err)
case strings.Contains(err.Error(), "authentication failed"):
// SMTP authentication failed (wrong credentials)
log.Printf("Auth error: %v", err)
case strings.Contains(err.Error(), "failed to connect"):
// Network/connection error
log.Printf("Connection error: %v", err)
case strings.Contains(err.Error(), "recipient rejected"):
// SMTP server rejected recipient
log.Printf("Recipient error: %v", err)
default:
log.Printf("Unknown error: %v", err)
}
}
```
### Common Error Messages
```
❌ "invalid email configuration: SMTP host is required"
→ Fix: Set config.SMTPHost
❌ "invalid email configuration: SMTP port must be between 1 and 65535"
→ Fix: Set valid config.SMTPPort (usually 587 or 465)
❌ "invalid email configuration: from email is required"
→ Fix: Set config.FromEmail
❌ "invalid email message: at least one recipient is required"
→ Fix: Set message.To with at least one email address
❌ "invalid email message: subject is required"
→ Fix: Set message.Subject
❌ "invalid email message: body is required"
→ Fix: Set message.Body
❌ "SMTP authentication failed"
→ Fix: Check credentials (username/password)
❌ "failed to connect to SMTP server"
→ Fix: Verify host/port, check firewall, ensure TLS/SSL settings match server
```
### Async Error Handling
For async sending, errors are logged automatically:
```go
email.SendAsync(config, message)
// Errors logged as:
// [WARN] Failed to send email asynchronously recipients=user@example.com error=...
```
## Dependencies
### Required
- **Go 1.25+**
- Standard library: `net/smtp`, `crypto/tls`
- `github.com/pkg/errors` - Error wrapping with context
### No External SMTP Libraries
This plugin uses Go's standard `net/smtp` library for maximum compatibility and minimal dependencies.
## API Reference
### Types
#### `Config`
```go
type Config struct {
SMTPHost string // SMTP server hostname
SMTPPort int // SMTP server port
SMTPUsername string // SMTP auth username
SMTPPassword string // SMTP auth password
FromEmail string // From email address
FromName string // From display name (optional)
UseTLS bool // Enable STARTTLS (port 587)
UseSSL bool // Enable SSL/TLS (port 465)
}
```
#### `Message`
```go
type Message struct {
To []string // Recipients
Cc []string // CC recipients (optional)
Bcc []string // BCC recipients (optional)
Subject string // Email subject
Body string // Email body (plain text or HTML)
IsHTML bool // true for HTML, false for plain text
ReplyTo string // Reply-To address (optional)
}
```
### Functions
#### `Send(config *Config, message *Message) error`
Sends an email synchronously. Blocks until email is sent or error occurs.
#### `SendAsync(config *Config, message *Message)`
Sends an email asynchronously in a goroutine. Returns immediately. Errors are logged.
#### `NewClient(config *Config) *Client`
Creates a new SMTP client for advanced usage.
#### `Client.Send(message *Message) error`
Sends email using the client's configuration.
## Architecture
```
plugin/email/
├── config.go # SMTP configuration types
├── message.go # Email message types and formatting
├── client.go # SMTP client implementation
├── email.go # High-level Send/SendAsync API
├── doc.go # Package documentation
└── *_test.go # Unit tests
```
## License
Part of the Memos project. See main repository for license details.
## Contributing
This plugin follows the Memos contribution guidelines. Please ensure:
1. All code is tested (TDD approach)
2. Tests pass: `go test ./plugin/email/... -v`
3. Code is formatted: `go fmt ./plugin/email/...`
4. No linting errors: `golangci-lint run ./plugin/email/...`
## Support
For issues and questions:
- Memos GitHub Issues: https://github.com/usememos/memos/issues
- Memos Documentation: https://usememos.com/docs
## Roadmap
Future enhancements may include:
- Email template system
- Attachment support
- Inline image embedding
- Email queuing system
- Delivery status tracking
- Bounce handling

143
plugin/email/client.go Normal file
View File

@@ -0,0 +1,143 @@
package email
import (
"crypto/tls"
"net/smtp"
"github.com/pkg/errors"
)
// Client represents an SMTP email client.
type Client struct {
config *Config
}
// NewClient creates a new email client with the given configuration.
func NewClient(config *Config) *Client {
return &Client{
config: config,
}
}
// validateConfig validates the client configuration.
func (c *Client) validateConfig() error {
if c.config == nil {
return errors.New("email configuration is required")
}
return c.config.Validate()
}
// createAuth creates an SMTP auth mechanism if credentials are provided.
func (c *Client) createAuth() smtp.Auth {
if c.config.SMTPUsername == "" && c.config.SMTPPassword == "" {
return nil
}
return smtp.PlainAuth("", c.config.SMTPUsername, c.config.SMTPPassword, c.config.SMTPHost)
}
// createTLSConfig creates a TLS configuration for secure connections.
func (c *Client) createTLSConfig() *tls.Config {
return &tls.Config{
ServerName: c.config.SMTPHost,
MinVersion: tls.VersionTLS12,
}
}
// Send sends an email message via SMTP.
func (c *Client) Send(message *Message) error {
// Validate configuration
if err := c.validateConfig(); err != nil {
return errors.Wrap(err, "invalid email configuration")
}
// Validate message
if message == nil {
return errors.New("message is required")
}
if err := message.Validate(); err != nil {
return errors.Wrap(err, "invalid email message")
}
// Format the message
body := message.Format(c.config.FromEmail, c.config.FromName)
// Get all recipients
recipients := message.GetAllRecipients()
// Create auth
auth := c.createAuth()
// Send based on encryption type
if c.config.UseSSL {
return c.sendWithSSL(auth, recipients, body)
}
return c.sendWithTLS(auth, recipients, body)
}
// sendWithTLS sends email using STARTTLS (port 587).
func (c *Client) sendWithTLS(auth smtp.Auth, recipients []string, body string) error {
serverAddr := c.config.GetServerAddress()
if c.config.UseTLS {
// Use STARTTLS
return smtp.SendMail(serverAddr, auth, c.config.FromEmail, recipients, []byte(body))
}
// Send without encryption (not recommended)
return smtp.SendMail(serverAddr, auth, c.config.FromEmail, recipients, []byte(body))
}
// sendWithSSL sends email using SSL/TLS (port 465).
func (c *Client) sendWithSSL(auth smtp.Auth, recipients []string, body string) error {
serverAddr := c.config.GetServerAddress()
// Create TLS connection
tlsConfig := c.createTLSConfig()
conn, err := tls.Dial("tcp", serverAddr, tlsConfig)
if err != nil {
return errors.Wrapf(err, "failed to connect to SMTP server with SSL: %s", serverAddr)
}
defer conn.Close()
// Create SMTP client
client, err := smtp.NewClient(conn, c.config.SMTPHost)
if err != nil {
return errors.Wrap(err, "failed to create SMTP client")
}
defer client.Quit()
// Authenticate
if auth != nil {
if err := client.Auth(auth); err != nil {
return errors.Wrap(err, "SMTP authentication failed")
}
}
// Set sender
if err := client.Mail(c.config.FromEmail); err != nil {
return errors.Wrap(err, "failed to set sender")
}
// Set recipients
for _, recipient := range recipients {
if err := client.Rcpt(recipient); err != nil {
return errors.Wrapf(err, "failed to set recipient: %s", recipient)
}
}
// Send message body
writer, err := client.Data()
if err != nil {
return errors.Wrap(err, "failed to send DATA command")
}
if _, err := writer.Write([]byte(body)); err != nil {
return errors.Wrap(err, "failed to write message body")
}
if err := writer.Close(); err != nil {
return errors.Wrap(err, "failed to close message writer")
}
return nil
}

121
plugin/email/client_test.go Normal file
View File

@@ -0,0 +1,121 @@
package email
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewClient(t *testing.T) {
config := &Config{
SMTPHost: "smtp.example.com",
SMTPPort: 587,
SMTPUsername: "user@example.com",
SMTPPassword: "password",
FromEmail: "noreply@example.com",
FromName: "Test App",
UseTLS: true,
}
client := NewClient(config)
assert.NotNil(t, client)
assert.Equal(t, config, client.config)
}
func TestClientValidateConfig(t *testing.T) {
tests := []struct {
name string
config *Config
wantErr bool
}{
{
name: "valid config",
config: &Config{
SMTPHost: "smtp.example.com",
SMTPPort: 587,
FromEmail: "test@example.com",
},
wantErr: false,
},
{
name: "nil config",
config: nil,
wantErr: true,
},
{
name: "invalid config",
config: &Config{
SMTPHost: "",
SMTPPort: 587,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := NewClient(tt.config)
err := client.validateConfig()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestClientSendValidation(t *testing.T) {
config := &Config{
SMTPHost: "smtp.example.com",
SMTPPort: 587,
FromEmail: "test@example.com",
}
client := NewClient(config)
tests := []struct {
name string
message *Message
wantErr bool
}{
{
name: "valid message",
message: &Message{
To: []string{"recipient@example.com"},
Subject: "Test",
Body: "Test body",
},
wantErr: false, // Will fail on actual send, but passes validation
},
{
name: "nil message",
message: nil,
wantErr: true,
},
{
name: "invalid message",
message: &Message{
To: []string{},
Subject: "Test",
Body: "Test",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := client.Send(tt.message)
// We expect validation errors for invalid messages
// For valid messages, we'll get connection errors (which is expected in tests)
if tt.wantErr {
assert.Error(t, err)
// Should fail validation before attempting connection
assert.NotContains(t, err.Error(), "dial")
}
// Note: We don't assert NoError for valid messages because
// we don't have a real SMTP server in tests
})
}
}

47
plugin/email/config.go Normal file
View File

@@ -0,0 +1,47 @@
package email
import (
"fmt"
"github.com/pkg/errors"
)
// Config represents the SMTP configuration for email sending.
// These settings should be provided by the self-hosted instance administrator.
type Config struct {
// SMTPHost is the SMTP server hostname (e.g., "smtp.gmail.com")
SMTPHost string
// SMTPPort is the SMTP server port (common: 587 for TLS, 465 for SSL, 25 for unencrypted)
SMTPPort int
// SMTPUsername is the SMTP authentication username (usually the email address)
SMTPUsername string
// SMTPPassword is the SMTP authentication password or app-specific password
SMTPPassword string
// FromEmail is the email address that will appear in the "From" field
FromEmail string
// FromName is the display name that will appear in the "From" field
FromName string
// UseTLS enables STARTTLS encryption (recommended for port 587)
UseTLS bool
// UseSSL enables SSL/TLS encryption (for port 465)
UseSSL bool
}
// Validate checks if the configuration is valid.
func (c *Config) Validate() error {
if c.SMTPHost == "" {
return errors.New("SMTP host is required")
}
if c.SMTPPort <= 0 || c.SMTPPort > 65535 {
return errors.New("SMTP port must be between 1 and 65535")
}
if c.FromEmail == "" {
return errors.New("from email is required")
}
return nil
}
// GetServerAddress returns the SMTP server address in the format "host:port".
func (c *Config) GetServerAddress() string {
return fmt.Sprintf("%s:%d", c.SMTPHost, c.SMTPPort)
}

View File

@@ -0,0 +1,80 @@
package email
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfigValidation(t *testing.T) {
tests := []struct {
name string
config *Config
wantErr bool
}{
{
name: "valid config",
config: &Config{
SMTPHost: "smtp.gmail.com",
SMTPPort: 587,
SMTPUsername: "user@example.com",
SMTPPassword: "password",
FromEmail: "noreply@example.com",
FromName: "Memos",
},
wantErr: false,
},
{
name: "missing host",
config: &Config{
SMTPPort: 587,
SMTPUsername: "user@example.com",
SMTPPassword: "password",
FromEmail: "noreply@example.com",
},
wantErr: true,
},
{
name: "invalid port",
config: &Config{
SMTPHost: "smtp.gmail.com",
SMTPPort: 0,
SMTPUsername: "user@example.com",
SMTPPassword: "password",
FromEmail: "noreply@example.com",
},
wantErr: true,
},
{
name: "missing from email",
config: &Config{
SMTPHost: "smtp.gmail.com",
SMTPPort: 587,
SMTPUsername: "user@example.com",
SMTPPassword: "password",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestConfigGetServerAddress(t *testing.T) {
config := &Config{
SMTPHost: "smtp.gmail.com",
SMTPPort: 587,
}
expected := "smtp.gmail.com:587"
assert.Equal(t, expected, config.GetServerAddress())
}

98
plugin/email/doc.go Normal file
View File

@@ -0,0 +1,98 @@
// Package email provides SMTP email sending functionality for self-hosted Memos instances.
//
// This package is designed for self-hosted environments where instance administrators
// configure their own SMTP servers. It follows industry-standard patterns used by
// platforms like GitHub, GitLab, and Discourse.
//
// # Configuration
//
// The package requires SMTP server configuration provided by the instance administrator:
//
// config := &email.Config{
// SMTPHost: "smtp.gmail.com",
// SMTPPort: 587,
// SMTPUsername: "your-email@gmail.com",
// SMTPPassword: "your-app-password",
// FromEmail: "noreply@yourdomain.com",
// FromName: "Memos Notifications",
// UseTLS: true,
// }
//
// # Common SMTP Settings
//
// Gmail (requires App Password):
// - Host: smtp.gmail.com
// - Port: 587 (TLS) or 465 (SSL)
// - Username: your-email@gmail.com
// - UseTLS: true (for port 587) or UseSSL: true (for port 465)
//
// SendGrid:
// - Host: smtp.sendgrid.net
// - Port: 587
// - Username: apikey
// - Password: your-sendgrid-api-key
// - UseTLS: true
//
// AWS SES:
// - Host: email-smtp.[region].amazonaws.com
// - Port: 587
// - Username: your-smtp-username
// - Password: your-smtp-password
// - UseTLS: true
//
// Mailgun:
// - Host: smtp.mailgun.org
// - Port: 587
// - Username: your-mailgun-smtp-username
// - Password: your-mailgun-smtp-password
// - UseTLS: true
//
// # Sending Email
//
// Synchronous (waits for completion):
//
// message := &email.Message{
// To: []string{"user@example.com"},
// Subject: "Welcome to Memos",
// Body: "Thank you for joining!",
// IsHTML: false,
// }
//
// err := email.Send(config, message)
// if err != nil {
// // Handle error
// }
//
// Asynchronous (returns immediately):
//
// email.SendAsync(config, message)
// // Errors are logged but not returned
//
// # HTML Email
//
// message := &email.Message{
// To: []string{"user@example.com"},
// Subject: "Welcome!",
// Body: "<html><body><h1>Welcome to Memos!</h1></body></html>",
// IsHTML: true,
// }
//
// # Security Considerations
//
// - Always use TLS (port 587) or SSL (port 465) for production
// - Store SMTP credentials securely (environment variables or secrets management)
// - Use app-specific passwords for services like Gmail
// - Validate and sanitize email content to prevent injection attacks
// - Rate limit email sending to prevent abuse
//
// # Error Handling
//
// The package returns descriptive errors for common issues:
// - Configuration validation errors (missing host, invalid port, etc.)
// - Message validation errors (missing recipients, subject, or body)
// - Connection errors (cannot reach SMTP server)
// - Authentication errors (invalid credentials)
// - SMTP protocol errors (recipient rejected, etc.)
//
// All errors are wrapped with context using github.com/pkg/errors for better debugging.
package email

43
plugin/email/email.go Normal file
View File

@@ -0,0 +1,43 @@
package email
import (
"log/slog"
"github.com/pkg/errors"
)
// Send sends an email synchronously.
// Returns an error if the email fails to send.
func Send(config *Config, message *Message) error {
if config == nil {
return errors.New("email configuration is required")
}
if message == nil {
return errors.New("email message is required")
}
client := NewClient(config)
return client.Send(message)
}
// SendAsync sends an email asynchronously.
// It spawns a new goroutine to handle the sending and does not wait for the response.
// Any errors are logged but not returned.
func SendAsync(config *Config, message *Message) {
go func() {
if err := Send(config, message); err != nil {
// Since we're in a goroutine, we can only log the error
recipients := ""
if message != nil && len(message.To) > 0 {
recipients = message.To[0]
if len(message.To) > 1 {
recipients += " and others"
}
}
slog.Warn("Failed to send email asynchronously",
slog.String("recipients", recipients),
slog.Any("error", err))
}
}()
}

127
plugin/email/email_test.go Normal file
View File

@@ -0,0 +1,127 @@
package email
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)
func TestSend(t *testing.T) {
config := &Config{
SMTPHost: "smtp.example.com",
SMTPPort: 587,
FromEmail: "test@example.com",
}
message := &Message{
To: []string{"recipient@example.com"},
Subject: "Test",
Body: "Test body",
}
// This will fail to connect (no real server), but should validate inputs
err := Send(config, message)
// We expect an error because there's no real SMTP server
// But it should be a connection error, not a validation error
assert.Error(t, err)
assert.Contains(t, err.Error(), "dial")
}
func TestSendValidation(t *testing.T) {
tests := []struct {
name string
config *Config
message *Message
wantErr bool
errMsg string
}{
{
name: "nil config",
config: nil,
message: &Message{To: []string{"test@example.com"}, Subject: "Test", Body: "Test"},
wantErr: true,
errMsg: "configuration is required",
},
{
name: "nil message",
config: &Config{SMTPHost: "smtp.example.com", SMTPPort: 587, FromEmail: "from@example.com"},
message: nil,
wantErr: true,
errMsg: "message is required",
},
{
name: "invalid config",
config: &Config{
SMTPHost: "",
SMTPPort: 587,
},
message: &Message{To: []string{"test@example.com"}, Subject: "Test", Body: "Test"},
wantErr: true,
errMsg: "invalid email configuration",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := Send(tt.config, tt.message)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
}
})
}
}
func TestSendAsync(t *testing.T) {
config := &Config{
SMTPHost: "smtp.example.com",
SMTPPort: 587,
FromEmail: "test@example.com",
}
message := &Message{
To: []string{"recipient@example.com"},
Subject: "Test Async",
Body: "Test async body",
}
// SendAsync should not block
start := time.Now()
SendAsync(config, message)
duration := time.Since(start)
// Should return almost immediately (< 100ms)
assert.Less(t, duration, 100*time.Millisecond)
// Give goroutine time to start
time.Sleep(50 * time.Millisecond)
}
func TestSendAsyncConcurrent(t *testing.T) {
config := &Config{
SMTPHost: "smtp.example.com",
SMTPPort: 587,
FromEmail: "test@example.com",
}
g := errgroup.Group{}
count := 5
for i := 0; i < count; i++ {
g.Go(func() error {
message := &Message{
To: []string{"recipient@example.com"},
Subject: "Concurrent Test",
Body: "Test body",
}
SendAsync(config, message)
return nil
})
}
if err := g.Wait(); err != nil {
t.Fatalf("SendAsync calls failed: %v", err)
}
}

91
plugin/email/message.go Normal file
View File

@@ -0,0 +1,91 @@
package email
import (
"errors"
"fmt"
"strings"
"time"
)
// Message represents an email message to be sent.
type Message struct {
To []string // Required: recipient email addresses
Cc []string // Optional: carbon copy recipients
Bcc []string // Optional: blind carbon copy recipients
Subject string // Required: email subject
Body string // Required: email body content
IsHTML bool // Whether the body is HTML (default: false for plain text)
ReplyTo string // Optional: reply-to address
}
// Validate checks that the message has all required fields.
func (m *Message) Validate() error {
if len(m.To) == 0 {
return errors.New("at least one recipient is required")
}
if m.Subject == "" {
return errors.New("subject is required")
}
if m.Body == "" {
return errors.New("body is required")
}
return nil
}
// Format creates an RFC 5322 formatted email message.
func (m *Message) Format(fromEmail, fromName string) string {
var sb strings.Builder
// From header
if fromName != "" {
sb.WriteString(fmt.Sprintf("From: %s <%s>\r\n", fromName, fromEmail))
} else {
sb.WriteString(fmt.Sprintf("From: %s\r\n", fromEmail))
}
// To header
sb.WriteString(fmt.Sprintf("To: %s\r\n", strings.Join(m.To, ", ")))
// Cc header (optional)
if len(m.Cc) > 0 {
sb.WriteString(fmt.Sprintf("Cc: %s\r\n", strings.Join(m.Cc, ", ")))
}
// Reply-To header (optional)
if m.ReplyTo != "" {
sb.WriteString(fmt.Sprintf("Reply-To: %s\r\n", m.ReplyTo))
}
// Subject header
sb.WriteString(fmt.Sprintf("Subject: %s\r\n", m.Subject))
// Date header (RFC 5322 format)
sb.WriteString(fmt.Sprintf("Date: %s\r\n", time.Now().Format(time.RFC1123Z)))
// MIME headers
sb.WriteString("MIME-Version: 1.0\r\n")
// Content-Type header
if m.IsHTML {
sb.WriteString("Content-Type: text/html; charset=utf-8\r\n")
} else {
sb.WriteString("Content-Type: text/plain; charset=utf-8\r\n")
}
// Empty line separating headers from body
sb.WriteString("\r\n")
// Body
sb.WriteString(m.Body)
return sb.String()
}
// GetAllRecipients returns all recipients (To, Cc, Bcc) as a single slice.
func (m *Message) GetAllRecipients() []string {
var recipients []string
recipients = append(recipients, m.To...)
recipients = append(recipients, m.Cc...)
recipients = append(recipients, m.Bcc...)
return recipients
}

View File

@@ -0,0 +1,181 @@
package email
import (
"strings"
"testing"
)
func TestMessageValidation(t *testing.T) {
tests := []struct {
name string
msg Message
wantErr bool
}{
{
name: "valid message",
msg: Message{
To: []string{"user@example.com"},
Subject: "Test Subject",
Body: "Test Body",
},
wantErr: false,
},
{
name: "no recipients",
msg: Message{
To: []string{},
Subject: "Test Subject",
Body: "Test Body",
},
wantErr: true,
},
{
name: "no subject",
msg: Message{
To: []string{"user@example.com"},
Subject: "",
Body: "Test Body",
},
wantErr: true,
},
{
name: "no body",
msg: Message{
To: []string{"user@example.com"},
Subject: "Test Subject",
Body: "",
},
wantErr: true,
},
{
name: "multiple recipients",
msg: Message{
To: []string{"user1@example.com", "user2@example.com"},
Cc: []string{"cc@example.com"},
Subject: "Test Subject",
Body: "Test Body",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.msg.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestMessageFormatPlainText(t *testing.T) {
msg := Message{
To: []string{"user@example.com"},
Subject: "Test Subject",
Body: "Test Body",
IsHTML: false,
}
formatted := msg.Format("sender@example.com", "Sender Name")
// Check required headers
if !strings.Contains(formatted, "From: Sender Name <sender@example.com>") {
t.Error("Missing or incorrect From header")
}
if !strings.Contains(formatted, "To: user@example.com") {
t.Error("Missing or incorrect To header")
}
if !strings.Contains(formatted, "Subject: Test Subject") {
t.Error("Missing or incorrect Subject header")
}
if !strings.Contains(formatted, "Content-Type: text/plain; charset=utf-8") {
t.Error("Missing or incorrect Content-Type header for plain text")
}
if !strings.Contains(formatted, "Test Body") {
t.Error("Missing message body")
}
}
func TestMessageFormatHTML(t *testing.T) {
msg := Message{
To: []string{"user@example.com"},
Subject: "Test Subject",
Body: "<html><body>Test Body</body></html>",
IsHTML: true,
}
formatted := msg.Format("sender@example.com", "Sender Name")
// Check HTML content-type
if !strings.Contains(formatted, "Content-Type: text/html; charset=utf-8") {
t.Error("Missing or incorrect Content-Type header for HTML")
}
if !strings.Contains(formatted, "<html><body>Test Body</body></html>") {
t.Error("Missing HTML body")
}
}
func TestMessageFormatMultipleRecipients(t *testing.T) {
msg := Message{
To: []string{"user1@example.com", "user2@example.com"},
Cc: []string{"cc1@example.com", "cc2@example.com"},
Bcc: []string{"bcc@example.com"},
Subject: "Test Subject",
Body: "Test Body",
ReplyTo: "reply@example.com",
}
formatted := msg.Format("sender@example.com", "Sender Name")
// Check To header formatting
if !strings.Contains(formatted, "To: user1@example.com, user2@example.com") {
t.Error("Missing or incorrect To header with multiple recipients")
}
// Check Cc header formatting
if !strings.Contains(formatted, "Cc: cc1@example.com, cc2@example.com") {
t.Error("Missing or incorrect Cc header")
}
// Bcc should NOT appear in the formatted message
if strings.Contains(formatted, "Bcc:") {
t.Error("Bcc header should not appear in formatted message")
}
// Check Reply-To header
if !strings.Contains(formatted, "Reply-To: reply@example.com") {
t.Error("Missing or incorrect Reply-To header")
}
}
func TestGetAllRecipients(t *testing.T) {
msg := Message{
To: []string{"user1@example.com", "user2@example.com"},
Cc: []string{"cc@example.com"},
Bcc: []string{"bcc@example.com"},
}
recipients := msg.GetAllRecipients()
// Should have all 4 recipients
if len(recipients) != 4 {
t.Errorf("GetAllRecipients() returned %d recipients, want 4", len(recipients))
}
// Check all recipients are present
expectedRecipients := map[string]bool{
"user1@example.com": true,
"user2@example.com": true,
"cc@example.com": true,
"bcc@example.com": true,
}
for _, recipient := range recipients {
if !expectedRecipients[recipient] {
t.Errorf("Unexpected recipient: %s", recipient)
}
delete(expectedRecipients, recipient)
}
if len(expectedRecipients) > 0 {
t.Error("Not all expected recipients were returned")
}
}

View File

@@ -0,0 +1,50 @@
# Maintaining the Memo Filter Engine
The engine is memo-specific; any future field or behavior changes must stay
consistent with the memo schema and store implementations. Use this guide when
extending or debugging the package.
## Adding a New Memo Field
1. **Update the schema**
- Add the field entry in `schema.go`.
- Define the backing column (`Column`), JSON path (if applicable), type, and
allowed operators.
- Include the CEL variable in `EnvOptions`.
2. **Adjust parser or renderer (if needed)**
- For non-scalar fields (JSON booleans, lists), add handling in
`parser.go` or extend the renderer helpers.
- Keep validation in the parser (e.g., reject unsupported operators).
3. **Write a golden test**
- Extend the dialect-specific memo filter tests under
`store/db/{sqlite,mysql,postgres}/memo_filter_test.go` with a case that
exercises the new field.
4. **Run `go test ./...`** to ensure the SQL output matches expectations across
all dialects.
## Supporting Dialect Nuances
- Centralize differences inside `render.go`. If a new dialect-specific behavior
emerges (e.g., JSON operators), add the logic there rather than leaking it
into store code.
- Use the renderer helpers (`jsonExtractExpr`, `jsonArrayExpr`, etc.) rather than
sprinkling ad-hoc SQL strings.
- When placeholders change, adjust `addArg` so that argument numbering stays in
sync with store queries.
## Debugging Tips
- **Parser errors** Most originate in `buildCondition` or schema validation.
Enable logging around `parser.go` when diagnosing unknown identifier/operator
messages.
- **Renderer output** Temporary printf/log statements in `renderCondition` help
identify which IR node produced unexpected SQL.
- **Store integration** Ensure drivers call `filter.DefaultEngine()` exactly once
per process; the singleton caches the parsed CEL environment.
## Testing Checklist
- `go test ./store/...` ensures all dialect tests consume the engine correctly.
- Add targeted unit tests whenever new IR nodes or renderer paths are introduced.
- When changing boolean or JSON handling, verify all three dialect test suites
(SQLite, MySQL, Postgres) to avoid regression.

63
plugin/filter/README.md Normal file
View File

@@ -0,0 +1,63 @@
# Memo Filter Engine
This package houses the memo-only filter engine that turns CEL expressions into
SQL fragments. The engine follows a three phase pipeline inspired by systems
such as Calcite or Prisma:
1. **Parsing** CEL expressions are parsed with `cel-go` and validated against
the memo-specific environment declared in `schema.go`. Only fields that
exist in the schema can surface in the filter.
2. **Normalization** the raw CEL AST is converted into an intermediate
representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of
conditions (logical operators, comparisons, list membership, etc.). This
step enforces schema rules (e.g. operator compatibility, type checks).
3. **Rendering** the renderer in `render.go` walks the IR and produces a SQL
fragment plus placeholder arguments tailored to a target dialect
(`sqlite`, `mysql`, or `postgres`). Dialect differences such as JSON access,
boolean semantics, placeholders, and `LIKE` vs `ILIKE` are encapsulated in
renderer helpers.
The entry point is `filter.DefaultEngine()` from `engine.go`. It lazily constructs
an `Engine` configured with the memo schema and exposes:
```go
engine, _ := filter.DefaultEngine()
stmt, _ := engine.CompileToStatement(ctx, `has_task_list && visibility == "PUBLIC"`, filter.RenderOptions{
Dialect: filter.DialectPostgres,
})
// stmt.SQL -> "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.visibility = $1)"
// stmt.Args -> ["PUBLIC"]
```
## Core Files
| File | Responsibility |
| ------------- | ------------------------------------------------------------------------------- |
| `schema.go` | Declares memo fields, their types, backing columns, CEL environment options |
| `ir.go` | IR node definitions used across the pipeline |
| `parser.go` | Converts CEL `Expr` into IR while applying schema validation |
| `render.go` | Translates IR into SQL, handling dialect-specific behavior |
| `engine.go` | Glue between the phases; exposes `Compile`, `CompileToStatement`, and `DefaultEngine` |
| `helpers.go` | Convenience helpers for store integration (appending conditions) |
## SQL Generation Notes
- **Placeholders** — `?` is used for SQLite/MySQL, `$n` for Postgres. The renderer
tracks offsets to compose queries with pre-existing arguments.
- **JSON Fields** — Memo metadata lives in `memo.payload`. The renderer handles
`JSON_EXTRACT`/`json_extract`/`->`/`->>` variations and boolean coercion.
- **Tag Operations** — `tag in [...]` and `"tag" in tags` become JSON array
predicates. SQLite uses `LIKE` patterns, MySQL uses `JSON_CONTAINS`, and
Postgres uses `@>`.
- **Boolean Flags** — Fields such as `has_task_list` render as `IS TRUE` equality
checks, or comparisons against `CAST('true' AS JSON)` depending on the dialect.
## Typical Integration
1. Fetch the engine with `filter.DefaultEngine()`.
2. Call `CompileToStatement` using the appropriate dialect enum.
3. Append the emitted SQL fragment/args to the existing `WHERE` clause.
4. Execute the resulting query through the store driver.
The `helpers.AppendConditions` helper encapsulates steps 23 when a driver needs
to process an array of filters.

191
plugin/filter/engine.go Normal file
View File

@@ -0,0 +1,191 @@
package filter
import (
"context"
"fmt"
"strings"
"sync"
"github.com/google/cel-go/cel"
"github.com/pkg/errors"
)
// Engine parses CEL filters into a dialect-agnostic condition tree.
type Engine struct {
schema Schema
env *cel.Env
}
// NewEngine builds a new Engine for the provided schema.
func NewEngine(schema Schema) (*Engine, error) {
env, err := cel.NewEnv(schema.EnvOptions...)
if err != nil {
return nil, errors.Wrap(err, "failed to create CEL environment")
}
return &Engine{
schema: schema,
env: env,
}, nil
}
// Program stores a compiled filter condition.
type Program struct {
schema Schema
condition Condition
}
// ConditionTree exposes the underlying condition tree.
func (p *Program) ConditionTree() Condition {
return p.condition
}
// Compile parses the filter string into an executable program.
func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
if strings.TrimSpace(filter) == "" {
return nil, errors.New("filter expression is empty")
}
filter = normalizeLegacyFilter(filter)
ast, issues := e.env.Compile(filter)
if issues != nil && issues.Err() != nil {
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
}
parsed, err := cel.AstToParsedExpr(ast)
if err != nil {
return nil, errors.Wrap(err, "failed to convert AST")
}
cond, err := buildCondition(parsed.GetExpr(), e.schema)
if err != nil {
return nil, err
}
return &Program{
schema: e.schema,
condition: cond,
}, nil
}
// CompileToStatement compiles and renders the filter in a single step.
func (e *Engine) CompileToStatement(ctx context.Context, filter string, opts RenderOptions) (Statement, error) {
program, err := e.Compile(ctx, filter)
if err != nil {
return Statement{}, err
}
return program.Render(opts)
}
// RenderOptions configure SQL rendering.
type RenderOptions struct {
Dialect DialectName
PlaceholderOffset int
DisableNullChecks bool
}
// Statement contains the rendered SQL fragment and its args.
type Statement struct {
SQL string
Args []any
}
// Render converts the program into a dialect-specific SQL fragment.
func (p *Program) Render(opts RenderOptions) (Statement, error) {
renderer := newRenderer(p.schema, opts)
return renderer.Render(p.condition)
}
var (
defaultOnce sync.Once
defaultInst *Engine
defaultErr error
defaultAttachmentOnce sync.Once
defaultAttachmentInst *Engine
defaultAttachmentErr error
)
// DefaultEngine returns the process-wide memo filter engine.
func DefaultEngine() (*Engine, error) {
defaultOnce.Do(func() {
defaultInst, defaultErr = NewEngine(NewSchema())
})
return defaultInst, defaultErr
}
// DefaultAttachmentEngine returns the process-wide attachment filter engine.
func DefaultAttachmentEngine() (*Engine, error) {
defaultAttachmentOnce.Do(func() {
defaultAttachmentInst, defaultAttachmentErr = NewEngine(NewAttachmentSchema())
})
return defaultAttachmentInst, defaultAttachmentErr
}
func normalizeLegacyFilter(expr string) string {
expr = rewriteNumericLogicalOperand(expr, "&&")
expr = rewriteNumericLogicalOperand(expr, "||")
return expr
}
func rewriteNumericLogicalOperand(expr, op string) string {
var builder strings.Builder
n := len(expr)
i := 0
var inQuote rune
for i < n {
ch := expr[i]
if inQuote != 0 {
builder.WriteByte(ch)
if ch == '\\' && i+1 < n {
builder.WriteByte(expr[i+1])
i += 2
continue
}
if ch == byte(inQuote) {
inQuote = 0
}
i++
continue
}
if ch == '\'' || ch == '"' {
inQuote = rune(ch)
builder.WriteByte(ch)
i++
continue
}
if strings.HasPrefix(expr[i:], op) {
builder.WriteString(op)
i += len(op)
// Preserve whitespace following the operator.
wsStart := i
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
i++
}
builder.WriteString(expr[wsStart:i])
signStart := i
if i < n && (expr[i] == '+' || expr[i] == '-') {
i++
}
for i < n && expr[i] >= '0' && expr[i] <= '9' {
i++
}
if i > signStart {
numLiteral := expr[signStart:i]
builder.WriteString(fmt.Sprintf("(%s != 0)", numLiteral))
} else {
builder.WriteString(expr[signStart:i])
}
continue
}
builder.WriteByte(ch)
i++
}
return builder.String()
}

25
plugin/filter/helpers.go Normal file
View File

@@ -0,0 +1,25 @@
package filter
import (
"context"
"fmt"
)
// AppendConditions compiles the provided filters and appends the resulting SQL fragments and args.
func AppendConditions(ctx context.Context, engine *Engine, filters []string, dialect DialectName, where *[]string, args *[]any) error {
for _, filterStr := range filters {
stmt, err := engine.CompileToStatement(ctx, filterStr, RenderOptions{
Dialect: dialect,
PlaceholderOffset: len(*args),
})
if err != nil {
return err
}
if stmt.SQL == "" {
continue
}
*where = append(*where, fmt.Sprintf("(%s)", stmt.SQL))
*args = append(*args, stmt.Args...)
}
return nil
}

159
plugin/filter/ir.go Normal file
View File

@@ -0,0 +1,159 @@
package filter
// Condition represents a boolean expression derived from the CEL filter.
type Condition interface {
isCondition()
}
// LogicalOperator enumerates the supported logical operators.
type LogicalOperator string
const (
LogicalAnd LogicalOperator = "AND"
LogicalOr LogicalOperator = "OR"
)
// LogicalCondition composes two conditions with a logical operator.
type LogicalCondition struct {
Operator LogicalOperator
Left Condition
Right Condition
}
func (*LogicalCondition) isCondition() {}
// NotCondition negates a child condition.
type NotCondition struct {
Expr Condition
}
func (*NotCondition) isCondition() {}
// FieldPredicateCondition asserts that a field evaluates to true.
type FieldPredicateCondition struct {
Field string
}
func (*FieldPredicateCondition) isCondition() {}
// ComparisonOperator lists supported comparison operators.
type ComparisonOperator string
const (
CompareEq ComparisonOperator = "="
CompareNeq ComparisonOperator = "!="
CompareLt ComparisonOperator = "<"
CompareLte ComparisonOperator = "<="
CompareGt ComparisonOperator = ">"
CompareGte ComparisonOperator = ">="
)
// ComparisonCondition represents a binary comparison.
type ComparisonCondition struct {
Left ValueExpr
Operator ComparisonOperator
Right ValueExpr
}
func (*ComparisonCondition) isCondition() {}
// InCondition represents an IN predicate with literal list values.
type InCondition struct {
Left ValueExpr
Values []ValueExpr
}
func (*InCondition) isCondition() {}
// ElementInCondition represents the CEL syntax `"value" in field`.
type ElementInCondition struct {
Element ValueExpr
Field string
}
func (*ElementInCondition) isCondition() {}
// ContainsCondition models the <field>.contains(<value>) call.
type ContainsCondition struct {
Field string
Value string
}
func (*ContainsCondition) isCondition() {}
// ConstantCondition captures a literal boolean outcome.
type ConstantCondition struct {
Value bool
}
func (*ConstantCondition) isCondition() {}
// ValueExpr models arithmetic or scalar expressions whose result feeds a comparison.
type ValueExpr interface {
isValueExpr()
}
// FieldRef references a named schema field.
type FieldRef struct {
Name string
}
func (*FieldRef) isValueExpr() {}
// LiteralValue holds a literal scalar.
type LiteralValue struct {
Value interface{}
}
func (*LiteralValue) isValueExpr() {}
// FunctionValue captures simple function calls like size(tags).
type FunctionValue struct {
Name string
Args []ValueExpr
}
func (*FunctionValue) isValueExpr() {}
// ListComprehensionCondition represents CEL macros like exists(), all(), filter().
type ListComprehensionCondition struct {
Kind ComprehensionKind
Field string // The list field to iterate over (e.g., "tags")
IterVar string // The iteration variable name (e.g., "t")
Predicate PredicateExpr // The predicate to evaluate for each element
}
func (*ListComprehensionCondition) isCondition() {}
// ComprehensionKind enumerates the types of list comprehensions.
type ComprehensionKind string
const (
ComprehensionExists ComprehensionKind = "exists"
)
// PredicateExpr represents predicates used in comprehensions.
type PredicateExpr interface {
isPredicateExpr()
}
// StartsWithPredicate represents t.startsWith("prefix").
type StartsWithPredicate struct {
Prefix string
}
func (*StartsWithPredicate) isPredicateExpr() {}
// EndsWithPredicate represents t.endsWith("suffix").
type EndsWithPredicate struct {
Suffix string
}
func (*EndsWithPredicate) isPredicateExpr() {}
// ContainsPredicate represents t.contains("substring").
type ContainsPredicate struct {
Substring string
}
func (*ContainsPredicate) isPredicateExpr() {}

586
plugin/filter/parser.go Normal file
View File

@@ -0,0 +1,586 @@
package filter
import (
"time"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
switch v := expr.ExprKind.(type) {
case *exprv1.Expr_CallExpr:
return buildCallCondition(v.CallExpr, schema)
case *exprv1.Expr_ConstExpr:
val, err := getConstValue(expr)
if err != nil {
return nil, err
}
switch v := val.(type) {
case bool:
return &ConstantCondition{Value: v}, nil
case int64:
return &ConstantCondition{Value: v != 0}, nil
case float64:
return &ConstantCondition{Value: v != 0}, nil
default:
return nil, errors.New("filter must evaluate to a boolean value")
}
case *exprv1.Expr_IdentExpr:
name := v.IdentExpr.GetName()
field, ok := schema.Field(name)
if !ok {
return nil, errors.Errorf("unknown identifier %q", name)
}
if field.Type != FieldTypeBool {
return nil, errors.Errorf("identifier %q is not boolean", name)
}
return &FieldPredicateCondition{Field: name}, nil
case *exprv1.Expr_ComprehensionExpr:
return buildComprehensionCondition(v.ComprehensionExpr, schema)
default:
return nil, errors.New("unsupported top-level expression")
}
}
func buildCallCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
switch call.Function {
case "_&&_":
if len(call.Args) != 2 {
return nil, errors.New("logical AND expects two arguments")
}
left, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildCondition(call.Args[1], schema)
if err != nil {
return nil, err
}
return &LogicalCondition{
Operator: LogicalAnd,
Left: left,
Right: right,
}, nil
case "_||_":
if len(call.Args) != 2 {
return nil, errors.New("logical OR expects two arguments")
}
left, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildCondition(call.Args[1], schema)
if err != nil {
return nil, err
}
return &LogicalCondition{
Operator: LogicalOr,
Left: left,
Right: right,
}, nil
case "!_":
if len(call.Args) != 1 {
return nil, errors.New("logical NOT expects one argument")
}
child, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
return &NotCondition{Expr: child}, nil
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
return buildComparisonCondition(call, schema)
case "@in":
return buildInCondition(call, schema)
case "contains":
return buildContainsCondition(call, schema)
default:
val, ok, err := evaluateBool(call)
if err != nil {
return nil, err
}
if ok {
return &ConstantCondition{Value: val}, nil
}
return nil, errors.Errorf("unsupported call expression %q", call.Function)
}
}
func buildComparisonCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if len(call.Args) != 2 {
return nil, errors.New("comparison expects two arguments")
}
op, err := toComparisonOperator(call.Function)
if err != nil {
return nil, err
}
left, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildValueExpr(call.Args[1], schema)
if err != nil {
return nil, err
}
// If the left side is a field, validate allowed operators.
if field, ok := left.(*FieldRef); ok {
def, exists := schema.Field(field.Name)
if !exists {
return nil, errors.Errorf("unknown identifier %q", field.Name)
}
if def.Kind == FieldKindVirtualAlias {
def, exists = schema.ResolveAlias(field.Name)
if !exists {
return nil, errors.Errorf("invalid alias %q", field.Name)
}
}
if def.AllowedComparisonOps != nil {
if _, allowed := def.AllowedComparisonOps[op]; !allowed {
return nil, errors.Errorf("operator %s not allowed for field %q", op, field.Name)
}
}
}
return &ComparisonCondition{
Left: left,
Operator: op,
Right: right,
}, nil
}
func buildInCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if len(call.Args) != 2 {
return nil, errors.New("in operator expects two arguments")
}
// Handle identifier in list syntax.
if identName, err := getIdentName(call.Args[0]); err == nil {
if field, ok := schema.Field(identName); ok && field.Kind == FieldKindVirtualAlias {
if _, aliasOk := schema.ResolveAlias(identName); !aliasOk {
return nil, errors.Errorf("invalid alias %q", identName)
}
} else if !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
if listExpr := call.Args[1].GetListExpr(); listExpr != nil {
values := make([]ValueExpr, 0, len(listExpr.Elements))
for _, element := range listExpr.Elements {
value, err := buildValueExpr(element, schema)
if err != nil {
return nil, err
}
values = append(values, value)
}
return &InCondition{
Left: &FieldRef{Name: identName},
Values: values,
}, nil
}
}
// Handle "value in identifier" syntax.
if identName, err := getIdentName(call.Args[1]); err == nil {
if _, ok := schema.Field(identName); !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
element, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
return &ElementInCondition{
Element: element,
Field: identName,
}, nil
}
return nil, errors.New("invalid use of in operator")
}
func buildContainsCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if call.Target == nil {
return nil, errors.New("contains requires a target")
}
targetName, err := getIdentName(call.Target)
if err != nil {
return nil, err
}
field, ok := schema.Field(targetName)
if !ok {
return nil, errors.Errorf("unknown identifier %q", targetName)
}
if !field.SupportsContains {
return nil, errors.Errorf("identifier %q does not support contains()", targetName)
}
if len(call.Args) != 1 {
return nil, errors.New("contains expects exactly one argument")
}
value, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "contains only supports literal arguments")
}
str, ok := value.(string)
if !ok {
return nil, errors.New("contains argument must be a string")
}
return &ContainsCondition{
Field: targetName,
Value: str,
}, nil
}
func buildValueExpr(expr *exprv1.Expr, schema Schema) (ValueExpr, error) {
if identName, err := getIdentName(expr); err == nil {
if _, ok := schema.Field(identName); !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
return &FieldRef{Name: identName}, nil
}
if literal, err := getConstValue(expr); err == nil {
return &LiteralValue{Value: literal}, nil
}
if value, ok, err := evaluateNumeric(expr); err != nil {
return nil, err
} else if ok {
return &LiteralValue{Value: value}, nil
}
if boolVal, ok, err := evaluateBoolExpr(expr); err != nil {
return nil, err
} else if ok {
return &LiteralValue{Value: boolVal}, nil
}
if call := expr.GetCallExpr(); call != nil {
switch call.Function {
case "size":
if len(call.Args) != 1 {
return nil, errors.New("size() expects one argument")
}
arg, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
return &FunctionValue{
Name: "size",
Args: []ValueExpr{arg},
}, nil
case "now":
return &LiteralValue{Value: timeNowUnix()}, nil
case "_+_", "_-_", "_*_":
value, ok, err := evaluateNumeric(expr)
if err != nil {
return nil, err
}
if ok {
return &LiteralValue{Value: value}, nil
}
default:
// Fall through to error return below
}
}
return nil, errors.New("unsupported value expression")
}
func toComparisonOperator(fn string) (ComparisonOperator, error) {
switch fn {
case "_==_":
return CompareEq, nil
case "_!=_":
return CompareNeq, nil
case "_<_":
return CompareLt, nil
case "_>_":
return CompareGt, nil
case "_<=_":
return CompareLte, nil
case "_>=_":
return CompareGte, nil
default:
return "", errors.Errorf("unsupported comparison operator %q", fn)
}
}
func getIdentName(expr *exprv1.Expr) (string, error) {
if ident := expr.GetIdentExpr(); ident != nil {
return ident.GetName(), nil
}
return "", errors.New("expression is not an identifier")
}
func getConstValue(expr *exprv1.Expr) (interface{}, error) {
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
if !ok {
return nil, errors.New("expression is not a literal")
}
switch x := v.ConstExpr.ConstantKind.(type) {
case *exprv1.Constant_StringValue:
return v.ConstExpr.GetStringValue(), nil
case *exprv1.Constant_Int64Value:
return v.ConstExpr.GetInt64Value(), nil
case *exprv1.Constant_Uint64Value:
return int64(v.ConstExpr.GetUint64Value()), nil
case *exprv1.Constant_DoubleValue:
return v.ConstExpr.GetDoubleValue(), nil
case *exprv1.Constant_BoolValue:
return v.ConstExpr.GetBoolValue(), nil
case *exprv1.Constant_NullValue:
return nil, nil
default:
return nil, errors.Errorf("unsupported constant %T", x)
}
}
func evaluateBool(call *exprv1.Expr_Call) (bool, bool, error) {
val, ok, err := evaluateBoolExpr(&exprv1.Expr{ExprKind: &exprv1.Expr_CallExpr{CallExpr: call}})
return val, ok, err
}
func evaluateBoolExpr(expr *exprv1.Expr) (bool, bool, error) {
if literal, err := getConstValue(expr); err == nil {
if b, ok := literal.(bool); ok {
return b, true, nil
}
return false, false, nil
}
if call := expr.GetCallExpr(); call != nil && call.Function == "!_" {
if len(call.Args) != 1 {
return false, false, errors.New("NOT expects exactly one argument")
}
val, ok, err := evaluateBoolExpr(call.Args[0])
if err != nil || !ok {
return false, false, err
}
return !val, true, nil
}
return false, false, nil
}
func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) {
if literal, err := getConstValue(expr); err == nil {
switch v := literal.(type) {
case int64:
return v, true, nil
case float64:
return int64(v), true, nil
}
return 0, false, nil
}
call := expr.GetCallExpr()
if call == nil {
return 0, false, nil
}
switch call.Function {
case "now":
return timeNowUnix(), true, nil
case "_+_", "_-_", "_*_":
if len(call.Args) != 2 {
return 0, false, errors.New("arithmetic requires two arguments")
}
left, ok, err := evaluateNumeric(call.Args[0])
if err != nil {
return 0, false, err
}
if !ok {
return 0, false, nil
}
right, ok, err := evaluateNumeric(call.Args[1])
if err != nil {
return 0, false, err
}
if !ok {
return 0, false, nil
}
switch call.Function {
case "_+_":
return left + right, true, nil
case "_-_":
return left - right, true, nil
case "_*_":
return left * right, true, nil
default:
return 0, false, errors.Errorf("unsupported arithmetic operator %q", call.Function)
}
default:
return 0, false, nil
}
}
func timeNowUnix() int64 {
return time.Now().Unix()
}
// buildComprehensionCondition handles CEL comprehension expressions (exists, all, etc.).
func buildComprehensionCondition(comp *exprv1.Expr_Comprehension, schema Schema) (Condition, error) {
// Determine the comprehension kind by examining the loop initialization and step
kind, err := detectComprehensionKind(comp)
if err != nil {
return nil, err
}
// Get the field being iterated over
iterRangeIdent := comp.IterRange.GetIdentExpr()
if iterRangeIdent == nil {
return nil, errors.New("comprehension range must be a field identifier")
}
fieldName := iterRangeIdent.GetName()
// Validate the field
field, ok := schema.Field(fieldName)
if !ok {
return nil, errors.Errorf("unknown field %q in comprehension", fieldName)
}
if field.Kind != FieldKindJSONList {
return nil, errors.Errorf("field %q does not support comprehension (must be a list)", fieldName)
}
// Extract the predicate from the loop step
predicate, err := extractPredicate(comp, schema)
if err != nil {
return nil, err
}
return &ListComprehensionCondition{
Kind: kind,
Field: fieldName,
IterVar: comp.IterVar,
Predicate: predicate,
}, nil
}
// detectComprehensionKind determines if this is an exists() macro.
// Only exists() is currently supported.
func detectComprehensionKind(comp *exprv1.Expr_Comprehension) (ComprehensionKind, error) {
// Check the accumulator initialization
accuInit := comp.AccuInit.GetConstExpr()
if accuInit == nil {
return "", errors.New("comprehension accumulator must be initialized with a constant")
}
// exists() starts with false and uses OR (||) in loop step
if !accuInit.GetBoolValue() {
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_||_" {
return ComprehensionExists, nil
}
}
// all() starts with true and uses AND (&&) - not supported
if accuInit.GetBoolValue() {
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_&&_" {
return "", errors.New("all() comprehension is not supported; use exists() instead")
}
}
return "", errors.New("unsupported comprehension type; only exists() is supported")
}
// extractPredicate extracts the predicate expression from the comprehension loop step.
func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, error) {
// The loop step is: @result || predicate(t) for exists
// or: @result && predicate(t) for all
step := comp.LoopStep.GetCallExpr()
if step == nil {
return nil, errors.New("comprehension loop step must be a call expression")
}
if len(step.Args) != 2 {
return nil, errors.New("comprehension loop step must have two arguments")
}
// The predicate is the second argument
predicateExpr := step.Args[1]
predicateCall := predicateExpr.GetCallExpr()
if predicateCall == nil {
return nil, errors.New("comprehension predicate must be a function call")
}
// Handle different predicate functions
switch predicateCall.Function {
case "startsWith":
return buildStartsWithPredicate(predicateCall, comp.IterVar)
case "endsWith":
return buildEndsWithPredicate(predicateCall, comp.IterVar)
case "contains":
return buildContainsPredicate(predicateCall, comp.IterVar)
default:
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
}
}
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").
func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
// Verify the target is the iteration variable
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
return nil, errors.Errorf("startsWith target must be the iteration variable %q", iterVar)
}
if len(call.Args) != 1 {
return nil, errors.New("startsWith expects exactly one argument")
}
prefix, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "startsWith argument must be a constant string")
}
prefixStr, ok := prefix.(string)
if !ok {
return nil, errors.New("startsWith argument must be a string")
}
return &StartsWithPredicate{Prefix: prefixStr}, nil
}
// buildEndsWithPredicate extracts the pattern from t.endsWith("suffix").
func buildEndsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
return nil, errors.Errorf("endsWith target must be the iteration variable %q", iterVar)
}
if len(call.Args) != 1 {
return nil, errors.New("endsWith expects exactly one argument")
}
suffix, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "endsWith argument must be a constant string")
}
suffixStr, ok := suffix.(string)
if !ok {
return nil, errors.New("endsWith argument must be a string")
}
return &EndsWithPredicate{Suffix: suffixStr}, nil
}
// buildContainsPredicate extracts the pattern from t.contains("substring").
func buildContainsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
return nil, errors.Errorf("contains target must be the iteration variable %q", iterVar)
}
if len(call.Args) != 1 {
return nil, errors.New("contains expects exactly one argument")
}
substring, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "contains argument must be a constant string")
}
substringStr, ok := substring.(string)
if !ok {
return nil, errors.New("contains argument must be a string")
}
return &ContainsPredicate{Substring: substringStr}, nil
}

748
plugin/filter/render.go Normal file
View File

@@ -0,0 +1,748 @@
package filter
import (
"fmt"
"strings"
"github.com/pkg/errors"
)
type renderer struct {
schema Schema
dialect DialectName
placeholderOffset int
placeholderCounter int
args []any
}
type renderResult struct {
sql string
trivial bool
unsatisfiable bool
}
func newRenderer(schema Schema, opts RenderOptions) *renderer {
return &renderer{
schema: schema,
dialect: opts.Dialect,
placeholderOffset: opts.PlaceholderOffset,
}
}
func (r *renderer) Render(cond Condition) (Statement, error) {
result, err := r.renderCondition(cond)
if err != nil {
return Statement{}, err
}
args := r.args
if args == nil {
args = []any{}
}
switch {
case result.unsatisfiable:
return Statement{
SQL: "1 = 0",
Args: args,
}, nil
case result.trivial:
return Statement{
SQL: "",
Args: args,
}, nil
default:
return Statement{
SQL: result.sql,
Args: args,
}, nil
}
}
func (r *renderer) renderCondition(cond Condition) (renderResult, error) {
switch c := cond.(type) {
case *LogicalCondition:
return r.renderLogicalCondition(c)
case *NotCondition:
return r.renderNotCondition(c)
case *FieldPredicateCondition:
return r.renderFieldPredicate(c)
case *ComparisonCondition:
return r.renderComparison(c)
case *InCondition:
return r.renderInCondition(c)
case *ElementInCondition:
return r.renderElementInCondition(c)
case *ContainsCondition:
return r.renderContainsCondition(c)
case *ListComprehensionCondition:
return r.renderListComprehension(c)
case *ConstantCondition:
if c.Value {
return renderResult{trivial: true}, nil
}
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
default:
return renderResult{}, errors.Errorf("unsupported condition type %T", c)
}
}
func (r *renderer) renderLogicalCondition(cond *LogicalCondition) (renderResult, error) {
left, err := r.renderCondition(cond.Left)
if err != nil {
return renderResult{}, err
}
right, err := r.renderCondition(cond.Right)
if err != nil {
return renderResult{}, err
}
switch cond.Operator {
case LogicalAnd:
return combineAnd(left, right), nil
case LogicalOr:
return combineOr(left, right), nil
default:
return renderResult{}, errors.Errorf("unsupported logical operator %s", cond.Operator)
}
}
func (r *renderer) renderNotCondition(cond *NotCondition) (renderResult, error) {
child, err := r.renderCondition(cond.Expr)
if err != nil {
return renderResult{}, err
}
if child.trivial {
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
}
if child.unsatisfiable {
return renderResult{trivial: true}, nil
}
return renderResult{
sql: fmt.Sprintf("NOT (%s)", child.sql),
}, nil
}
func (r *renderer) renderFieldPredicate(cond *FieldPredicateCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
switch field.Kind {
case FieldKindBoolColumn:
column := qualifyColumn(r.dialect, field.Column)
return renderResult{
sql: fmt.Sprintf("%s IS TRUE", column),
}, nil
case FieldKindJSONBool:
sql, err := r.jsonBoolPredicate(field)
if err != nil {
return renderResult{}, err
}
return renderResult{sql: sql}, nil
default:
return renderResult{}, errors.Errorf("field %q cannot be used as a predicate", cond.Field)
}
}
func (r *renderer) renderComparison(cond *ComparisonCondition) (renderResult, error) {
switch left := cond.Left.(type) {
case *FieldRef:
field, ok := r.schema.Field(left.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", left.Name)
}
switch field.Kind {
case FieldKindBoolColumn:
return r.renderBoolColumnComparison(field, cond.Operator, cond.Right)
case FieldKindJSONBool:
return r.renderJSONBoolComparison(field, cond.Operator, cond.Right)
case FieldKindScalar:
return r.renderScalarComparison(field, cond.Operator, cond.Right)
default:
return renderResult{}, errors.Errorf("field %q does not support comparison", field.Name)
}
case *FunctionValue:
return r.renderFunctionComparison(left, cond.Operator, cond.Right)
default:
return renderResult{}, errors.New("comparison must start with a field reference or supported function")
}
}
func (r *renderer) renderFunctionComparison(fn *FunctionValue, op ComparisonOperator, right ValueExpr) (renderResult, error) {
if fn.Name != "size" {
return renderResult{}, errors.Errorf("unsupported function %s in comparison", fn.Name)
}
if len(fn.Args) != 1 {
return renderResult{}, errors.New("size() expects one argument")
}
fieldArg, ok := fn.Args[0].(*FieldRef)
if !ok {
return renderResult{}, errors.New("size() argument must be a field")
}
field, ok := r.schema.Field(fieldArg.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", fieldArg.Name)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("size() only supports tag lists, got %q", field.Name)
}
value, err := expectNumericLiteral(right)
if err != nil {
return renderResult{}, err
}
expr := jsonArrayLengthExpr(r.dialect, field)
placeholder := r.addArg(value)
return renderResult{
sql: fmt.Sprintf("%s %s %s", expr, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderScalarComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
lit, err := expectLiteral(right)
if err != nil {
return renderResult{}, err
}
columnExpr := field.columnExpr(r.dialect)
if lit == nil {
switch op {
case CompareEq:
return renderResult{sql: fmt.Sprintf("%s IS NULL", columnExpr)}, nil
case CompareNeq:
return renderResult{sql: fmt.Sprintf("%s IS NOT NULL", columnExpr)}, nil
default:
return renderResult{}, errors.Errorf("operator %s not supported for null comparison", op)
}
}
placeholder := ""
switch field.Type {
case FieldTypeString:
value, ok := lit.(string)
if !ok {
return renderResult{}, errors.Errorf("field %q expects string value", field.Name)
}
placeholder = r.addArg(value)
case FieldTypeInt, FieldTypeTimestamp:
num, err := toInt64(lit)
if err != nil {
return renderResult{}, errors.Wrapf(err, "field %q expects integer value", field.Name)
}
placeholder = r.addArg(num)
default:
return renderResult{}, errors.Errorf("unsupported data type %q for field %s", field.Type, field.Name)
}
return renderResult{
sql: fmt.Sprintf("%s %s %s", columnExpr, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderBoolColumnComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
value, err := expectBool(right)
if err != nil {
return renderResult{}, err
}
placeholder := r.addBoolArg(value)
column := qualifyColumn(r.dialect, field.Column)
return renderResult{
sql: fmt.Sprintf("%s %s %s", column, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderJSONBoolComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
value, err := expectBool(right)
if err != nil {
return renderResult{}, err
}
jsonExpr := jsonExtractExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite:
switch op {
case CompareEq:
if field.Name == "has_task_list" {
target := "0"
if value {
target = "1"
}
return renderResult{sql: fmt.Sprintf("%s = %s", jsonExpr, target)}, nil
}
if value {
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
}
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
case CompareNeq:
if field.Name == "has_task_list" {
target := "0"
if value {
target = "1"
}
return renderResult{sql: fmt.Sprintf("%s != %s", jsonExpr, target)}, nil
}
if value {
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
}
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
default:
return renderResult{}, errors.Errorf("operator %s not supported for boolean JSON field", op)
}
case DialectMySQL:
boolStr := "false"
if value {
boolStr = "true"
}
return renderResult{
sql: fmt.Sprintf("%s %s CAST('%s' AS JSON)", jsonExpr, sqlOperator(op), boolStr),
}, nil
case DialectPostgres:
placeholder := r.addArg(value)
return renderResult{
sql: fmt.Sprintf("(%s)::boolean %s %s", jsonExpr, sqlOperator(op), placeholder),
}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func (r *renderer) renderInCondition(cond *InCondition) (renderResult, error) {
fieldRef, ok := cond.Left.(*FieldRef)
if !ok {
return renderResult{}, errors.New("IN operator requires a field on the left-hand side")
}
if fieldRef.Name == "tag" {
return r.renderTagInList(cond.Values)
}
field, ok := r.schema.Field(fieldRef.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", fieldRef.Name)
}
if field.Kind != FieldKindScalar {
return renderResult{}, errors.Errorf("field %q does not support IN()", fieldRef.Name)
}
return r.renderScalarInCondition(field, cond.Values)
}
func (r *renderer) renderTagInList(values []ValueExpr) (renderResult, error) {
field, ok := r.schema.ResolveAlias("tag")
if !ok {
return renderResult{}, errors.New("tag attribute is not configured")
}
conditions := make([]string, 0, len(values))
for _, v := range values {
lit, err := expectLiteral(v)
if err != nil {
return renderResult{}, err
}
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.New("tags must be compared with string literals")
}
switch r.dialect {
case DialectSQLite:
// Support hierarchical tags: match exact tag OR tags with this prefix (e.g., "book" matches "book" and "book/something")
exactMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
prefixMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
conditions = append(conditions, expr)
case DialectMySQL:
// Support hierarchical tags: match exact tag OR tags with this prefix
exactMatch := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
prefixMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
conditions = append(conditions, expr)
case DialectPostgres:
// Support hierarchical tags: match exact tag OR tags with this prefix
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
conditions = append(conditions, expr)
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
if len(conditions) == 1 {
return renderResult{sql: conditions[0]}, nil
}
return renderResult{
sql: fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")),
}, nil
}
func (r *renderer) renderElementInCondition(cond *ElementInCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("field %q is not a tag list", cond.Field)
}
lit, err := expectLiteral(cond.Element)
if err != nil {
return renderResult{}, err
}
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.New("tags membership requires string literal")
}
switch r.dialect {
case DialectSQLite:
sql := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
return renderResult{sql: sql}, nil
case DialectMySQL:
sql := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
return renderResult{sql: sql}, nil
case DialectPostgres:
sql := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
return renderResult{sql: sql}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func (r *renderer) renderScalarInCondition(field Field, values []ValueExpr) (renderResult, error) {
placeholders := make([]string, 0, len(values))
for _, v := range values {
lit, err := expectLiteral(v)
if err != nil {
return renderResult{}, err
}
switch field.Type {
case FieldTypeString:
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.Errorf("field %q expects string values", field.Name)
}
placeholders = append(placeholders, r.addArg(str))
case FieldTypeInt:
num, err := toInt64(lit)
if err != nil {
return renderResult{}, err
}
placeholders = append(placeholders, r.addArg(num))
default:
return renderResult{}, errors.Errorf("field %q does not support IN() comparisons", field.Name)
}
}
column := field.columnExpr(r.dialect)
return renderResult{
sql: fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")),
}, nil
}
func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
column := field.columnExpr(r.dialect)
arg := fmt.Sprintf("%%%s%%", cond.Value)
switch r.dialect {
case DialectSQLite:
// Use custom Unicode-aware case folding function for case-insensitive comparison.
// This overcomes SQLite's ASCII-only LOWER() limitation.
sql := fmt.Sprintf("memos_unicode_lower(%s) LIKE memos_unicode_lower(%s)", column, r.addArg(arg))
return renderResult{sql: sql}, nil
case DialectPostgres:
sql := fmt.Sprintf("%s ILIKE %s", column, r.addArg(arg))
return renderResult{sql: sql}, nil
default:
sql := fmt.Sprintf("%s LIKE %s", column, r.addArg(arg))
return renderResult{sql: sql}, nil
}
}
func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("field %q is not a JSON list", cond.Field)
}
// Render based on predicate type
switch pred := cond.Predicate.(type) {
case *StartsWithPredicate:
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
case *EndsWithPredicate:
return r.renderTagEndsWith(field, pred.Suffix, cond.Kind)
case *ContainsPredicate:
return r.renderTagContains(field, pred.Substring, cond.Kind)
default:
return renderResult{}, errors.Errorf("unsupported predicate type %T in comprehension", pred)
}
}
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite, DialectMySQL:
// Match exact tag or tags with this prefix (hierarchical support)
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, prefix))
prefixMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s%%`, prefix))
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
case DialectPostgres:
// Use PostgreSQL's powerful JSON operators
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, prefix)))
prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(fmt.Sprintf(`%%"%s%%`, prefix)))
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
// renderTagEndsWith generates SQL for tags.exists(t, t.endsWith("suffix")).
func (r *renderer) renderTagEndsWith(field Field, suffix string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)
pattern := fmt.Sprintf(`%%%s"%%`, suffix)
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
}
// renderTagContains generates SQL for tags.exists(t, t.contains("substring")).
func (r *renderer) renderTagContains(field Field, substring string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)
pattern := fmt.Sprintf(`%%%s%%`, substring)
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
}
// buildJSONArrayLike builds a LIKE expression for matching within a JSON array.
// Returns the LIKE clause without NULL/empty checks.
func (r *renderer) buildJSONArrayLike(arrayExpr, pattern string) string {
switch r.dialect {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("%s LIKE %s", arrayExpr, r.addArg(pattern))
case DialectPostgres:
return fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(pattern))
default:
return ""
}
}
// wrapWithNullCheck wraps a condition with NULL and empty array checks.
// This ensures we don't match against NULL or empty JSON arrays.
func (r *renderer) wrapWithNullCheck(arrayExpr, condition string) string {
var nullCheck string
switch r.dialect {
case DialectSQLite:
nullCheck = fmt.Sprintf("%s IS NOT NULL AND %s != '[]'", arrayExpr, arrayExpr)
case DialectMySQL:
nullCheck = fmt.Sprintf("%s IS NOT NULL AND JSON_LENGTH(%s) > 0", arrayExpr, arrayExpr)
case DialectPostgres:
nullCheck = fmt.Sprintf("%s IS NOT NULL AND jsonb_array_length(%s) > 0", arrayExpr, arrayExpr)
default:
return condition
}
return fmt.Sprintf("(%s AND %s)", condition, nullCheck)
}
func (r *renderer) jsonBoolPredicate(field Field) (string, error) {
expr := jsonExtractExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite:
return fmt.Sprintf("%s IS TRUE", expr), nil
case DialectMySQL:
return fmt.Sprintf("COALESCE(%s, CAST('false' AS JSON)) = CAST('true' AS JSON)", expr), nil
case DialectPostgres:
return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil
default:
return "", errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func combineAnd(left, right renderResult) renderResult {
if left.unsatisfiable || right.unsatisfiable {
return renderResult{sql: "1 = 0", unsatisfiable: true}
}
if left.trivial {
return right
}
if right.trivial {
return left
}
return renderResult{
sql: fmt.Sprintf("(%s AND %s)", left.sql, right.sql),
}
}
func combineOr(left, right renderResult) renderResult {
if left.trivial || right.trivial {
return renderResult{trivial: true}
}
if left.unsatisfiable {
return right
}
if right.unsatisfiable {
return left
}
return renderResult{
sql: fmt.Sprintf("(%s OR %s)", left.sql, right.sql),
}
}
func (r *renderer) addArg(value any) string {
r.placeholderCounter++
r.args = append(r.args, value)
if r.dialect == DialectPostgres {
return fmt.Sprintf("$%d", r.placeholderOffset+r.placeholderCounter)
}
return "?"
}
func (r *renderer) addBoolArg(value bool) string {
var v any
switch r.dialect {
case DialectSQLite:
if value {
v = 1
} else {
v = 0
}
default:
v = value
}
return r.addArg(v)
}
func expectLiteral(expr ValueExpr) (any, error) {
lit, ok := expr.(*LiteralValue)
if !ok {
return nil, errors.New("expression must be a literal")
}
return lit.Value, nil
}
func expectBool(expr ValueExpr) (bool, error) {
lit, err := expectLiteral(expr)
if err != nil {
return false, err
}
value, ok := lit.(bool)
if !ok {
return false, errors.New("boolean literal required")
}
return value, nil
}
func expectNumericLiteral(expr ValueExpr) (int64, error) {
lit, err := expectLiteral(expr)
if err != nil {
return 0, err
}
return toInt64(lit)
}
func toInt64(value any) (int64, error) {
switch v := value.(type) {
case int:
return int64(v), nil
case int32:
return int64(v), nil
case int64:
return v, nil
case uint32:
return int64(v), nil
case uint64:
return int64(v), nil
case float32:
return int64(v), nil
case float64:
return int64(v), nil
default:
return 0, errors.Errorf("cannot convert %T to int64", value)
}
}
func sqlOperator(op ComparisonOperator) string {
return string(op)
}
func qualifyColumn(d DialectName, col Column) string {
switch d {
case DialectPostgres:
return fmt.Sprintf("%s.%s", col.Table, col.Name)
default:
return fmt.Sprintf("`%s`.`%s`", col.Table, col.Name)
}
}
func jsonPath(field Field) string {
return "$." + strings.Join(field.JSONPath, ".")
}
func jsonExtractExpr(d DialectName, field Field) string {
column := qualifyColumn(d, field.Column)
switch d {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
case DialectPostgres:
return buildPostgresJSONAccessor(column, field.JSONPath, true)
default:
return ""
}
}
func jsonArrayExpr(d DialectName, field Field) string {
column := qualifyColumn(d, field.Column)
switch d {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
case DialectPostgres:
return buildPostgresJSONAccessor(column, field.JSONPath, false)
default:
return ""
}
}
func jsonArrayLengthExpr(d DialectName, field Field) string {
arrayExpr := jsonArrayExpr(d, field)
switch d {
case DialectSQLite:
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
case DialectMySQL:
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
case DialectPostgres:
return fmt.Sprintf("jsonb_array_length(COALESCE(%s, '[]'::jsonb))", arrayExpr)
default:
return ""
}
}
func buildPostgresJSONAccessor(base string, path []string, terminalText bool) string {
expr := base
for idx, part := range path {
if idx == len(path)-1 && terminalText {
expr = fmt.Sprintf("%s->>'%s'", expr, part)
} else {
expr = fmt.Sprintf("%s->'%s'", expr, part)
}
}
return expr
}

319
plugin/filter/schema.go Normal file
View File

@@ -0,0 +1,319 @@
package filter
import (
"fmt"
"time"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// DialectName enumerates supported SQL dialects.
type DialectName string
const (
DialectSQLite DialectName = "sqlite"
DialectMySQL DialectName = "mysql"
DialectPostgres DialectName = "postgres"
)
// FieldType represents the logical type of a field.
type FieldType string
const (
FieldTypeString FieldType = "string"
FieldTypeInt FieldType = "int"
FieldTypeBool FieldType = "bool"
FieldTypeTimestamp FieldType = "timestamp"
)
// FieldKind describes how a field is stored.
type FieldKind string
const (
FieldKindScalar FieldKind = "scalar"
FieldKindBoolColumn FieldKind = "bool_column"
FieldKindJSONBool FieldKind = "json_bool"
FieldKindJSONList FieldKind = "json_list"
FieldKindVirtualAlias FieldKind = "virtual_alias"
)
// Column identifies the backing table column.
type Column struct {
Table string
Name string
}
// Field captures the schema metadata for an exposed CEL identifier.
type Field struct {
Name string
Kind FieldKind
Type FieldType
Column Column
JSONPath []string
AliasFor string
SupportsContains bool
Expressions map[DialectName]string
AllowedComparisonOps map[ComparisonOperator]bool
}
// Schema collects CEL environment options and field metadata.
type Schema struct {
Name string
Fields map[string]Field
EnvOptions []cel.EnvOption
}
// Field returns the field metadata if present.
func (s Schema) Field(name string) (Field, bool) {
f, ok := s.Fields[name]
return f, ok
}
// ResolveAlias resolves a virtual alias to its target field.
func (s Schema) ResolveAlias(name string) (Field, bool) {
field, ok := s.Fields[name]
if !ok {
return Field{}, false
}
if field.Kind == FieldKindVirtualAlias {
target, ok := s.Fields[field.AliasFor]
if !ok {
return Field{}, false
}
return target, true
}
return field, true
}
var nowFunction = cel.Function("now",
cel.Overload("now",
[]*cel.Type{},
cel.IntType,
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
return types.Int(time.Now().Unix())
}),
),
)
// NewSchema constructs the memo filter schema and CEL environment.
func NewSchema() Schema {
fields := map[string]Field{
"content": {
Name: "content",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "memo", Name: "content"},
SupportsContains: true,
Expressions: map[DialectName]string{},
},
"creator_id": {
Name: "creator_id",
Kind: FieldKindScalar,
Type: FieldTypeInt,
Column: Column{Table: "memo", Name: "creator_id"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"created_ts": {
Name: "created_ts",
Kind: FieldKindScalar,
Type: FieldTypeTimestamp,
Column: Column{Table: "memo", Name: "created_ts"},
Expressions: map[DialectName]string{
// MySQL stores created_ts as TIMESTAMP, needs conversion to epoch
DialectMySQL: "UNIX_TIMESTAMP(%s)",
// PostgreSQL and SQLite store created_ts as BIGINT (epoch), no conversion needed
DialectPostgres: "%s",
DialectSQLite: "%s",
},
},
"updated_ts": {
Name: "updated_ts",
Kind: FieldKindScalar,
Type: FieldTypeTimestamp,
Column: Column{Table: "memo", Name: "updated_ts"},
Expressions: map[DialectName]string{
// MySQL stores updated_ts as TIMESTAMP, needs conversion to epoch
DialectMySQL: "UNIX_TIMESTAMP(%s)",
// PostgreSQL and SQLite store updated_ts as BIGINT (epoch), no conversion needed
DialectPostgres: "%s",
DialectSQLite: "%s",
},
},
"pinned": {
Name: "pinned",
Kind: FieldKindBoolColumn,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "pinned"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"visibility": {
Name: "visibility",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "memo", Name: "visibility"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"tags": {
Name: "tags",
Kind: FieldKindJSONList,
Type: FieldTypeString,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"tags"},
},
"tag": {
Name: "tag",
Kind: FieldKindVirtualAlias,
Type: FieldTypeString,
AliasFor: "tags",
},
"has_task_list": {
Name: "has_task_list",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasTaskList"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"has_link": {
Name: "has_link",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasLink"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"has_code": {
Name: "has_code",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasCode"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"has_incomplete_tasks": {
Name: "has_incomplete_tasks",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasIncompleteTasks"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
}
envOptions := []cel.EnvOption{
cel.Variable("content", cel.StringType),
cel.Variable("creator_id", cel.IntType),
cel.Variable("created_ts", cel.IntType),
cel.Variable("updated_ts", cel.IntType),
cel.Variable("pinned", cel.BoolType),
cel.Variable("tag", cel.StringType),
cel.Variable("tags", cel.ListType(cel.StringType)),
cel.Variable("visibility", cel.StringType),
cel.Variable("has_task_list", cel.BoolType),
cel.Variable("has_link", cel.BoolType),
cel.Variable("has_code", cel.BoolType),
cel.Variable("has_incomplete_tasks", cel.BoolType),
nowFunction,
}
return Schema{
Name: "memo",
Fields: fields,
EnvOptions: envOptions,
}
}
// NewAttachmentSchema constructs the attachment filter schema and CEL environment.
func NewAttachmentSchema() Schema {
fields := map[string]Field{
"filename": {
Name: "filename",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "attachment", Name: "filename"},
SupportsContains: true,
Expressions: map[DialectName]string{},
},
"mime_type": {
Name: "mime_type",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "attachment", Name: "type"},
Expressions: map[DialectName]string{},
},
"create_time": {
Name: "create_time",
Kind: FieldKindScalar,
Type: FieldTypeTimestamp,
Column: Column{Table: "attachment", Name: "created_ts"},
Expressions: map[DialectName]string{
// MySQL stores created_ts as TIMESTAMP, needs conversion to epoch
DialectMySQL: "UNIX_TIMESTAMP(%s)",
// PostgreSQL and SQLite store created_ts as BIGINT (epoch), no conversion needed
DialectPostgres: "%s",
DialectSQLite: "%s",
},
},
"memo_id": {
Name: "memo_id",
Kind: FieldKindScalar,
Type: FieldTypeInt,
Column: Column{Table: "attachment", Name: "memo_id"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
}
envOptions := []cel.EnvOption{
cel.Variable("filename", cel.StringType),
cel.Variable("mime_type", cel.StringType),
cel.Variable("create_time", cel.IntType),
cel.Variable("memo_id", cel.AnyType),
nowFunction,
}
return Schema{
Name: "attachment",
Fields: fields,
EnvOptions: envOptions,
}
}
// columnExpr returns the field expression for the given dialect, applying
// any schema-specific overrides (e.g. UNIX timestamp conversions).
func (f Field) columnExpr(d DialectName) string {
base := qualifyColumn(d, f.Column)
if expr, ok := f.Expressions[d]; ok && expr != "" {
return fmt.Sprintf(expr, base)
}
return base
}

View File

@@ -0,0 +1,166 @@
package httpgetter
import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"github.com/pkg/errors"
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
)
var ErrInternalIP = errors.New("internal IP addresses are not allowed")
var httpClient = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if err := validateURL(req.URL.String()); err != nil {
return errors.Wrap(err, "redirect to internal IP")
}
if len(via) >= 10 {
return errors.New("too many redirects")
}
return nil
},
}
type HTMLMeta struct {
Title string `json:"title"`
Description string `json:"description"`
Image string `json:"image"`
}
func GetHTMLMeta(urlStr string) (*HTMLMeta, error) {
if err := validateURL(urlStr); err != nil {
return nil, err
}
response, err := httpClient.Get(urlStr)
if err != nil {
return nil, err
}
defer response.Body.Close()
mediatype, err := getMediatype(response)
if err != nil {
return nil, err
}
if mediatype != "text/html" {
return nil, errors.New("not a HTML page")
}
// TODO: limit the size of the response body
htmlMeta := extractHTMLMeta(response.Body)
enrichSiteMeta(response.Request.URL, htmlMeta)
return htmlMeta, nil
}
func extractHTMLMeta(resp io.Reader) *HTMLMeta {
tokenizer := html.NewTokenizer(resp)
htmlMeta := new(HTMLMeta)
for {
tokenType := tokenizer.Next()
if tokenType == html.ErrorToken {
break
} else if tokenType == html.StartTagToken || tokenType == html.SelfClosingTagToken {
token := tokenizer.Token()
if token.DataAtom == atom.Body {
break
}
if token.DataAtom == atom.Title {
tokenizer.Next()
token := tokenizer.Token()
htmlMeta.Title = token.Data
} else if token.DataAtom == atom.Meta {
description, ok := extractMetaProperty(token, "description")
if ok {
htmlMeta.Description = description
}
ogTitle, ok := extractMetaProperty(token, "og:title")
if ok {
htmlMeta.Title = ogTitle
}
ogDescription, ok := extractMetaProperty(token, "og:description")
if ok {
htmlMeta.Description = ogDescription
}
ogImage, ok := extractMetaProperty(token, "og:image")
if ok {
htmlMeta.Image = ogImage
}
}
}
}
return htmlMeta
}
func extractMetaProperty(token html.Token, prop string) (content string, ok bool) {
content, ok = "", false
for _, attr := range token.Attr {
if attr.Key == "property" && attr.Val == prop {
ok = true
}
if attr.Key == "content" {
content = attr.Val
}
}
return content, ok
}
func validateURL(urlStr string) error {
u, err := url.Parse(urlStr)
if err != nil {
return errors.New("invalid URL format")
}
if u.Scheme != "http" && u.Scheme != "https" {
return errors.New("only http/https protocols are allowed")
}
host := u.Hostname()
if host == "" {
return errors.New("empty hostname")
}
// check if the hostname is an IP
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
return errors.Wrap(ErrInternalIP, ip.String())
}
return nil
}
// check if it's a hostname, resolve it and check all returned IPs
ips, err := net.LookupIP(host)
if err != nil {
return errors.Errorf("failed to resolve hostname: %v", err)
}
for _, ip := range ips {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
return errors.Wrapf(ErrInternalIP, "host=%s, ip=%s", host, ip.String())
}
}
return nil
}
func enrichSiteMeta(url *url.URL, meta *HTMLMeta) {
if url.Hostname() == "www.youtube.com" {
if url.Path == "/watch" {
vid := url.Query().Get("v")
if vid != "" {
meta.Image = fmt.Sprintf("https://img.youtube.com/vi/%s/mqdefault.jpg", vid)
}
}
}
}

View File

@@ -0,0 +1,32 @@
package httpgetter
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetHTMLMeta(t *testing.T) {
tests := []struct {
urlStr string
htmlMeta HTMLMeta
}{}
for _, test := range tests {
metadata, err := GetHTMLMeta(test.urlStr)
require.NoError(t, err)
require.Equal(t, test.htmlMeta, *metadata)
}
}
func TestGetHTMLMetaForInternal(t *testing.T) {
// test for internal IP
if _, err := GetHTMLMeta("http://192.168.0.1"); !errors.Is(err, ErrInternalIP) {
t.Errorf("Expected error for internal IP, got %v", err)
}
// test for resolved internal IP
if _, err := GetHTMLMeta("http://localhost"); !errors.Is(err, ErrInternalIP) {
t.Errorf("Expected error for resolved internal IP, got %v", err)
}
}

View File

@@ -0,0 +1 @@
package httpgetter

View File

@@ -0,0 +1,45 @@
package httpgetter
import (
"errors"
"io"
"net/http"
"net/url"
"strings"
)
type Image struct {
Blob []byte
Mediatype string
}
func GetImage(urlStr string) (*Image, error) {
if _, err := url.Parse(urlStr); err != nil {
return nil, err
}
response, err := http.Get(urlStr)
if err != nil {
return nil, err
}
defer response.Body.Close()
mediatype, err := getMediatype(response)
if err != nil {
return nil, err
}
if !strings.HasPrefix(mediatype, "image/") {
return nil, errors.New("wrong image mediatype")
}
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
image := &Image{
Blob: bodyBytes,
Mediatype: mediatype,
}
return image, nil
}

15
plugin/httpgetter/util.go Normal file
View File

@@ -0,0 +1,15 @@
package httpgetter
import (
"mime"
"net/http"
)
func getMediatype(response *http.Response) (string, error) {
contentType := response.Header.Get("content-type")
mediatype, _, err := mime.ParseMediaType(contentType)
if err != nil {
return "", err
}
return mediatype, nil
}

8
plugin/idp/idp.go Normal file
View File

@@ -0,0 +1,8 @@
package idp
type IdentityProviderUserInfo struct {
Identifier string
DisplayName string
Email string
AvatarURL string
}

134
plugin/idp/oauth2/oauth2.go Normal file
View File

@@ -0,0 +1,134 @@
// Package oauth2 is the plugin for OAuth2 Identity Provider.
package oauth2
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"github.com/pkg/errors"
"golang.org/x/oauth2"
"github.com/usememos/memos/plugin/idp"
storepb "github.com/usememos/memos/proto/gen/store"
)
// IdentityProvider represents an OAuth2 Identity Provider.
type IdentityProvider struct {
config *storepb.OAuth2Config
}
// NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration.
func NewIdentityProvider(config *storepb.OAuth2Config) (*IdentityProvider, error) {
for v, field := range map[string]string{
config.ClientId: "clientId",
config.ClientSecret: "clientSecret",
config.TokenUrl: "tokenUrl",
config.UserInfoUrl: "userInfoUrl",
config.FieldMapping.Identifier: "fieldMapping.identifier",
} {
if v == "" {
return nil, errors.Errorf(`the field "%s" is empty but required`, field)
}
}
return &IdentityProvider{
config: config,
}, nil
}
// ExchangeToken returns the exchanged OAuth2 token using the given authorization code.
// If codeVerifier is provided, it will be used for PKCE (Proof Key for Code Exchange) validation.
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code, codeVerifier string) (string, error) {
conf := &oauth2.Config{
ClientID: p.config.ClientId,
ClientSecret: p.config.ClientSecret,
RedirectURL: redirectURL,
Scopes: p.config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: p.config.AuthUrl,
TokenURL: p.config.TokenUrl,
AuthStyle: oauth2.AuthStyleInParams,
},
}
// Prepare token exchange options
opts := []oauth2.AuthCodeOption{}
// Add PKCE code_verifier if provided
if codeVerifier != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
}
token, err := conf.Exchange(ctx, code, opts...)
if err != nil {
return "", errors.Wrap(err, "failed to exchange access token")
}
// Use the standard AccessToken field instead of Extra()
// This is more reliable across different OAuth providers
if token.AccessToken == "" {
return "", errors.New("missing access token from authorization response")
}
return token.AccessToken, nil
}
// UserInfo returns the parsed user information using the given OAuth2 token.
func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) {
client := &http.Client{}
req, err := http.NewRequest(http.MethodGet, p.config.UserInfoUrl, nil)
if err != nil {
return nil, errors.Wrap(err, "failed to new http request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
resp, err := client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to get user information")
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read response body")
}
var claims map[string]any
if err := json.Unmarshal(body, &claims); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
}
slog.Info("user info claims", "claims", claims)
userInfo := &idp.IdentityProviderUserInfo{}
if v, ok := claims[p.config.FieldMapping.Identifier].(string); ok {
userInfo.Identifier = v
}
if userInfo.Identifier == "" {
return nil, errors.Errorf("the field %q is not found in claims or has empty value", p.config.FieldMapping.Identifier)
}
// Best effort to map optional fields
if p.config.FieldMapping.DisplayName != "" {
if v, ok := claims[p.config.FieldMapping.DisplayName].(string); ok {
userInfo.DisplayName = v
}
}
if userInfo.DisplayName == "" {
userInfo.DisplayName = userInfo.Identifier
}
if p.config.FieldMapping.Email != "" {
if v, ok := claims[p.config.FieldMapping.Email].(string); ok {
userInfo.Email = v
}
}
if p.config.FieldMapping.AvatarUrl != "" {
if v, ok := claims[p.config.FieldMapping.AvatarUrl].(string); ok {
userInfo.AvatarURL = v
}
}
slog.Info("user info", "userInfo", userInfo)
return userInfo, nil
}

View File

@@ -0,0 +1,164 @@
package oauth2
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/idp"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestNewIdentityProvider(t *testing.T) {
tests := []struct {
name string
config *storepb.OAuth2Config
containsErr string
}{
{
name: "no tokenUrl",
config: &storepb.OAuth2Config{
ClientId: "test-client-id",
ClientSecret: "test-client-secret",
AuthUrl: "",
TokenUrl: "",
UserInfoUrl: "https://example.com/api/user",
FieldMapping: &storepb.FieldMapping{
Identifier: "login",
},
},
containsErr: `the field "tokenUrl" is empty but required`,
},
{
name: "no userInfoUrl",
config: &storepb.OAuth2Config{
ClientId: "test-client-id",
ClientSecret: "test-client-secret",
AuthUrl: "",
TokenUrl: "https://example.com/token",
UserInfoUrl: "",
FieldMapping: &storepb.FieldMapping{
Identifier: "login",
},
},
containsErr: `the field "userInfoUrl" is empty but required`,
},
{
name: "no field mapping identifier",
config: &storepb.OAuth2Config{
ClientId: "test-client-id",
ClientSecret: "test-client-secret",
AuthUrl: "",
TokenUrl: "https://example.com/token",
UserInfoUrl: "https://example.com/api/user",
FieldMapping: &storepb.FieldMapping{
Identifier: "",
},
},
containsErr: `the field "fieldMapping.identifier" is empty but required`,
},
}
for _, test := range tests {
t.Run(test.name, func(*testing.T) {
_, err := NewIdentityProvider(test.config)
assert.ErrorContains(t, err, test.containsErr)
})
}
}
func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server {
mux := http.NewServeMux()
var rawIDToken string
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
vals, err := url.ParseQuery(string(body))
require.NoError(t, err)
require.Equal(t, code, vals.Get("code"))
require.Equal(t, "authorization_code", vals.Get("grant_type"))
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(map[string]any{
"access_token": accessToken,
"token_type": "Bearer",
"expires_in": 3600,
"id_token": rawIDToken,
})
require.NoError(t, err)
})
mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, err := w.Write(userinfo)
require.NoError(t, err)
})
s := httptest.NewServer(mux)
return s
}
func TestIdentityProvider(t *testing.T) {
ctx := context.Background()
const (
testClientID = "test-client-id"
testCode = "test-code"
testAccessToken = "test-access-token"
testSubject = "123456789"
testName = "John Doe"
testEmail = "john.doe@example.com"
)
userInfo, err := json.Marshal(
map[string]any{
"sub": testSubject,
"name": testName,
"email": testEmail,
},
)
require.NoError(t, err)
s := newMockServer(t, testCode, testAccessToken, userInfo)
oauth2, err := NewIdentityProvider(
&storepb.OAuth2Config{
ClientId: testClientID,
ClientSecret: "test-client-secret",
TokenUrl: fmt.Sprintf("%s/oauth2/token", s.URL),
UserInfoUrl: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
FieldMapping: &storepb.FieldMapping{
Identifier: "sub",
DisplayName: "name",
Email: "email",
},
},
)
require.NoError(t, err)
redirectURL := "https://example.com/oauth/callback"
// Test without PKCE (backward compatibility)
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode, "")
require.NoError(t, err)
require.Equal(t, testAccessToken, oauthToken)
userInfoResult, err := oauth2.UserInfo(oauthToken)
require.NoError(t, err)
wantUserInfo := &idp.IdentityProviderUserInfo{
Identifier: testSubject,
DisplayName: testName,
Email: testEmail,
}
assert.Equal(t, wantUserInfo, userInfoResult)
}

View File

@@ -0,0 +1,28 @@
package ast
import (
gast "github.com/yuin/goldmark/ast"
)
// TagNode represents a #tag in the markdown AST.
type TagNode struct {
gast.BaseInline
// Tag name without the # prefix
Tag []byte
}
// KindTag is the NodeKind for TagNode.
var KindTag = gast.NewNodeKind("Tag")
// Kind returns KindTag.
func (*TagNode) Kind() gast.NodeKind {
return KindTag
}
// Dump implements Node.Dump for debugging.
func (n *TagNode) Dump(source []byte, level int) {
gast.DumpHelper(n, source, level, map[string]string{
"Tag": string(n.Tag),
}, nil)
}

View File

@@ -0,0 +1,24 @@
package extensions
import (
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/parser"
"github.com/yuin/goldmark/util"
mparser "github.com/usememos/memos/plugin/markdown/parser"
)
type tagExtension struct{}
// TagExtension is a goldmark extension for #tag syntax.
var TagExtension = &tagExtension{}
// Extend extends the goldmark parser with tag support.
func (*tagExtension) Extend(m goldmark.Markdown) {
m.Parser().AddOptions(
parser.WithInlineParsers(
// Priority 200 - run before standard link parser (500)
util.Prioritized(mparser.NewTagParser(), 200),
),
)
}

409
plugin/markdown/markdown.go Normal file
View File

@@ -0,0 +1,409 @@
package markdown
import (
"bytes"
"strings"
"github.com/yuin/goldmark"
gast "github.com/yuin/goldmark/ast"
"github.com/yuin/goldmark/extension"
east "github.com/yuin/goldmark/extension/ast"
"github.com/yuin/goldmark/parser"
"github.com/yuin/goldmark/text"
mast "github.com/usememos/memos/plugin/markdown/ast"
"github.com/usememos/memos/plugin/markdown/extensions"
"github.com/usememos/memos/plugin/markdown/renderer"
storepb "github.com/usememos/memos/proto/gen/store"
)
// ExtractedData contains all metadata extracted from markdown in a single pass.
type ExtractedData struct {
Tags []string
Property *storepb.MemoPayload_Property
}
// Service handles markdown metadata extraction.
// It uses goldmark to parse markdown and extract tags, properties, and snippets.
// HTML rendering is primarily done on frontend using markdown-it, but backend provides
// RenderHTML for RSS feeds and other server-side rendering needs.
type Service interface {
// ExtractAll extracts tags, properties, and references in a single parse (most efficient)
ExtractAll(content []byte) (*ExtractedData, error)
// ExtractTags returns all #tags found in content
ExtractTags(content []byte) ([]string, error)
// ExtractProperties computes boolean properties
ExtractProperties(content []byte) (*storepb.MemoPayload_Property, error)
// RenderMarkdown renders goldmark AST back to markdown text
RenderMarkdown(content []byte) (string, error)
// RenderHTML renders markdown content to HTML
RenderHTML(content []byte) (string, error)
// GenerateSnippet creates plain text summary
GenerateSnippet(content []byte, maxLength int) (string, error)
// ValidateContent checks for syntax errors
ValidateContent(content []byte) error
// RenameTag renames all occurrences of oldTag to newTag in content
RenameTag(content []byte, oldTag, newTag string) (string, error)
}
// service implements the Service interface.
type service struct {
md goldmark.Markdown
}
// Option configures the markdown service.
type Option func(*config)
type config struct {
enableTags bool
}
// WithTagExtension enables #tag parsing.
func WithTagExtension() Option {
return func(c *config) {
c.enableTags = true
}
}
// NewService creates a new markdown service with the given options.
func NewService(opts ...Option) Service {
cfg := &config{}
for _, opt := range opts {
opt(cfg)
}
exts := []goldmark.Extender{
extension.GFM, // GitHub Flavored Markdown (tables, strikethrough, task lists, autolinks)
}
// Add custom extensions based on config
if cfg.enableTags {
exts = append(exts, extensions.TagExtension)
}
md := goldmark.New(
goldmark.WithExtensions(exts...),
goldmark.WithParserOptions(
parser.WithAutoHeadingID(), // Generate heading IDs
),
)
return &service{
md: md,
}
}
// parse is an internal helper to parse content into AST.
func (s *service) parse(content []byte) (gast.Node, error) {
reader := text.NewReader(content)
doc := s.md.Parser().Parse(reader)
return doc, nil
}
// ExtractTags returns all #tags found in content.
func (s *service) ExtractTags(content []byte) ([]string, error) {
root, err := s.parse(content)
if err != nil {
return nil, err
}
var tags []string
// Walk the AST to find tag nodes
err = gast.Walk(root, func(n gast.Node, entering bool) (gast.WalkStatus, error) {
if !entering {
return gast.WalkContinue, nil
}
// Check for custom TagNode
if tagNode, ok := n.(*mast.TagNode); ok {
tags = append(tags, string(tagNode.Tag))
}
return gast.WalkContinue, nil
})
if err != nil {
return nil, err
}
// Deduplicate tags while preserving original case
return uniquePreserveCase(tags), nil
}
// ExtractProperties computes boolean properties about the content.
func (s *service) ExtractProperties(content []byte) (*storepb.MemoPayload_Property, error) {
root, err := s.parse(content)
if err != nil {
return nil, err
}
prop := &storepb.MemoPayload_Property{}
err = gast.Walk(root, func(n gast.Node, entering bool) (gast.WalkStatus, error) {
if !entering {
return gast.WalkContinue, nil
}
switch n.Kind() {
case gast.KindLink:
prop.HasLink = true
case gast.KindCodeBlock, gast.KindFencedCodeBlock, gast.KindCodeSpan:
prop.HasCode = true
case east.KindTaskCheckBox:
prop.HasTaskList = true
if checkBox, ok := n.(*east.TaskCheckBox); ok {
if !checkBox.IsChecked {
prop.HasIncompleteTasks = true
}
}
default:
// No special handling for other node types
}
return gast.WalkContinue, nil
})
if err != nil {
return nil, err
}
return prop, nil
}
// RenderMarkdown renders goldmark AST back to markdown text.
func (s *service) RenderMarkdown(content []byte) (string, error) {
root, err := s.parse(content)
if err != nil {
return "", err
}
mdRenderer := renderer.NewMarkdownRenderer()
return mdRenderer.Render(root, content), nil
}
// RenderHTML renders markdown content to HTML using goldmark's built-in HTML renderer.
func (s *service) RenderHTML(content []byte) (string, error) {
var buf bytes.Buffer
if err := s.md.Convert(content, &buf); err != nil {
return "", err
}
return buf.String(), nil
}
// GenerateSnippet creates a plain text summary from markdown content.
func (s *service) GenerateSnippet(content []byte, maxLength int) (string, error) {
root, err := s.parse(content)
if err != nil {
return "", err
}
var buf strings.Builder
var lastNodeWasBlock bool
err = gast.Walk(root, func(n gast.Node, entering bool) (gast.WalkStatus, error) {
if entering {
// Skip code blocks and code spans entirely
switch n.Kind() {
case gast.KindCodeBlock, gast.KindFencedCodeBlock, gast.KindCodeSpan:
return gast.WalkSkipChildren, nil
default:
// Continue walking for other node types
}
// Add space before block elements (except first)
switch n.Kind() {
case gast.KindParagraph, gast.KindHeading, gast.KindListItem:
if buf.Len() > 0 && lastNodeWasBlock {
buf.WriteByte(' ')
}
default:
// No space needed for other node types
}
}
if !entering {
// Mark that we just exited a block element
switch n.Kind() {
case gast.KindParagraph, gast.KindHeading, gast.KindListItem:
lastNodeWasBlock = true
default:
// Not a block element
}
return gast.WalkContinue, nil
}
lastNodeWasBlock = false
// Only extract plain text nodes
if textNode, ok := n.(*gast.Text); ok {
segment := textNode.Segment
buf.Write(segment.Value(content))
// Add space if this is a soft line break
if textNode.SoftLineBreak() {
buf.WriteByte(' ')
}
}
// Stop walking if we've exceeded double the max length
// (we'll truncate precisely later)
if buf.Len() > maxLength*2 {
return gast.WalkStop, nil
}
return gast.WalkContinue, nil
})
if err != nil {
return "", err
}
snippet := buf.String()
// Truncate at word boundary if needed
if len(snippet) > maxLength {
snippet = truncateAtWord(snippet, maxLength)
}
return strings.TrimSpace(snippet), nil
}
// ValidateContent checks if the markdown content is valid.
func (s *service) ValidateContent(content []byte) error {
// Try to parse the content
_, err := s.parse(content)
return err
}
// ExtractAll extracts tags, properties, and references in a single parse for efficiency.
func (s *service) ExtractAll(content []byte) (*ExtractedData, error) {
root, err := s.parse(content)
if err != nil {
return nil, err
}
data := &ExtractedData{
Tags: []string{},
Property: &storepb.MemoPayload_Property{},
}
// Single walk to collect all data
err = gast.Walk(root, func(n gast.Node, entering bool) (gast.WalkStatus, error) {
if !entering {
return gast.WalkContinue, nil
}
// Extract tags
if tagNode, ok := n.(*mast.TagNode); ok {
data.Tags = append(data.Tags, string(tagNode.Tag))
}
// Extract properties based on node kind
switch n.Kind() {
case gast.KindLink:
data.Property.HasLink = true
case gast.KindCodeBlock, gast.KindFencedCodeBlock, gast.KindCodeSpan:
data.Property.HasCode = true
case east.KindTaskCheckBox:
data.Property.HasTaskList = true
if checkBox, ok := n.(*east.TaskCheckBox); ok {
if !checkBox.IsChecked {
data.Property.HasIncompleteTasks = true
}
}
default:
// No special handling for other node types
}
return gast.WalkContinue, nil
})
if err != nil {
return nil, err
}
// Deduplicate tags while preserving original case
data.Tags = uniquePreserveCase(data.Tags)
return data, nil
}
// RenameTag renames all occurrences of oldTag to newTag in content.
func (s *service) RenameTag(content []byte, oldTag, newTag string) (string, error) {
root, err := s.parse(content)
if err != nil {
return "", err
}
// Walk the AST to find and rename tag nodes
err = gast.Walk(root, func(n gast.Node, entering bool) (gast.WalkStatus, error) {
if !entering {
return gast.WalkContinue, nil
}
// Check for custom TagNode and rename if it matches
if tagNode, ok := n.(*mast.TagNode); ok {
if string(tagNode.Tag) == oldTag {
tagNode.Tag = []byte(newTag)
}
}
return gast.WalkContinue, nil
})
if err != nil {
return "", err
}
// Render back to markdown using the already-parsed AST
mdRenderer := renderer.NewMarkdownRenderer()
return mdRenderer.Render(root, content), nil
}
// uniquePreserveCase returns unique strings from input while preserving case.
func uniquePreserveCase(strs []string) []string {
seen := make(map[string]struct{})
var result []string
for _, s := range strs {
if _, exists := seen[s]; !exists {
seen[s] = struct{}{}
result = append(result, s)
}
}
return result
}
// truncateAtWord truncates a string at the last word boundary before maxLength.
// maxLength is treated as a rune (character) count to properly handle UTF-8 multi-byte characters.
func truncateAtWord(s string, maxLength int) string {
// Convert to runes to properly handle multi-byte UTF-8 characters
runes := []rune(s)
if len(runes) <= maxLength {
return s
}
// Truncate to max length (by character count, not byte count)
truncated := string(runes[:maxLength])
// Find last space to avoid cutting in the middle of a word
lastSpace := strings.LastIndexAny(truncated, " \t\n\r")
if lastSpace > 0 {
truncated = truncated[:lastSpace]
}
return truncated + " ..."
}

View File

@@ -0,0 +1,448 @@
package markdown
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewService(t *testing.T) {
svc := NewService()
assert.NotNil(t, svc)
}
func TestValidateContent(t *testing.T) {
svc := NewService()
tests := []struct {
name string
content string
wantErr bool
}{
{
name: "valid markdown",
content: "# Hello\n\nThis is **bold** text.",
wantErr: false,
},
{
name: "empty content",
content: "",
wantErr: false,
},
{
name: "complex markdown",
content: "# Title\n\n- List item 1\n- List item 2\n\n```go\ncode block\n```",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := svc.ValidateContent([]byte(tt.content))
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestGenerateSnippet(t *testing.T) {
svc := NewService()
tests := []struct {
name string
content string
maxLength int
expected string
}{
{
name: "simple text",
content: "Hello world",
maxLength: 100,
expected: "Hello world",
},
{
name: "text with formatting",
content: "This is **bold** and *italic* text.",
maxLength: 100,
expected: "This is bold and italic text.",
},
{
name: "truncate long text",
content: "This is a very long piece of text that should be truncated at a word boundary.",
maxLength: 30,
expected: "This is a very long piece of ...",
},
{
name: "heading and paragraph",
content: "# My Title\n\nThis is the first paragraph.",
maxLength: 100,
expected: "My Title This is the first paragraph.",
},
{
name: "code block removed",
content: "Text before\n\n```go\ncode\n```\n\nText after",
maxLength: 100,
expected: "Text before Text after",
},
{
name: "list items",
content: "- Item 1\n- Item 2\n- Item 3",
maxLength: 100,
expected: "Item 1 Item 2 Item 3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
snippet, err := svc.GenerateSnippet([]byte(tt.content), tt.maxLength)
require.NoError(t, err)
assert.Equal(t, tt.expected, snippet)
})
}
}
func TestExtractProperties(t *testing.T) {
tests := []struct {
name string
content string
hasLink bool
hasCode bool
hasTasks bool
hasInc bool
}{
{
name: "plain text",
content: "Just plain text",
hasLink: false,
hasCode: false,
hasTasks: false,
hasInc: false,
},
{
name: "with link",
content: "Check out [this link](https://example.com)",
hasLink: true,
hasCode: false,
hasTasks: false,
hasInc: false,
},
{
name: "with inline code",
content: "Use `console.log()` to debug",
hasLink: false,
hasCode: true,
hasTasks: false,
hasInc: false,
},
{
name: "with code block",
content: "```go\nfunc main() {}\n```",
hasLink: false,
hasCode: true,
hasTasks: false,
hasInc: false,
},
{
name: "with completed task",
content: "- [x] Completed task",
hasLink: false,
hasCode: false,
hasTasks: true,
hasInc: false,
},
{
name: "with incomplete task",
content: "- [ ] Todo item",
hasLink: false,
hasCode: false,
hasTasks: true,
hasInc: true,
},
{
name: "mixed tasks",
content: "- [x] Done\n- [ ] Not done",
hasLink: false,
hasCode: false,
hasTasks: true,
hasInc: true,
},
{
name: "everything",
content: "# Title\n\n[Link](url)\n\n`code`\n\n- [ ] Task",
hasLink: true,
hasCode: true,
hasTasks: true,
hasInc: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := NewService()
props, err := svc.ExtractProperties([]byte(tt.content))
require.NoError(t, err)
assert.Equal(t, tt.hasLink, props.HasLink, "HasLink")
assert.Equal(t, tt.hasCode, props.HasCode, "HasCode")
assert.Equal(t, tt.hasTasks, props.HasTaskList, "HasTaskList")
assert.Equal(t, tt.hasInc, props.HasIncompleteTasks, "HasIncompleteTasks")
})
}
}
func TestExtractTags(t *testing.T) {
tests := []struct {
name string
content string
withExt bool
expected []string
}{
{
name: "no tags",
content: "Just plain text",
withExt: false,
expected: []string{},
},
{
name: "single tag",
content: "Text with #tag",
withExt: true,
expected: []string{"tag"},
},
{
name: "multiple tags",
content: "Text with #tag1 and #tag2",
withExt: true,
expected: []string{"tag1", "tag2"},
},
{
name: "duplicate tags",
content: "#work is important. #Work #WORK",
withExt: true,
expected: []string{"work", "Work", "WORK"},
},
{
name: "tags with hyphens and underscores",
content: "Tags: #work-notes #2024_plans",
withExt: true,
expected: []string{"work-notes", "2024_plans"},
},
{
name: "tags at end of sentence",
content: "This is important #urgent.",
withExt: true,
expected: []string{"urgent"},
},
{
name: "headings not tags",
content: "## Heading\n\n# Title\n\nText with #realtag",
withExt: true,
expected: []string{"realtag"},
},
{
name: "numeric tag",
content: "Issue #123",
withExt: true,
expected: []string{"123"},
},
{
name: "tag in list",
content: "- Item 1 #todo\n- Item 2 #done",
withExt: true,
expected: []string{"todo", "done"},
},
{
name: "no extension enabled",
content: "Text with #tag",
withExt: false,
expected: []string{},
},
{
name: "Chinese tag",
content: "Text with #测试",
withExt: true,
expected: []string{"测试"},
},
{
name: "Chinese tag followed by punctuation",
content: "Text #测试。 More text",
withExt: true,
expected: []string{"测试"},
},
{
name: "mixed Chinese and ASCII tag",
content: "#测试test123 content",
withExt: true,
expected: []string{"测试test123"},
},
{
name: "Japanese tag",
content: "#日本語 content",
withExt: true,
expected: []string{"日本語"},
},
{
name: "Korean tag",
content: "#한국어 content",
withExt: true,
expected: []string{"한국어"},
},
{
name: "hierarchical tag with Chinese",
content: "#work/测试/项目",
withExt: true,
expected: []string{"work/测试/项目"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var svc Service
if tt.withExt {
svc = NewService(WithTagExtension())
} else {
svc = NewService()
}
tags, err := svc.ExtractTags([]byte(tt.content))
require.NoError(t, err)
assert.ElementsMatch(t, tt.expected, tags)
})
}
}
func TestUniquePreserveCase(t *testing.T) {
tests := []struct {
name string
input []string
expected []string
}{
{
name: "empty",
input: []string{},
expected: []string{},
},
{
name: "unique items",
input: []string{"tag1", "tag2", "tag3"},
expected: []string{"tag1", "tag2", "tag3"},
},
{
name: "duplicates",
input: []string{"tag", "TAG", "Tag"},
expected: []string{"tag", "TAG", "Tag"},
},
{
name: "mixed",
input: []string{"Work", "work", "Important", "work"},
expected: []string{"Work", "work", "Important"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := uniquePreserveCase(tt.input)
assert.ElementsMatch(t, tt.expected, result)
})
}
}
func TestTruncateAtWord(t *testing.T) {
tests := []struct {
name string
input string
maxLength int
expected string
}{
{
name: "no truncation needed",
input: "short",
maxLength: 10,
expected: "short",
},
{
name: "exact length",
input: "exactly ten",
maxLength: 11,
expected: "exactly ten",
},
{
name: "truncate at word",
input: "this is a long sentence",
maxLength: 10,
expected: "this is a ...",
},
{
name: "truncate very long word",
input: "supercalifragilisticexpialidocious",
maxLength: 10,
expected: "supercalif ...",
},
{
name: "CJK characters without spaces",
input: "这是一个很长的中文句子没有空格的情况下也要正确处理",
maxLength: 15,
expected: "这是一个很长的中文句子没有空格 ...",
},
{
name: "mixed CJK and Latin",
input: "这是中文mixed with English文字",
maxLength: 10,
expected: "这是中文mixed ...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncateAtWord(tt.input, tt.maxLength)
assert.Equal(t, tt.expected, result)
})
}
}
// Benchmark tests.
func BenchmarkGenerateSnippet(b *testing.B) {
svc := NewService()
content := []byte(`# Large Document
This is a large document with multiple paragraphs and formatting.
## Section 1
Here is some **bold** text and *italic* text with [links](https://example.com).
- List item 1
- List item 2
- List item 3
## Section 2
More content here with ` + "`inline code`" + ` and other elements.
` + "```go\nfunc example() {\n return true\n}\n```")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := svc.GenerateSnippet(content, 200)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkExtractProperties(b *testing.B) {
svc := NewService()
content := []byte("# Title\n\n[Link](url)\n\n`code`\n\n- [ ] Task\n- [x] Done")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := svc.ExtractProperties(content)
if err != nil {
b.Fatal(err)
}
}
}

View File

@@ -0,0 +1,139 @@
package parser
import (
"unicode"
"unicode/utf8"
gast "github.com/yuin/goldmark/ast"
"github.com/yuin/goldmark/parser"
"github.com/yuin/goldmark/text"
mast "github.com/usememos/memos/plugin/markdown/ast"
)
const (
// MaxTagLength defines the maximum number of runes allowed in a tag.
MaxTagLength = 100
)
type tagParser struct{}
// NewTagParser creates a new inline parser for #tag syntax.
func NewTagParser() parser.InlineParser {
return &tagParser{}
}
// Trigger returns the characters that trigger this parser.
func (*tagParser) Trigger() []byte {
return []byte{'#'}
}
// isValidTagRune checks if a Unicode rune is valid in a tag.
// Uses Unicode categories for proper international character support.
func isValidTagRune(r rune) bool {
// Allow Unicode letters (any script: Latin, CJK, Arabic, Cyrillic, etc.)
if unicode.IsLetter(r) {
return true
}
// Allow Unicode digits
if unicode.IsNumber(r) {
return true
}
// Allow emoji and symbols (So category: Symbol, Other)
// This includes emoji, which are essential for social media-style tagging
if unicode.IsSymbol(r) {
return true
}
// Allow specific ASCII symbols for tag structure
// Underscore: word separation (snake_case)
// Hyphen: word separation (kebab-case)
// Forward slash: hierarchical tags (category/subcategory)
// Ampersand: compound tags (science&tech)
if r == '_' || r == '-' || r == '/' || r == '&' {
return true
}
return false
}
// Parse parses #tag syntax using Unicode-aware validation.
// Tags support international characters and follow these rules:
// - Must start with # followed by valid tag characters
// - Valid characters: Unicode letters, Unicode digits, underscore (_), hyphen (-), forward slash (/)
// - Maximum length: 100 runes (Unicode characters)
// - Stops at: whitespace, punctuation, or other invalid characters
func (*tagParser) Parse(_ gast.Node, block text.Reader, _ parser.Context) gast.Node {
line, _ := block.PeekLine()
// Must start with #
if len(line) == 0 || line[0] != '#' {
return nil
}
// Check if it's a heading (## or space after #)
if len(line) > 1 {
if line[1] == '#' {
// It's a heading (##), not a tag
return nil
}
if line[1] == ' ' {
// Space after # - heading or just a hash
return nil
}
} else {
// Just a lone #
return nil
}
// Parse tag using UTF-8 aware rune iteration
tagStart := 1
pos := tagStart
runeCount := 0
for pos < len(line) {
r, size := utf8.DecodeRune(line[pos:])
// Stop at invalid UTF-8
if r == utf8.RuneError && size == 1 {
break
}
// Validate character using Unicode categories
if !isValidTagRune(r) {
break
}
// Enforce max length (by rune count, not byte count)
runeCount++
if runeCount > MaxTagLength {
break
}
pos += size
}
// Must have at least one character after #
if pos <= tagStart {
return nil
}
// Extract tag (without #)
tagName := line[tagStart:pos]
// Make a copy of the tag name
tagCopy := make([]byte, len(tagName))
copy(tagCopy, tagName)
// Advance reader
block.Advance(pos)
// Create node
node := &mast.TagNode{
Tag: tagCopy,
}
return node
}

View File

@@ -0,0 +1,251 @@
package parser
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/yuin/goldmark/parser"
"github.com/yuin/goldmark/text"
mast "github.com/usememos/memos/plugin/markdown/ast"
)
func TestTagParser(t *testing.T) {
tests := []struct {
name string
input string
expectedTag string
shouldParse bool
}{
{
name: "basic tag",
input: "#tag",
expectedTag: "tag",
shouldParse: true,
},
{
name: "tag with hyphen",
input: "#work-notes",
expectedTag: "work-notes",
shouldParse: true,
},
{
name: "tag with ampersand",
input: "#science&tech",
expectedTag: "science&tech",
shouldParse: true,
},
{
name: "tag with underscore",
input: "#2024_plans",
expectedTag: "2024_plans",
shouldParse: true,
},
{
name: "numeric tag",
input: "#123",
expectedTag: "123",
shouldParse: true,
},
{
name: "tag followed by space",
input: "#tag ",
expectedTag: "tag",
shouldParse: true,
},
{
name: "tag followed by punctuation",
input: "#tag.",
expectedTag: "tag",
shouldParse: true,
},
{
name: "tag in sentence",
input: "#important task",
expectedTag: "important",
shouldParse: true,
},
{
name: "heading (##)",
input: "## Heading",
expectedTag: "",
shouldParse: false,
},
{
name: "space after hash",
input: "# heading",
expectedTag: "",
shouldParse: false,
},
{
name: "lone hash",
input: "#",
expectedTag: "",
shouldParse: false,
},
{
name: "hash with space",
input: "# ",
expectedTag: "",
shouldParse: false,
},
{
name: "special characters",
input: "#tag@special",
expectedTag: "tag",
shouldParse: true,
},
{
name: "mixed case",
input: "#WorkNotes",
expectedTag: "WorkNotes",
shouldParse: true,
},
{
name: "hierarchical tag with slash",
input: "#tag1/subtag",
expectedTag: "tag1/subtag",
shouldParse: true,
},
{
name: "hierarchical tag with multiple levels",
input: "#tag1/subtag/subtag2",
expectedTag: "tag1/subtag/subtag2",
shouldParse: true,
},
{
name: "hierarchical tag followed by space",
input: "#work/notes ",
expectedTag: "work/notes",
shouldParse: true,
},
{
name: "hierarchical tag followed by punctuation",
input: "#project/2024.",
expectedTag: "project/2024",
shouldParse: true,
},
{
name: "hierarchical tag with numbers and dashes",
input: "#work-log/2024/q1",
expectedTag: "work-log/2024/q1",
shouldParse: true,
},
{
name: "Chinese characters",
input: "#测试",
expectedTag: "测试",
shouldParse: true,
},
{
name: "Chinese tag followed by space",
input: "#测试 some text",
expectedTag: "测试",
shouldParse: true,
},
{
name: "Chinese tag followed by punctuation",
input: "#测试。",
expectedTag: "测试",
shouldParse: true,
},
{
name: "mixed Chinese and ASCII",
input: "#测试test123",
expectedTag: "测试test123",
shouldParse: true,
},
{
name: "Japanese characters",
input: "#テスト",
expectedTag: "テスト",
shouldParse: true,
},
{
name: "Korean characters",
input: "#테스트",
expectedTag: "테스트",
shouldParse: true,
},
{
name: "emoji",
input: "#test🚀",
expectedTag: "test🚀",
shouldParse: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := NewTagParser()
reader := text.NewReader([]byte(tt.input))
ctx := parser.NewContext()
node := p.Parse(nil, reader, ctx)
if tt.shouldParse {
require.NotNil(t, node, "Expected tag to be parsed")
require.IsType(t, &mast.TagNode{}, node)
tagNode, ok := node.(*mast.TagNode)
require.True(t, ok, "Expected node to be *mast.TagNode")
assert.Equal(t, tt.expectedTag, string(tagNode.Tag))
} else {
assert.Nil(t, node, "Expected tag NOT to be parsed")
}
})
}
}
func TestTagParser_Trigger(t *testing.T) {
p := NewTagParser()
triggers := p.Trigger()
assert.Equal(t, []byte{'#'}, triggers)
}
func TestTagParser_MultipleTags(t *testing.T) {
// Test that parser correctly handles multiple tags in sequence
input := "#tag1 #tag2"
p := NewTagParser()
reader := text.NewReader([]byte(input))
ctx := parser.NewContext()
// Parse first tag
node1 := p.Parse(nil, reader, ctx)
require.NotNil(t, node1)
tagNode1, ok := node1.(*mast.TagNode)
require.True(t, ok, "Expected node1 to be *mast.TagNode")
assert.Equal(t, "tag1", string(tagNode1.Tag))
// Advance past the space
reader.Advance(1)
// Parse second tag
node2 := p.Parse(nil, reader, ctx)
require.NotNil(t, node2)
tagNode2, ok := node2.(*mast.TagNode)
require.True(t, ok, "Expected node2 to be *mast.TagNode")
assert.Equal(t, "tag2", string(tagNode2.Tag))
}
func TestTagNode_Kind(t *testing.T) {
node := &mast.TagNode{
Tag: []byte("test"),
}
assert.Equal(t, mast.KindTag, node.Kind())
}
func TestTagNode_Dump(t *testing.T) {
node := &mast.TagNode{
Tag: []byte("test"),
}
// Should not panic
assert.NotPanics(t, func() {
node.Dump([]byte("#test"), 0)
})
}

View File

@@ -0,0 +1,266 @@
package renderer
import (
"bytes"
"fmt"
"strings"
gast "github.com/yuin/goldmark/ast"
east "github.com/yuin/goldmark/extension/ast"
mast "github.com/usememos/memos/plugin/markdown/ast"
)
// MarkdownRenderer renders goldmark AST back to markdown text.
type MarkdownRenderer struct {
buf *bytes.Buffer
}
// NewMarkdownRenderer creates a new markdown renderer.
func NewMarkdownRenderer() *MarkdownRenderer {
return &MarkdownRenderer{
buf: &bytes.Buffer{},
}
}
// Render renders the AST node to markdown and returns the result.
func (r *MarkdownRenderer) Render(node gast.Node, source []byte) string {
r.buf.Reset()
r.renderNode(node, source, 0)
return r.buf.String()
}
// renderNode renders a single node and its children.
func (r *MarkdownRenderer) renderNode(node gast.Node, source []byte, depth int) {
switch n := node.(type) {
case *gast.Document:
r.renderChildren(n, source, depth)
case *gast.Paragraph:
r.renderChildren(n, source, depth)
if node.NextSibling() != nil {
r.buf.WriteString("\n\n")
}
case *gast.Text:
// Text nodes store their content as segments in the source
segment := n.Segment
r.buf.Write(segment.Value(source))
if n.SoftLineBreak() {
r.buf.WriteByte('\n')
} else if n.HardLineBreak() {
r.buf.WriteString(" \n")
}
case *gast.CodeSpan:
r.buf.WriteByte('`')
r.renderChildren(n, source, depth)
r.buf.WriteByte('`')
case *gast.Emphasis:
symbol := "*"
if n.Level == 2 {
symbol = "**"
}
r.buf.WriteString(symbol)
r.renderChildren(n, source, depth)
r.buf.WriteString(symbol)
case *gast.Link:
r.buf.WriteString("[")
r.renderChildren(n, source, depth)
r.buf.WriteString("](")
r.buf.Write(n.Destination)
if len(n.Title) > 0 {
r.buf.WriteString(` "`)
r.buf.Write(n.Title)
r.buf.WriteString(`"`)
}
r.buf.WriteString(")")
case *gast.AutoLink:
url := n.URL(source)
if n.AutoLinkType == gast.AutoLinkEmail {
r.buf.WriteString("<")
r.buf.Write(url)
r.buf.WriteString(">")
} else {
r.buf.Write(url)
}
case *gast.Image:
r.buf.WriteString("![")
r.renderChildren(n, source, depth)
r.buf.WriteString("](")
r.buf.Write(n.Destination)
if len(n.Title) > 0 {
r.buf.WriteString(` "`)
r.buf.Write(n.Title)
r.buf.WriteString(`"`)
}
r.buf.WriteString(")")
case *gast.Heading:
r.buf.WriteString(strings.Repeat("#", n.Level))
r.buf.WriteByte(' ')
r.renderChildren(n, source, depth)
if node.NextSibling() != nil {
r.buf.WriteString("\n\n")
}
case *gast.CodeBlock, *gast.FencedCodeBlock:
r.renderCodeBlock(n, source)
case *gast.Blockquote:
// Render each child line with "> " prefix
r.renderBlockquote(n, source, depth)
if node.NextSibling() != nil {
r.buf.WriteString("\n\n")
}
case *gast.List:
r.renderChildren(n, source, depth)
if node.NextSibling() != nil {
r.buf.WriteString("\n\n")
}
case *gast.ListItem:
r.renderListItem(n, source, depth)
case *gast.ThematicBreak:
r.buf.WriteString("---")
if node.NextSibling() != nil {
r.buf.WriteString("\n\n")
}
case *east.Strikethrough:
r.buf.WriteString("~~")
r.renderChildren(n, source, depth)
r.buf.WriteString("~~")
case *east.TaskCheckBox:
if n.IsChecked {
r.buf.WriteString("[x] ")
} else {
r.buf.WriteString("[ ] ")
}
case *east.Table:
r.renderTable(n, source)
if node.NextSibling() != nil {
r.buf.WriteString("\n\n")
}
// Custom Memos nodes
case *mast.TagNode:
r.buf.WriteByte('#')
r.buf.Write(n.Tag)
default:
// For unknown nodes, try to render children
r.renderChildren(n, source, depth)
}
}
// renderChildren renders all children of a node.
func (r *MarkdownRenderer) renderChildren(node gast.Node, source []byte, depth int) {
child := node.FirstChild()
for child != nil {
r.renderNode(child, source, depth+1)
child = child.NextSibling()
}
}
// renderCodeBlock renders a code block.
func (r *MarkdownRenderer) renderCodeBlock(node gast.Node, source []byte) {
if fenced, ok := node.(*gast.FencedCodeBlock); ok {
// Fenced code block with language
r.buf.WriteString("```")
if lang := fenced.Language(source); len(lang) > 0 {
r.buf.Write(lang)
}
r.buf.WriteByte('\n')
// Write all lines
lines := fenced.Lines()
for i := 0; i < lines.Len(); i++ {
line := lines.At(i)
r.buf.Write(line.Value(source))
}
r.buf.WriteString("```")
if node.NextSibling() != nil {
r.buf.WriteString("\n\n")
}
} else if codeBlock, ok := node.(*gast.CodeBlock); ok {
// Indented code block
lines := codeBlock.Lines()
for i := 0; i < lines.Len(); i++ {
line := lines.At(i)
r.buf.WriteString(" ")
r.buf.Write(line.Value(source))
}
if node.NextSibling() != nil {
r.buf.WriteString("\n\n")
}
}
}
// renderBlockquote renders a blockquote with "> " prefix.
func (r *MarkdownRenderer) renderBlockquote(node *gast.Blockquote, source []byte, depth int) {
// Create a temporary buffer for the blockquote content
tempBuf := &bytes.Buffer{}
tempRenderer := &MarkdownRenderer{buf: tempBuf}
tempRenderer.renderChildren(node, source, depth)
// Add "> " prefix to each line
content := tempBuf.String()
lines := strings.Split(strings.TrimRight(content, "\n"), "\n")
for i, line := range lines {
r.buf.WriteString("> ")
r.buf.WriteString(line)
if i < len(lines)-1 {
r.buf.WriteByte('\n')
}
}
}
// renderListItem renders a list item with proper indentation and markers.
func (r *MarkdownRenderer) renderListItem(node *gast.ListItem, source []byte, depth int) {
parent := node.Parent()
list, ok := parent.(*gast.List)
if !ok {
r.renderChildren(node, source, depth)
return
}
// Add indentation only for nested lists
// Document=0, List=1, ListItem=2 (no indent), nested ListItem=3+ (indent)
if depth > 2 {
indent := strings.Repeat(" ", depth-2)
r.buf.WriteString(indent)
}
// Add list marker
if list.IsOrdered() {
fmt.Fprintf(r.buf, "%d. ", list.Start)
list.Start++ // Increment for next item
} else {
r.buf.WriteString("- ")
}
// Render content
r.renderChildren(node, source, depth)
// Add newline if there's a next sibling
if node.NextSibling() != nil {
r.buf.WriteByte('\n')
}
}
// renderTable renders a table in markdown format.
func (r *MarkdownRenderer) renderTable(table *east.Table, source []byte) {
// This is a simplified table renderer
// A full implementation would need to handle alignment, etc.
r.renderChildren(table, source, 0)
}

View File

@@ -0,0 +1,176 @@
package renderer
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/extension"
"github.com/yuin/goldmark/parser"
"github.com/yuin/goldmark/text"
"github.com/usememos/memos/plugin/markdown/extensions"
)
func TestMarkdownRenderer(t *testing.T) {
// Create goldmark instance with all extensions
md := goldmark.New(
goldmark.WithExtensions(
extension.GFM,
extensions.TagExtension,
),
goldmark.WithParserOptions(
parser.WithAutoHeadingID(),
),
)
tests := []struct {
name string
input string
expected string
}{
{
name: "simple text",
input: "Hello world",
expected: "Hello world",
},
{
name: "paragraph with newlines",
input: "First paragraph\n\nSecond paragraph",
expected: "First paragraph\n\nSecond paragraph",
},
{
name: "emphasis",
input: "This is *italic* and **bold** text",
expected: "This is *italic* and **bold** text",
},
{
name: "headings",
input: "# Heading 1\n\n## Heading 2\n\n### Heading 3",
expected: "# Heading 1\n\n## Heading 2\n\n### Heading 3",
},
{
name: "link",
input: "Check [this link](https://example.com)",
expected: "Check [this link](https://example.com)",
},
{
name: "image",
input: "![alt text](image.png)",
expected: "![alt text](image.png)",
},
{
name: "code inline",
input: "This is `inline code` here",
expected: "This is `inline code` here",
},
{
name: "code block fenced",
input: "```go\nfunc main() {\n}\n```",
expected: "```go\nfunc main() {\n}\n```",
},
{
name: "unordered list",
input: "- Item 1\n- Item 2\n- Item 3",
expected: "- Item 1\n- Item 2\n- Item 3",
},
{
name: "ordered list",
input: "1. First\n2. Second\n3. Third",
expected: "1. First\n2. Second\n3. Third",
},
{
name: "blockquote",
input: "> This is a quote\n> Second line",
expected: "> This is a quote\n> Second line",
},
{
name: "horizontal rule",
input: "Text before\n\n---\n\nText after",
expected: "Text before\n\n---\n\nText after",
},
{
name: "strikethrough",
input: "This is ~~deleted~~ text",
expected: "This is ~~deleted~~ text",
},
{
name: "task list",
input: "- [x] Completed task\n- [ ] Incomplete task",
expected: "- [x] Completed task\n- [ ] Incomplete task",
},
{
name: "tag",
input: "This has #tag in it",
expected: "This has #tag in it",
},
{
name: "multiple tags",
input: "#work #important meeting notes",
expected: "#work #important meeting notes",
},
{
name: "complex mixed content",
input: "# Meeting Notes\n\n**Date**: 2024-01-01\n\n## Attendees\n- Alice\n- Bob\n\n## Discussion\n\nWe discussed #project status.\n\n```python\nprint('hello')\n```",
expected: "# Meeting Notes\n\n**Date**: 2024-01-01\n\n## Attendees\n\n- Alice\n- Bob\n\n## Discussion\n\nWe discussed #project status.\n\n```python\nprint('hello')\n```",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Parse the input
source := []byte(tt.input)
reader := text.NewReader(source)
doc := md.Parser().Parse(reader)
require.NotNil(t, doc)
// Render back to markdown
renderer := NewMarkdownRenderer()
result := renderer.Render(doc, source)
// For debugging
if result != tt.expected {
t.Logf("Input: %q", tt.input)
t.Logf("Expected: %q", tt.expected)
t.Logf("Got: %q", result)
}
assert.Equal(t, tt.expected, result)
})
}
}
func TestMarkdownRendererPreservesStructure(t *testing.T) {
// Test that parsing and rendering preserves structure
md := goldmark.New(
goldmark.WithExtensions(
extension.GFM,
extensions.TagExtension,
),
)
inputs := []string{
"# Title\n\nParagraph",
"**Bold** and *italic*",
"- List\n- Items",
"#tag #another",
"> Quote",
}
renderer := NewMarkdownRenderer()
for _, input := range inputs {
t.Run(input, func(t *testing.T) {
source := []byte(input)
reader := text.NewReader(source)
doc := md.Parser().Parse(reader)
result := renderer.Render(doc, source)
// The result should be structurally similar
// (may have minor formatting differences)
assert.NotEmpty(t, result)
})
}
}

367
plugin/scheduler/README.md Normal file
View File

@@ -0,0 +1,367 @@
# Scheduler Plugin
A production-ready, GitHub Actions-inspired cron job scheduler for Go.
## Features
- **Standard Cron Syntax**: Supports both 5-field and 6-field (with seconds) cron expressions
- **Timezone-Aware**: Explicit timezone handling to avoid DST surprises
- **Middleware Pattern**: Composable job wrappers for logging, metrics, panic recovery, timeouts
- **Graceful Shutdown**: Jobs complete cleanly or cancel when context expires
- **Zero Dependencies**: Core functionality uses only the standard library
- **Type-Safe**: Strong typing with clear error messages
- **Well-Tested**: Comprehensive test coverage
## Installation
This package is included with Memos. No separate installation required.
## Quick Start
```go
package main
import (
"context"
"fmt"
"github.com/usememos/memos/plugin/scheduler"
)
func main() {
s := scheduler.New()
s.Register(&scheduler.Job{
Name: "daily-cleanup",
Schedule: "0 2 * * *", // 2 AM daily
Handler: func(ctx context.Context) error {
fmt.Println("Running cleanup...")
return nil
},
})
s.Start()
defer s.Stop(context.Background())
// Keep running...
select {}
}
```
## Cron Expression Format
### 5-Field Format (Standard)
```
┌───────────── minute (0 - 59)
│ ┌───────────── hour (0 - 23)
│ │ ┌───────────── day of month (1 - 31)
│ │ │ ┌───────────── month (1 - 12)
│ │ │ │ ┌───────────── day of week (0 - 7) (Sunday = 0 or 7)
│ │ │ │ │
* * * * *
```
### 6-Field Format (With Seconds)
```
┌───────────── second (0 - 59)
│ ┌───────────── minute (0 - 59)
│ │ ┌───────────── hour (0 - 23)
│ │ │ ┌───────────── day of month (1 - 31)
│ │ │ │ ┌───────────── month (1 - 12)
│ │ │ │ │ ┌───────────── day of week (0 - 7)
│ │ │ │ │ │
* * * * * *
```
### Special Characters
- `*` - Any value (every minute, every hour, etc.)
- `,` - List of values: `1,15,30` (1st, 15th, and 30th)
- `-` - Range: `9-17` (9 AM through 5 PM)
- `/` - Step: `*/15` (every 15 units)
### Common Examples
| Schedule | Description |
|----------|-------------|
| `* * * * *` | Every minute |
| `0 * * * *` | Every hour |
| `0 0 * * *` | Daily at midnight |
| `0 9 * * 1-5` | Weekdays at 9 AM |
| `*/15 * * * *` | Every 15 minutes |
| `0 0 1 * *` | First day of every month |
| `0 0 * * 0` | Every Sunday at midnight |
| `30 14 * * *` | Every day at 2:30 PM |
## Timezone Support
```go
// Global timezone for all jobs
s := scheduler.New(
scheduler.WithTimezone("America/New_York"),
)
// Per-job timezone (overrides global)
s.Register(&scheduler.Job{
Name: "tokyo-report",
Schedule: "0 9 * * *", // 9 AM Tokyo time
Timezone: "Asia/Tokyo",
Handler: func(ctx context.Context) error {
// Runs at 9 AM in Tokyo
return nil
},
})
```
**Important**: Always use IANA timezone names (`America/New_York`, not `EST`).
## Middleware
Middleware wraps job handlers to add cross-cutting behavior. Multiple middleware can be chained together.
### Built-in Middleware
#### Recovery (Panic Handling)
```go
s := scheduler.New(
scheduler.WithMiddleware(
scheduler.Recovery(func(jobName string, r interface{}) {
log.Printf("Job %s panicked: %v", jobName, r)
}),
),
)
```
#### Logging
```go
type Logger interface {
Info(msg string, args ...interface{})
Error(msg string, args ...interface{})
}
s := scheduler.New(
scheduler.WithMiddleware(
scheduler.Logging(myLogger),
),
)
```
#### Timeout
```go
s := scheduler.New(
scheduler.WithMiddleware(
scheduler.Timeout(5 * time.Minute),
),
)
```
### Combining Middleware
```go
s := scheduler.New(
scheduler.WithMiddleware(
scheduler.Recovery(panicHandler),
scheduler.Logging(logger),
scheduler.Timeout(10 * time.Minute),
),
)
```
**Order matters**: Middleware are applied left-to-right. In the example above:
1. Recovery (outermost) catches panics from everything
2. Logging logs the execution
3. Timeout (innermost) wraps the actual handler
### Custom Middleware
```go
func Metrics(recorder MetricsRecorder) scheduler.Middleware {
return func(next scheduler.JobHandler) scheduler.JobHandler {
return func(ctx context.Context) error {
start := time.Now()
err := next(ctx)
duration := time.Since(start)
jobName := scheduler.GetJobName(ctx)
recorder.Record(jobName, duration, err)
return err
}
}
}
```
## Graceful Shutdown
Always use `Stop()` with a context to allow jobs to finish cleanly:
```go
// Give jobs up to 30 seconds to complete
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.Stop(ctx); err != nil {
log.Printf("Shutdown error: %v", err)
}
```
Jobs should respect context cancellation:
```go
Handler: func(ctx context.Context) error {
for i := 0; i < 100; i++ {
select {
case <-ctx.Done():
return ctx.Err() // Canceled
default:
// Do work
}
}
return nil
}
```
## Best Practices
### 1. Always Name Your Jobs
Names are used for logging, metrics, and debugging:
```go
Name: "user-cleanup-job" // Good
Name: "job1" // Bad
```
### 2. Add Descriptions and Tags
```go
s.Register(&scheduler.Job{
Name: "stale-session-cleanup",
Description: "Removes user sessions older than 30 days",
Tags: []string{"maintenance", "security"},
Schedule: "0 3 * * *",
Handler: cleanupSessions,
})
```
### 3. Use Appropriate Middleware
Always include Recovery and Logging in production:
```go
scheduler.New(
scheduler.WithMiddleware(
scheduler.Recovery(logPanic),
scheduler.Logging(logger),
),
)
```
### 4. Avoid Scheduling Exactly on the Hour
Many systems schedule jobs at `:00`, causing load spikes. Stagger your jobs:
```go
"5 2 * * *" // 2:05 AM (good)
"0 2 * * *" // 2:00 AM (often overloaded)
```
### 5. Make Jobs Idempotent
Jobs may run multiple times (crash recovery, etc.). Design them to be safely re-runnable:
```go
Handler: func(ctx context.Context) error {
// Use unique constraint or check-before-insert
db.Exec("INSERT IGNORE INTO processed_items ...")
return nil
}
```
### 6. Handle Timezones Explicitly
Always specify timezone for business-hour jobs:
```go
Timezone: "America/New_York" // Good
// Timezone: "" // Bad (defaults to UTC)
```
### 7. Test Your Cron Expressions
Use a cron expression calculator before deploying:
- [crontab.guru](https://crontab.guru/)
- Write unit tests with the parser
## Testing Jobs
Test job handlers independently of the scheduler:
```go
func TestCleanupJob(t *testing.T) {
ctx := context.Background()
err := cleanupHandler(ctx)
if err != nil {
t.Fatalf("cleanup failed: %v", err)
}
// Verify cleanup occurred
}
```
Test schedule parsing:
```go
func TestScheduleParsing(t *testing.T) {
job := &scheduler.Job{
Name: "test",
Schedule: "0 2 * * *",
Handler: func(ctx context.Context) error { return nil },
}
if err := job.Validate(); err != nil {
t.Fatalf("invalid job: %v", err)
}
}
```
## Comparison to Other Solutions
| Feature | scheduler | robfig/cron | github.com/go-co-op/gocron |
|---------|-----------|-------------|----------------------------|
| Standard cron syntax | ✅ | ✅ | ✅ |
| Seconds support | ✅ | ✅ | ✅ |
| Timezone support | ✅ | ✅ | ✅ |
| Middleware pattern | ✅ | ⚠️ (basic) | ❌ |
| Graceful shutdown | ✅ | ⚠️ (basic) | ✅ |
| Zero dependencies | ✅ | ❌ | ❌ |
| Job metadata | ✅ | ❌ | ⚠️ (limited) |
## API Reference
See [example_test.go](./example_test.go) for comprehensive examples.
### Core Types
- `Scheduler` - Manages scheduled jobs
- `Job` - Job definition with schedule and handler
- `Middleware` - Function that wraps job handlers
### Functions
- `New(opts ...Option) *Scheduler` - Create new scheduler
- `WithTimezone(tz string) Option` - Set default timezone
- `WithMiddleware(mw ...Middleware) Option` - Add middleware
### Methods
- `Register(job *Job) error` - Add job to scheduler
- `Start() error` - Begin executing jobs
- `Stop(ctx context.Context) error` - Graceful shutdown
## License
This package is part of the Memos project and shares its license.

35
plugin/scheduler/doc.go Normal file
View File

@@ -0,0 +1,35 @@
// Package scheduler provides a GitHub Actions-inspired cron job scheduler.
//
// Features:
// - Standard cron expression syntax (5-field and 6-field formats)
// - Timezone-aware scheduling
// - Middleware pattern for cross-cutting concerns (logging, metrics, recovery)
// - Graceful shutdown with context cancellation
// - Zero external dependencies
//
// Basic usage:
//
// s := scheduler.New()
//
// s.Register(&scheduler.Job{
// Name: "daily-cleanup",
// Schedule: "0 2 * * *", // 2 AM daily
// Handler: func(ctx context.Context) error {
// // Your cleanup logic here
// return nil
// },
// })
//
// s.Start()
// defer s.Stop(context.Background())
//
// With middleware:
//
// s := scheduler.New(
// scheduler.WithTimezone("America/New_York"),
// scheduler.WithMiddleware(
// scheduler.Recovery(),
// scheduler.Logging(),
// ),
// )
package scheduler

View File

@@ -0,0 +1,165 @@
package scheduler_test
import (
"context"
"fmt"
"log/slog"
"os"
"time"
"github.com/usememos/memos/plugin/scheduler"
)
// Example demonstrates basic scheduler usage.
func Example_basic() {
s := scheduler.New()
s.Register(&scheduler.Job{
Name: "hello",
Schedule: "*/5 * * * *", // Every 5 minutes
Description: "Say hello",
Handler: func(_ context.Context) error {
fmt.Println("Hello from scheduler!")
return nil
},
})
s.Start()
defer s.Stop(context.Background())
// Scheduler runs in background
time.Sleep(100 * time.Millisecond)
}
// Example demonstrates timezone-aware scheduling.
func Example_timezone() {
s := scheduler.New(
scheduler.WithTimezone("America/New_York"),
)
s.Register(&scheduler.Job{
Name: "daily-report",
Schedule: "0 9 * * *", // 9 AM in New York
Handler: func(_ context.Context) error {
fmt.Println("Generating daily report...")
return nil
},
})
s.Start()
defer s.Stop(context.Background())
}
// Example demonstrates middleware usage.
func Example_middleware() {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
s := scheduler.New(
scheduler.WithMiddleware(
scheduler.Recovery(func(jobName string, r interface{}) {
logger.Error("Job panicked", "job", jobName, "panic", r)
}),
scheduler.Logging(&slogAdapter{logger}),
scheduler.Timeout(5*time.Minute),
),
)
s.Register(&scheduler.Job{
Name: "data-sync",
Schedule: "0 */2 * * *", // Every 2 hours
Handler: func(_ context.Context) error {
// Your sync logic here
return nil
},
})
s.Start()
defer s.Stop(context.Background())
}
// slogAdapter adapts slog.Logger to scheduler.Logger interface.
type slogAdapter struct {
logger *slog.Logger
}
func (a *slogAdapter) Info(msg string, args ...interface{}) {
a.logger.Info(msg, args...)
}
func (a *slogAdapter) Error(msg string, args ...interface{}) {
a.logger.Error(msg, args...)
}
// Example demonstrates multiple jobs with different schedules.
func Example_multipleJobs() {
s := scheduler.New()
// Cleanup old data every night at 2 AM
s.Register(&scheduler.Job{
Name: "cleanup",
Schedule: "0 2 * * *",
Tags: []string{"maintenance"},
Handler: func(_ context.Context) error {
fmt.Println("Cleaning up old data...")
return nil
},
})
// Health check every 5 minutes
s.Register(&scheduler.Job{
Name: "health-check",
Schedule: "*/5 * * * *",
Tags: []string{"monitoring"},
Handler: func(_ context.Context) error {
fmt.Println("Running health check...")
return nil
},
})
// Weekly backup on Sundays at 1 AM
s.Register(&scheduler.Job{
Name: "weekly-backup",
Schedule: "0 1 * * 0",
Tags: []string{"backup"},
Handler: func(_ context.Context) error {
fmt.Println("Creating weekly backup...")
return nil
},
})
s.Start()
defer s.Stop(context.Background())
}
// Example demonstrates graceful shutdown with timeout.
func Example_gracefulShutdown() {
s := scheduler.New()
s.Register(&scheduler.Job{
Name: "long-running",
Schedule: "* * * * *",
Handler: func(ctx context.Context) error {
select {
case <-time.After(30 * time.Second):
fmt.Println("Job completed")
case <-ctx.Done():
fmt.Println("Job canceled, cleaning up...")
return ctx.Err()
}
return nil
},
})
s.Start()
// Simulate shutdown signal
time.Sleep(5 * time.Second)
// Give jobs 10 seconds to finish
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.Stop(shutdownCtx); err != nil {
fmt.Printf("Shutdown error: %v\n", err)
}
}

View File

@@ -0,0 +1,393 @@
package scheduler_test
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/usememos/memos/plugin/scheduler"
)
// TestRealWorldScenario tests a realistic multi-job scenario.
func TestRealWorldScenario(t *testing.T) {
var (
quickJobCount atomic.Int32
hourlyJobCount atomic.Int32
logEntries []string
logMu sync.Mutex
)
logger := &testLogger{
onInfo: func(msg string, _ ...interface{}) {
logMu.Lock()
logEntries = append(logEntries, fmt.Sprintf("INFO: %s", msg))
logMu.Unlock()
},
onError: func(msg string, _ ...interface{}) {
logMu.Lock()
logEntries = append(logEntries, fmt.Sprintf("ERROR: %s", msg))
logMu.Unlock()
},
}
s := scheduler.New(
scheduler.WithTimezone("UTC"),
scheduler.WithMiddleware(
scheduler.Recovery(func(jobName string, r interface{}) {
t.Logf("Job %s panicked: %v", jobName, r)
}),
scheduler.Logging(logger),
scheduler.Timeout(5*time.Second),
),
)
// Quick job (every second)
s.Register(&scheduler.Job{
Name: "quick-check",
Schedule: "* * * * * *",
Handler: func(_ context.Context) error {
quickJobCount.Add(1)
time.Sleep(100 * time.Millisecond)
return nil
},
})
// Slower job (every 2 seconds)
s.Register(&scheduler.Job{
Name: "slow-process",
Schedule: "*/2 * * * * *",
Handler: func(_ context.Context) error {
hourlyJobCount.Add(1)
time.Sleep(500 * time.Millisecond)
return nil
},
})
// Start scheduler
if err := s.Start(); err != nil {
t.Fatalf("failed to start scheduler: %v", err)
}
// Let it run for 5 seconds
time.Sleep(5 * time.Second)
// Graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.Stop(ctx); err != nil {
t.Fatalf("failed to stop scheduler: %v", err)
}
// Verify execution counts
quick := quickJobCount.Load()
slow := hourlyJobCount.Load()
t.Logf("Quick job ran %d times", quick)
t.Logf("Slow job ran %d times", slow)
if quick < 4 {
t.Errorf("expected quick job to run at least 4 times, ran %d", quick)
}
if slow < 2 {
t.Errorf("expected slow job to run at least 2 times, ran %d", slow)
}
// Verify logging
logMu.Lock()
defer logMu.Unlock()
hasStartLog := false
hasCompleteLog := false
for _, entry := range logEntries {
if contains(entry, "Job started") {
hasStartLog = true
}
if contains(entry, "Job completed") {
hasCompleteLog = true
}
}
if !hasStartLog {
t.Error("expected job start logs")
}
if !hasCompleteLog {
t.Error("expected job completion logs")
}
}
// TestCancellationDuringExecution verifies jobs can be canceled mid-execution.
func TestCancellationDuringExecution(t *testing.T) {
var canceled atomic.Bool
var started atomic.Bool
s := scheduler.New()
s.Register(&scheduler.Job{
Name: "long-job",
Schedule: "* * * * * *",
Handler: func(ctx context.Context) error {
started.Store(true)
// Simulate long-running work
for i := 0; i < 100; i++ {
select {
case <-ctx.Done():
canceled.Store(true)
return ctx.Err()
case <-time.After(100 * time.Millisecond):
// Keep working
}
}
return nil
},
})
if err := s.Start(); err != nil {
t.Fatalf("failed to start: %v", err)
}
// Wait until job starts
for i := 0; i < 30; i++ {
if started.Load() {
break
}
time.Sleep(100 * time.Millisecond)
}
if !started.Load() {
t.Fatal("job did not start within timeout")
}
// Stop with reasonable timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.Stop(ctx); err != nil {
t.Logf("stop returned error (may be expected): %v", err)
}
if !canceled.Load() {
t.Error("expected job to detect cancellation")
}
}
// TestTimezoneHandling verifies timezone-aware scheduling.
func TestTimezoneHandling(t *testing.T) {
// Parse a schedule in a specific timezone
schedule, err := scheduler.ParseCronExpression("0 9 * * *") // 9 AM
if err != nil {
t.Fatalf("failed to parse schedule: %v", err)
}
// Test in New York timezone
nyc, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatalf("failed to load timezone: %v", err)
}
// Current time: 8:30 AM in New York
now := time.Date(2025, 1, 15, 8, 30, 0, 0, nyc)
// Next run should be 9:00 AM same day
next := schedule.Next(now)
expected := time.Date(2025, 1, 15, 9, 0, 0, 0, nyc)
if !next.Equal(expected) {
t.Errorf("next = %v, expected %v", next, expected)
}
// If it's already past 9 AM
now = time.Date(2025, 1, 15, 9, 30, 0, 0, nyc)
next = schedule.Next(now)
expected = time.Date(2025, 1, 16, 9, 0, 0, 0, nyc)
if !next.Equal(expected) {
t.Errorf("next = %v, expected %v", next, expected)
}
}
// TestErrorPropagation verifies error handling.
func TestErrorPropagation(t *testing.T) {
var errorLogged atomic.Bool
logger := &testLogger{
onError: func(msg string, _ ...interface{}) {
if msg == "Job failed" {
errorLogged.Store(true)
}
},
}
s := scheduler.New(
scheduler.WithMiddleware(
scheduler.Logging(logger),
),
)
s.Register(&scheduler.Job{
Name: "failing-job",
Schedule: "* * * * * *",
Handler: func(_ context.Context) error {
return errors.New("intentional error")
},
})
if err := s.Start(); err != nil {
t.Fatalf("failed to start: %v", err)
}
// Let it run once
time.Sleep(1500 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.Stop(ctx); err != nil {
t.Fatalf("failed to stop: %v", err)
}
if !errorLogged.Load() {
t.Error("expected error to be logged")
}
}
// TestPanicRecovery verifies panic recovery middleware.
func TestPanicRecovery(t *testing.T) {
var panicRecovered atomic.Bool
s := scheduler.New(
scheduler.WithMiddleware(
scheduler.Recovery(func(jobName string, r interface{}) {
panicRecovered.Store(true)
t.Logf("Recovered from panic in job %s: %v", jobName, r)
}),
),
)
s.Register(&scheduler.Job{
Name: "panicking-job",
Schedule: "* * * * * *",
Handler: func(_ context.Context) error {
panic("intentional panic for testing")
},
})
if err := s.Start(); err != nil {
t.Fatalf("failed to start: %v", err)
}
// Let it run once
time.Sleep(1500 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.Stop(ctx); err != nil {
t.Fatalf("failed to stop: %v", err)
}
if !panicRecovered.Load() {
t.Error("expected panic to be recovered")
}
}
// TestMultipleJobsWithDifferentSchedules verifies concurrent job execution.
func TestMultipleJobsWithDifferentSchedules(t *testing.T) {
var (
job1Count atomic.Int32
job2Count atomic.Int32
job3Count atomic.Int32
)
s := scheduler.New()
// Job 1: Every second
s.Register(&scheduler.Job{
Name: "job-1sec",
Schedule: "* * * * * *",
Handler: func(_ context.Context) error {
job1Count.Add(1)
return nil
},
})
// Job 2: Every 2 seconds
s.Register(&scheduler.Job{
Name: "job-2sec",
Schedule: "*/2 * * * * *",
Handler: func(_ context.Context) error {
job2Count.Add(1)
return nil
},
})
// Job 3: Every 3 seconds
s.Register(&scheduler.Job{
Name: "job-3sec",
Schedule: "*/3 * * * * *",
Handler: func(_ context.Context) error {
job3Count.Add(1)
return nil
},
})
if err := s.Start(); err != nil {
t.Fatalf("failed to start: %v", err)
}
// Let them run for 6 seconds
time.Sleep(6 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.Stop(ctx); err != nil {
t.Fatalf("failed to stop: %v", err)
}
// Verify counts (allowing for timing variance)
c1 := job1Count.Load()
c2 := job2Count.Load()
c3 := job3Count.Load()
t.Logf("Job 1 ran %d times, Job 2 ran %d times, Job 3 ran %d times", c1, c2, c3)
if c1 < 5 {
t.Errorf("expected job1 to run at least 5 times, ran %d", c1)
}
if c2 < 2 {
t.Errorf("expected job2 to run at least 2 times, ran %d", c2)
}
if c3 < 1 {
t.Errorf("expected job3 to run at least 1 time, ran %d", c3)
}
}
// Helpers
type testLogger struct {
onInfo func(msg string, args ...interface{})
onError func(msg string, args ...interface{})
}
func (l *testLogger) Info(msg string, args ...interface{}) {
if l.onInfo != nil {
l.onInfo(msg, args...)
}
}
func (l *testLogger) Error(msg string, args ...interface{}) {
if l.onError != nil {
l.onError(msg, args...)
}
}
func contains(s, substr string) bool {
return strings.Contains(s, substr)
}

58
plugin/scheduler/job.go Normal file
View File

@@ -0,0 +1,58 @@
package scheduler
import (
"context"
"github.com/pkg/errors"
)
// JobHandler is the function signature for scheduled job handlers.
// The context passed to the handler will be canceled if the scheduler is shutting down.
type JobHandler func(ctx context.Context) error
// Job represents a scheduled task.
type Job struct {
// Name is a unique identifier for this job (required).
// Used for logging and metrics.
Name string
// Schedule is a cron expression defining when this job runs (required).
// Supports standard 5-field format: "minute hour day month weekday"
// Examples: "0 * * * *" (hourly), "0 0 * * *" (daily at midnight)
Schedule string
// Timezone for schedule evaluation (optional, defaults to UTC).
// Use IANA timezone names: "America/New_York", "Europe/London", etc.
Timezone string
// Handler is the function to execute when the job triggers (required).
Handler JobHandler
// Description provides human-readable context about what this job does (optional).
Description string
// Tags allow categorizing jobs for filtering/monitoring (optional).
Tags []string
}
// Validate checks if the job definition is valid.
func (j *Job) Validate() error {
if j.Name == "" {
return errors.New("job name is required")
}
if j.Schedule == "" {
return errors.New("job schedule is required")
}
// Validate cron expression using parser
if _, err := ParseCronExpression(j.Schedule); err != nil {
return errors.Wrap(err, "invalid cron expression")
}
if j.Handler == nil {
return errors.New("job handler is required")
}
return nil
}

View File

@@ -0,0 +1,90 @@
package scheduler
import (
"context"
"testing"
)
func TestJobDefinition(t *testing.T) {
callCount := 0
job := &Job{
Name: "test-job",
Handler: func(_ context.Context) error {
callCount++
return nil
},
}
if job.Name != "test-job" {
t.Errorf("expected name 'test-job', got %s", job.Name)
}
// Test handler execution
if err := job.Handler(context.Background()); err != nil {
t.Fatalf("handler failed: %v", err)
}
if callCount != 1 {
t.Errorf("expected handler to be called once, called %d times", callCount)
}
}
func TestJobValidation(t *testing.T) {
tests := []struct {
name string
job *Job
wantErr bool
}{
{
name: "valid job",
job: &Job{
Name: "valid",
Schedule: "0 * * * *",
Handler: func(_ context.Context) error { return nil },
},
wantErr: false,
},
{
name: "missing name",
job: &Job{
Schedule: "0 * * * *",
Handler: func(_ context.Context) error { return nil },
},
wantErr: true,
},
{
name: "missing schedule",
job: &Job{
Name: "test",
Handler: func(_ context.Context) error { return nil },
},
wantErr: true,
},
{
name: "invalid cron expression",
job: &Job{
Name: "test",
Schedule: "invalid cron",
Handler: func(_ context.Context) error { return nil },
},
wantErr: true,
},
{
name: "missing handler",
job: &Job{
Name: "test",
Schedule: "0 * * * *",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.job.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,120 @@
package scheduler
import (
"context"
"time"
"github.com/pkg/errors"
)
// Middleware wraps a JobHandler to add cross-cutting behavior.
type Middleware func(JobHandler) JobHandler
// Chain combines multiple middleware into a single middleware.
// Middleware are applied in the order they're provided (left to right).
func Chain(middlewares ...Middleware) Middleware {
return func(handler JobHandler) JobHandler {
// Apply middleware in reverse order so first middleware wraps outermost
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler)
}
return handler
}
}
// Recovery recovers from panics in job handlers and converts them to errors.
func Recovery(onPanic func(jobName string, recovered interface{})) Middleware {
return func(next JobHandler) JobHandler {
return func(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
jobName := getJobName(ctx)
if onPanic != nil {
onPanic(jobName, r)
}
err = errors.Errorf("job %q panicked: %v", jobName, r)
}
}()
return next(ctx)
}
}
}
// Logger is a minimal logging interface.
type Logger interface {
Info(msg string, args ...interface{})
Error(msg string, args ...interface{})
}
// Logging adds execution logging to jobs.
func Logging(logger Logger) Middleware {
return func(next JobHandler) JobHandler {
return func(ctx context.Context) error {
jobName := getJobName(ctx)
start := time.Now()
logger.Info("Job started", "job", jobName)
err := next(ctx)
duration := time.Since(start)
if err != nil {
logger.Error("Job failed", "job", jobName, "duration", duration, "error", err)
} else {
logger.Info("Job completed", "job", jobName, "duration", duration)
}
return err
}
}
}
// Timeout wraps a job handler with a timeout.
func Timeout(duration time.Duration) Middleware {
return func(next JobHandler) JobHandler {
return func(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, duration)
defer cancel()
done := make(chan error, 1)
go func() {
done <- next(ctx)
}()
select {
case err := <-done:
return err
case <-ctx.Done():
return errors.Errorf("job %q timed out after %v", getJobName(ctx), duration)
}
}
}
}
// Context keys for job metadata.
type contextKey int
const (
jobNameKey contextKey = iota
)
// withJobName adds the job name to the context.
func withJobName(ctx context.Context, name string) context.Context {
return context.WithValue(ctx, jobNameKey, name)
}
// getJobName retrieves the job name from the context.
func getJobName(ctx context.Context) string {
if name, ok := ctx.Value(jobNameKey).(string); ok {
return name
}
return "unknown"
}
// GetJobName retrieves the job name from the context (public API).
// Returns empty string if not found.
//
//nolint:revive // GetJobName is the public API, getJobName is internal
func GetJobName(ctx context.Context) string {
return getJobName(ctx)
}

View File

@@ -0,0 +1,146 @@
package scheduler
import (
"context"
"errors"
"sync/atomic"
"testing"
)
func TestMiddlewareChaining(t *testing.T) {
var order []string
mw1 := func(next JobHandler) JobHandler {
return func(ctx context.Context) error {
order = append(order, "before-1")
err := next(ctx)
order = append(order, "after-1")
return err
}
}
mw2 := func(next JobHandler) JobHandler {
return func(ctx context.Context) error {
order = append(order, "before-2")
err := next(ctx)
order = append(order, "after-2")
return err
}
}
handler := func(_ context.Context) error {
order = append(order, "handler")
return nil
}
chain := Chain(mw1, mw2)
wrapped := chain(handler)
if err := wrapped(context.Background()); err != nil {
t.Fatalf("wrapped handler failed: %v", err)
}
expected := []string{"before-1", "before-2", "handler", "after-2", "after-1"}
if len(order) != len(expected) {
t.Fatalf("expected %d calls, got %d", len(expected), len(order))
}
for i, want := range expected {
if order[i] != want {
t.Errorf("order[%d] = %q, want %q", i, order[i], want)
}
}
}
func TestRecoveryMiddleware(t *testing.T) {
var panicRecovered atomic.Bool
onPanic := func(_ string, _ interface{}) {
panicRecovered.Store(true)
}
handler := func(_ context.Context) error {
panic("simulated panic")
}
recovery := Recovery(onPanic)
wrapped := recovery(handler)
// Should not panic, error should be returned
err := wrapped(withJobName(context.Background(), "test-job"))
if err == nil {
t.Error("expected error from recovered panic")
}
if !panicRecovered.Load() {
t.Error("panic handler was not called")
}
}
func TestLoggingMiddleware(t *testing.T) {
var loggedStart, loggedEnd atomic.Bool
var loggedError atomic.Bool
logger := &testLogger{
onInfo: func(msg string, _ ...interface{}) {
if msg == "Job started" {
loggedStart.Store(true)
} else if msg == "Job completed" {
loggedEnd.Store(true)
}
},
onError: func(msg string, _ ...interface{}) {
if msg == "Job failed" {
loggedError.Store(true)
}
},
}
// Test successful execution
handler := func(_ context.Context) error {
return nil
}
logging := Logging(logger)
wrapped := logging(handler)
if err := wrapped(withJobName(context.Background(), "test-job")); err != nil {
t.Fatalf("handler failed: %v", err)
}
if !loggedStart.Load() {
t.Error("start was not logged")
}
if !loggedEnd.Load() {
t.Error("end was not logged")
}
// Test error handling
handlerErr := func(_ context.Context) error {
return errors.New("job error")
}
wrappedErr := logging(handlerErr)
_ = wrappedErr(withJobName(context.Background(), "test-job-error"))
if !loggedError.Load() {
t.Error("error was not logged")
}
}
type testLogger struct {
onInfo func(msg string, args ...interface{})
onError func(msg string, args ...interface{})
}
func (l *testLogger) Info(msg string, args ...interface{}) {
if l.onInfo != nil {
l.onInfo(msg, args...)
}
}
func (l *testLogger) Error(msg string, args ...interface{}) {
if l.onError != nil {
l.onError(msg, args...)
}
}

229
plugin/scheduler/parser.go Normal file
View File

@@ -0,0 +1,229 @@
package scheduler
import (
"strconv"
"strings"
"time"
"github.com/pkg/errors"
)
// Schedule represents a parsed cron expression.
type Schedule struct {
seconds fieldMatcher // 0-59 (optional, for 6-field format)
minutes fieldMatcher // 0-59
hours fieldMatcher // 0-23
days fieldMatcher // 1-31
months fieldMatcher // 1-12
weekdays fieldMatcher // 0-7 (0 and 7 are Sunday)
hasSecs bool
}
// fieldMatcher determines if a field value matches.
type fieldMatcher interface {
matches(value int) bool
}
// ParseCronExpression parses a cron expression and returns a Schedule.
// Supports both 5-field (minute hour day month weekday) and 6-field (second minute hour day month weekday) formats.
func ParseCronExpression(expr string) (*Schedule, error) {
if expr == "" {
return nil, errors.New("empty cron expression")
}
fields := strings.Fields(expr)
if len(fields) != 5 && len(fields) != 6 {
return nil, errors.Errorf("invalid cron expression: expected 5 or 6 fields, got %d", len(fields))
}
s := &Schedule{
hasSecs: len(fields) == 6,
}
var err error
offset := 0
// Parse seconds (if 6-field format)
if s.hasSecs {
s.seconds, err = parseField(fields[0], 0, 59)
if err != nil {
return nil, errors.Wrap(err, "invalid seconds field")
}
offset = 1
} else {
s.seconds = &exactMatcher{value: 0} // Default to 0 seconds
}
// Parse minutes
s.minutes, err = parseField(fields[offset], 0, 59)
if err != nil {
return nil, errors.Wrap(err, "invalid minutes field")
}
// Parse hours
s.hours, err = parseField(fields[offset+1], 0, 23)
if err != nil {
return nil, errors.Wrap(err, "invalid hours field")
}
// Parse days
s.days, err = parseField(fields[offset+2], 1, 31)
if err != nil {
return nil, errors.Wrap(err, "invalid days field")
}
// Parse months
s.months, err = parseField(fields[offset+3], 1, 12)
if err != nil {
return nil, errors.Wrap(err, "invalid months field")
}
// Parse weekdays (0-7, where both 0 and 7 represent Sunday)
s.weekdays, err = parseField(fields[offset+4], 0, 7)
if err != nil {
return nil, errors.Wrap(err, "invalid weekdays field")
}
return s, nil
}
// Next returns the next time the schedule should run after the given time.
func (s *Schedule) Next(from time.Time) time.Time {
// Start from the next second/minute
if s.hasSecs {
from = from.Add(1 * time.Second).Truncate(time.Second)
} else {
from = from.Add(1 * time.Minute).Truncate(time.Minute)
}
// Cap search at 4 years to prevent infinite loops
maxTime := from.AddDate(4, 0, 0)
for from.Before(maxTime) {
if s.matches(from) {
return from
}
// Advance to next potential match
if s.hasSecs {
from = from.Add(1 * time.Second)
} else {
from = from.Add(1 * time.Minute)
}
}
// Should never reach here with valid cron expressions
return time.Time{}
}
// matches checks if the given time matches the schedule.
func (s *Schedule) matches(t time.Time) bool {
return s.seconds.matches(t.Second()) &&
s.minutes.matches(t.Minute()) &&
s.hours.matches(t.Hour()) &&
s.months.matches(int(t.Month())) &&
(s.days.matches(t.Day()) || s.weekdays.matches(int(t.Weekday())))
}
// parseField parses a single cron field (supports *, ranges, lists, steps).
func parseField(field string, min, max int) (fieldMatcher, error) {
// Wildcard
if field == "*" {
return &wildcardMatcher{}, nil
}
// Step values (*/N)
if strings.HasPrefix(field, "*/") {
step, err := strconv.Atoi(field[2:])
if err != nil || step < 1 || step > max {
return nil, errors.Errorf("invalid step value: %s", field)
}
return &stepMatcher{step: step, min: min, max: max}, nil
}
// List (1,2,3)
if strings.Contains(field, ",") {
parts := strings.Split(field, ",")
values := make([]int, 0, len(parts))
for _, p := range parts {
val, err := strconv.Atoi(strings.TrimSpace(p))
if err != nil || val < min || val > max {
return nil, errors.Errorf("invalid list value: %s", p)
}
values = append(values, val)
}
return &listMatcher{values: values}, nil
}
// Range (1-5)
if strings.Contains(field, "-") {
parts := strings.Split(field, "-")
if len(parts) != 2 {
return nil, errors.Errorf("invalid range: %s", field)
}
start, err1 := strconv.Atoi(strings.TrimSpace(parts[0]))
end, err2 := strconv.Atoi(strings.TrimSpace(parts[1]))
if err1 != nil || err2 != nil || start < min || end > max || start > end {
return nil, errors.Errorf("invalid range: %s", field)
}
return &rangeMatcher{start: start, end: end}, nil
}
// Exact value
val, err := strconv.Atoi(field)
if err != nil || val < min || val > max {
return nil, errors.Errorf("invalid value: %s (must be between %d and %d)", field, min, max)
}
return &exactMatcher{value: val}, nil
}
// wildcardMatcher matches any value.
type wildcardMatcher struct{}
func (*wildcardMatcher) matches(_ int) bool {
return true
}
// exactMatcher matches a specific value.
type exactMatcher struct {
value int
}
func (m *exactMatcher) matches(value int) bool {
return value == m.value
}
// rangeMatcher matches values in a range.
type rangeMatcher struct {
start, end int
}
func (m *rangeMatcher) matches(value int) bool {
return value >= m.start && value <= m.end
}
// listMatcher matches any value in a list.
type listMatcher struct {
values []int
}
func (m *listMatcher) matches(value int) bool {
for _, v := range m.values {
if v == value {
return true
}
}
return false
}
// stepMatcher matches values at regular intervals.
type stepMatcher struct {
step, min, max int
}
func (m *stepMatcher) matches(value int) bool {
if value < m.min || value > m.max {
return false
}
return (value-m.min)%m.step == 0
}

View File

@@ -0,0 +1,127 @@
package scheduler
import (
"testing"
"time"
)
func TestParseCronExpression(t *testing.T) {
tests := []struct {
name string
expr string
wantErr bool
}{
// Standard 5-field format
{"every minute", "* * * * *", false},
{"hourly", "0 * * * *", false},
{"daily midnight", "0 0 * * *", false},
{"weekly sunday", "0 0 * * 0", false},
{"monthly", "0 0 1 * *", false},
{"specific time", "30 14 * * *", false}, // 2:30 PM daily
{"range", "0 9-17 * * *", false}, // Every hour 9 AM - 5 PM
{"step", "*/15 * * * *", false}, // Every 15 minutes
{"list", "0 8,12,18 * * *", false}, // 8 AM, 12 PM, 6 PM
// 6-field format with seconds
{"with seconds", "0 * * * * *", false},
{"every 30 seconds", "*/30 * * * * *", false},
// Invalid expressions
{"empty", "", true},
{"too few fields", "* * *", true},
{"too many fields", "* * * * * * *", true},
{"invalid minute", "60 * * * *", true},
{"invalid hour", "0 24 * * *", true},
{"invalid day", "0 0 32 * *", true},
{"invalid month", "0 0 1 13 *", true},
{"invalid weekday", "0 0 * * 8", true},
{"garbage", "not a cron expression", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
schedule, err := ParseCronExpression(tt.expr)
if (err != nil) != tt.wantErr {
t.Errorf("ParseCronExpression(%q) error = %v, wantErr %v", tt.expr, err, tt.wantErr)
return
}
if !tt.wantErr && schedule == nil {
t.Errorf("ParseCronExpression(%q) returned nil schedule without error", tt.expr)
}
})
}
}
func TestScheduleNext(t *testing.T) {
tests := []struct {
name string
expr string
from time.Time
expected time.Time
}{
{
name: "every minute from start of hour",
expr: "* * * * *",
from: time.Date(2025, 1, 1, 10, 0, 0, 0, time.UTC),
expected: time.Date(2025, 1, 1, 10, 1, 0, 0, time.UTC),
},
{
name: "hourly at minute 30",
expr: "30 * * * *",
from: time.Date(2025, 1, 1, 10, 0, 0, 0, time.UTC),
expected: time.Date(2025, 1, 1, 10, 30, 0, 0, time.UTC),
},
{
name: "hourly at minute 30 (already past)",
expr: "30 * * * *",
from: time.Date(2025, 1, 1, 10, 45, 0, 0, time.UTC),
expected: time.Date(2025, 1, 1, 11, 30, 0, 0, time.UTC),
},
{
name: "daily at 2 AM",
expr: "0 2 * * *",
from: time.Date(2025, 1, 1, 10, 0, 0, 0, time.UTC),
expected: time.Date(2025, 1, 2, 2, 0, 0, 0, time.UTC),
},
{
name: "every 15 minutes",
expr: "*/15 * * * *",
from: time.Date(2025, 1, 1, 10, 7, 0, 0, time.UTC),
expected: time.Date(2025, 1, 1, 10, 15, 0, 0, time.UTC),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
schedule, err := ParseCronExpression(tt.expr)
if err != nil {
t.Fatalf("failed to parse expression: %v", err)
}
next := schedule.Next(tt.from)
if !next.Equal(tt.expected) {
t.Errorf("Next(%v) = %v, expected %v", tt.from, next, tt.expected)
}
})
}
}
func TestScheduleNextWithTimezone(t *testing.T) {
nyc, _ := time.LoadLocation("America/New_York")
// Schedule for 9 AM in New York
schedule, err := ParseCronExpression("0 9 * * *")
if err != nil {
t.Fatalf("failed to parse expression: %v", err)
}
// Current time: 8 AM in New York
from := time.Date(2025, 1, 1, 8, 0, 0, 0, nyc)
next := schedule.Next(from)
// Should be 9 AM same day in New York
expected := time.Date(2025, 1, 1, 9, 0, 0, 0, nyc)
if !next.Equal(expected) {
t.Errorf("Next(%v) = %v, expected %v", from, next, expected)
}
}

View File

@@ -0,0 +1,202 @@
package scheduler
import (
"context"
"sync"
"time"
"github.com/pkg/errors"
)
// Scheduler manages scheduled jobs.
type Scheduler struct {
jobs map[string]*registeredJob
jobsMu sync.RWMutex
timezone *time.Location
middleware Middleware
running bool
runningMu sync.RWMutex
stopCh chan struct{}
wg sync.WaitGroup
}
// registeredJob wraps a Job with runtime state.
type registeredJob struct {
job *Job
cancelFn context.CancelFunc
}
// Option configures a Scheduler.
type Option func(*Scheduler)
// WithTimezone sets the default timezone for all jobs.
func WithTimezone(tz string) Option {
return func(s *Scheduler) {
loc, err := time.LoadLocation(tz)
if err != nil {
// Default to UTC on invalid timezone
loc = time.UTC
}
s.timezone = loc
}
}
// WithMiddleware sets middleware to wrap all job handlers.
func WithMiddleware(mw ...Middleware) Option {
return func(s *Scheduler) {
if len(mw) > 0 {
s.middleware = Chain(mw...)
}
}
}
// New creates a new Scheduler with optional configuration.
func New(opts ...Option) *Scheduler {
s := &Scheduler{
jobs: make(map[string]*registeredJob),
timezone: time.UTC,
stopCh: make(chan struct{}),
}
for _, opt := range opts {
opt(s)
}
return s
}
// Register adds a job to the scheduler.
// Jobs must be registered before calling Start().
func (s *Scheduler) Register(job *Job) error {
if job == nil {
return errors.New("job cannot be nil")
}
if err := job.Validate(); err != nil {
return errors.Wrap(err, "invalid job")
}
s.jobsMu.Lock()
defer s.jobsMu.Unlock()
if _, exists := s.jobs[job.Name]; exists {
return errors.Errorf("job with name %q already registered", job.Name)
}
s.jobs[job.Name] = &registeredJob{job: job}
return nil
}
// Start begins executing scheduled jobs.
func (s *Scheduler) Start() error {
s.runningMu.Lock()
defer s.runningMu.Unlock()
if s.running {
return errors.New("scheduler already running")
}
s.jobsMu.RLock()
defer s.jobsMu.RUnlock()
// Parse and schedule all jobs
for _, rj := range s.jobs {
schedule, err := ParseCronExpression(rj.job.Schedule)
if err != nil {
return errors.Wrapf(err, "failed to parse schedule for job %q", rj.job.Name)
}
ctx, cancel := context.WithCancel(context.Background())
rj.cancelFn = cancel
s.wg.Add(1)
go s.runJobWithSchedule(ctx, rj, schedule)
}
s.running = true
return nil
}
// runJobWithSchedule executes a job according to its cron schedule.
func (s *Scheduler) runJobWithSchedule(ctx context.Context, rj *registeredJob, schedule *Schedule) {
defer s.wg.Done()
// Apply middleware to handler
handler := rj.job.Handler
if s.middleware != nil {
handler = s.middleware(handler)
}
for {
// Calculate next run time
now := time.Now()
if rj.job.Timezone != "" {
loc, err := time.LoadLocation(rj.job.Timezone)
if err == nil {
now = now.In(loc)
}
} else if s.timezone != nil {
now = now.In(s.timezone)
}
next := schedule.Next(now)
duration := time.Until(next)
timer := time.NewTimer(duration)
select {
case <-timer.C:
// Add job name to context and execute
jobCtx := withJobName(ctx, rj.job.Name)
if err := handler(jobCtx); err != nil {
// Error already handled by middleware (if any)
_ = err
}
case <-ctx.Done():
// Stop the timer to prevent it from firing. The timer will be garbage collected.
timer.Stop()
return
case <-s.stopCh:
// Stop the timer to prevent it from firing. The timer will be garbage collected.
timer.Stop()
return
}
}
}
// Stop gracefully shuts down the scheduler.
// It waits for all running jobs to complete or until the context is canceled.
func (s *Scheduler) Stop(ctx context.Context) error {
s.runningMu.Lock()
if !s.running {
s.runningMu.Unlock()
return errors.New("scheduler not running")
}
s.running = false
s.runningMu.Unlock()
// Cancel all job contexts
s.jobsMu.RLock()
for _, rj := range s.jobs {
if rj.cancelFn != nil {
rj.cancelFn()
}
}
s.jobsMu.RUnlock()
// Signal stop and wait for jobs to finish
close(s.stopCh)
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

View File

@@ -0,0 +1,165 @@
package scheduler
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestSchedulerCreation(t *testing.T) {
s := New()
if s == nil {
t.Fatal("New() returned nil")
}
}
func TestSchedulerWithTimezone(t *testing.T) {
s := New(WithTimezone("America/New_York"))
if s == nil {
t.Fatal("New() with timezone returned nil")
}
}
func TestJobRegistration(t *testing.T) {
s := New()
job := &Job{
Name: "test-registration",
Schedule: "0 * * * *",
Handler: func(_ context.Context) error { return nil },
}
if err := s.Register(job); err != nil {
t.Fatalf("failed to register valid job: %v", err)
}
// Registering duplicate name should fail
if err := s.Register(job); err == nil {
t.Error("expected error when registering duplicate job name")
}
}
func TestSchedulerStartStop(t *testing.T) {
s := New()
var runCount atomic.Int32
job := &Job{
Name: "test-start-stop",
Schedule: "* * * * * *", // Every second (6-field format)
Handler: func(_ context.Context) error {
runCount.Add(1)
return nil
},
}
if err := s.Register(job); err != nil {
t.Fatalf("failed to register job: %v", err)
}
// Start scheduler
if err := s.Start(); err != nil {
t.Fatalf("failed to start scheduler: %v", err)
}
// Let it run for 2.5 seconds
time.Sleep(2500 * time.Millisecond)
// Stop scheduler
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.Stop(ctx); err != nil {
t.Fatalf("failed to stop scheduler: %v", err)
}
count := runCount.Load()
// Should have run at least twice (at 0s and 1s, maybe 2s)
if count < 2 {
t.Errorf("expected job to run at least 2 times, ran %d times", count)
}
// Verify it stopped (count shouldn't increase)
finalCount := count
time.Sleep(1500 * time.Millisecond)
if runCount.Load() != finalCount {
t.Error("scheduler did not stop - job continued running")
}
}
func TestSchedulerWithMiddleware(t *testing.T) {
var executionLog []string
var logMu sync.Mutex
logger := &testLogger{
onInfo: func(msg string, _ ...interface{}) {
logMu.Lock()
executionLog = append(executionLog, fmt.Sprintf("INFO: %s", msg))
logMu.Unlock()
},
onError: func(msg string, _ ...interface{}) {
logMu.Lock()
executionLog = append(executionLog, fmt.Sprintf("ERROR: %s", msg))
logMu.Unlock()
},
}
s := New(WithMiddleware(
Recovery(func(jobName string, r interface{}) {
logMu.Lock()
executionLog = append(executionLog, fmt.Sprintf("PANIC: %s - %v", jobName, r))
logMu.Unlock()
}),
Logging(logger),
))
job := &Job{
Name: "test-middleware",
Schedule: "* * * * * *", // Every second
Handler: func(_ context.Context) error {
return nil
},
}
if err := s.Register(job); err != nil {
t.Fatalf("failed to register job: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("failed to start: %v", err)
}
time.Sleep(1500 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.Stop(ctx); err != nil {
t.Fatalf("failed to stop: %v", err)
}
logMu.Lock()
defer logMu.Unlock()
// Should have at least one start and one completion log
hasStart := false
hasCompletion := false
for _, log := range executionLog {
if strings.Contains(log, "Job started") {
hasStart = true
}
if strings.Contains(log, "Job completed") {
hasCompletion = true
}
}
if !hasStart {
t.Error("expected job start log")
}
if !hasCompletion {
t.Error("expected job completion log")
}
}

118
plugin/storage/s3/s3.go Normal file
View File

@@ -0,0 +1,118 @@
package s3
import (
"context"
"io"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store"
)
type Client struct {
Client *s3.Client
Bucket *string
}
func NewClient(ctx context.Context, s3Config *storepb.StorageS3Config) (*Client, error) {
cfg, err := config.LoadDefaultConfig(ctx,
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(s3Config.AccessKeyId, s3Config.AccessKeySecret, "")),
config.WithRegion(s3Config.Region),
)
if err != nil {
return nil, errors.Wrap(err, "failed to load s3 config")
}
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.BaseEndpoint = aws.String(s3Config.Endpoint)
o.UsePathStyle = s3Config.UsePathStyle
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired
})
return &Client{
Client: client,
Bucket: aws.String(s3Config.Bucket),
}, nil
}
// UploadObject uploads an object to S3.
func (c *Client) UploadObject(ctx context.Context, key string, fileType string, content io.Reader) (string, error) {
uploader := manager.NewUploader(c.Client)
putInput := s3.PutObjectInput{
Bucket: c.Bucket,
Key: aws.String(key),
ContentType: aws.String(fileType),
Body: content,
}
result, err := uploader.Upload(ctx, &putInput)
if err != nil {
return "", err
}
resultKey := result.Key
if resultKey == nil || *resultKey == "" {
return "", errors.New("failed to get file key")
}
return *resultKey, nil
}
// PresignGetObject presigns an object in S3.
func (c *Client) PresignGetObject(ctx context.Context, key string) (string, error) {
presignClient := s3.NewPresignClient(c.Client)
presignResult, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(*c.Bucket),
Key: aws.String(key),
}, func(opts *s3.PresignOptions) {
// Set the expiration time of the presigned URL to 5 days.
// Reference: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html
opts.Expires = time.Duration(5 * 24 * time.Hour)
})
if err != nil {
return "", errors.Wrap(err, "failed to presign put object")
}
return presignResult.URL, nil
}
// GetObject retrieves an object from S3.
func (c *Client) GetObject(ctx context.Context, key string) ([]byte, error) {
downloader := manager.NewDownloader(c.Client)
buffer := manager.NewWriteAtBuffer([]byte{})
_, err := downloader.Download(ctx, buffer, &s3.GetObjectInput{
Bucket: c.Bucket,
Key: aws.String(key),
})
if err != nil {
return nil, errors.Wrap(err, "failed to download object")
}
return buffer.Bytes(), nil
}
// GetObjectStream retrieves an object from S3 as a stream.
func (c *Client) GetObjectStream(ctx context.Context, key string) (io.ReadCloser, error) {
output, err := c.Client.GetObject(ctx, &s3.GetObjectInput{
Bucket: c.Bucket,
Key: aws.String(key),
})
if err != nil {
return nil, errors.Wrap(err, "failed to get object")
}
return output.Body, nil
}
// DeleteObject deletes an object in S3.
func (c *Client) DeleteObject(ctx context.Context, key string) error {
_, err := c.Client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: c.Bucket,
Key: aws.String(key),
})
if err != nil {
return errors.Wrap(err, "failed to delete object")
}
return nil
}

View File

@@ -0,0 +1,75 @@
package webhook
import (
"net"
"net/url"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// reservedCIDRs lists IP ranges that must never be targeted by outbound webhook requests.
// Covers loopback, RFC-1918 private, link-local (including cloud IMDS at 169.254.169.254),
// and their IPv6 equivalents.
var reservedCIDRs = []string{
"127.0.0.0/8", // IPv4 loopback
"10.0.0.0/8", // RFC-1918 class A
"172.16.0.0/12", // RFC-1918 class B
"192.168.0.0/16", // RFC-1918 class C
"169.254.0.0/16", // Link-local / cloud IMDS
"::1/128", // IPv6 loopback
"fc00::/7", // IPv6 unique local
"fe80::/10", // IPv6 link-local
}
// reservedNetworks is the parsed form of reservedCIDRs, built once at startup.
var reservedNetworks []*net.IPNet
func init() {
for _, cidr := range reservedCIDRs {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
panic("webhook: invalid reserved CIDR " + cidr + ": " + err.Error())
}
reservedNetworks = append(reservedNetworks, network)
}
}
// isReservedIP reports whether ip falls within any reserved/private range.
func isReservedIP(ip net.IP) bool {
for _, network := range reservedNetworks {
if network.Contains(ip) {
return true
}
}
return false
}
// ValidateURL checks that rawURL:
// 1. Parses as a valid absolute URL.
// 2. Uses the http or https scheme.
// 3. Does not resolve to a reserved/private IP address.
//
// It returns a gRPC InvalidArgument status error so callers can return it directly.
func ValidateURL(rawURL string) error {
u, err := url.ParseRequestURI(rawURL)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid webhook URL: %v", err)
}
if u.Scheme != "http" && u.Scheme != "https" {
return status.Errorf(codes.InvalidArgument, "webhook URL must use http or https scheme, got %q", u.Scheme)
}
ips, err := net.LookupHost(u.Hostname())
if err != nil {
return status.Errorf(codes.InvalidArgument, "webhook URL hostname could not be resolved: %v", err)
}
for _, ipStr := range ips {
ip := net.ParseIP(ipStr)
if ip != nil && isReservedIP(ip) {
return status.Errorf(codes.InvalidArgument, "webhook URL must not resolve to a reserved or private IP address")
}
}
return nil
}

120
plugin/webhook/webhook.go Normal file
View File

@@ -0,0 +1,120 @@
package webhook
import (
"bytes"
"context"
"encoding/json"
"io"
"log/slog"
"net"
"net/http"
"time"
"github.com/pkg/errors"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
var (
// timeout is the timeout for webhook request. Default to 30 seconds.
timeout = 30 * time.Second
// safeClient is the shared HTTP client used for all webhook dispatches.
// Its Transport guards against SSRF by blocking connections to reserved/private
// IP addresses at dial time, which also defeats DNS rebinding attacks.
safeClient = &http.Client{
Timeout: timeout,
Transport: &http.Transport{
DialContext: safeDialContext,
},
}
)
// safeDialContext is a net.Dialer.DialContext replacement that resolves the target
// hostname and rejects any address that falls within a reserved/private IP range.
func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, errors.Errorf("webhook: invalid address %q", addr)
}
ips, err := net.DefaultResolver.LookupHost(ctx, host)
if err != nil {
return nil, errors.Wrapf(err, "webhook: failed to resolve host %q", host)
}
for _, ipStr := range ips {
if ip := net.ParseIP(ipStr); ip != nil && isReservedIP(ip) {
return nil, errors.Errorf("webhook: connection to reserved/private IP address is not allowed")
}
}
return (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(host, port))
}
type WebhookRequestPayload struct {
// The target URL for the webhook request.
URL string `json:"url"`
// The type of activity that triggered this webhook.
ActivityType string `json:"activityType"`
// The resource name of the creator. Format: users/{user}
Creator string `json:"creator"`
// The memo that triggered this webhook (if applicable).
Memo *v1pb.Memo `json:"memo"`
}
// Post posts the message to webhook endpoint.
func Post(requestPayload *WebhookRequestPayload) error {
body, err := json.Marshal(requestPayload)
if err != nil {
return errors.Wrapf(err, "failed to marshal webhook request to %s", requestPayload.URL)
}
req, err := http.NewRequest("POST", requestPayload.URL, bytes.NewBuffer(body))
if err != nil {
return errors.Wrapf(err, "failed to construct webhook request to %s", requestPayload.URL)
}
req.Header.Set("Content-Type", "application/json")
resp, err := safeClient.Do(req)
if err != nil {
return errors.Wrapf(err, "failed to post webhook to %s", requestPayload.URL)
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return errors.Wrapf(err, "failed to read webhook response from %s", requestPayload.URL)
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return errors.Errorf("failed to post webhook %s, status code: %d", requestPayload.URL, resp.StatusCode)
}
response := &struct {
Code int `json:"code"`
Message string `json:"message"`
}{}
if err := json.Unmarshal(b, response); err != nil {
return errors.Wrapf(err, "failed to unmarshal webhook response from %s", requestPayload.URL)
}
if response.Code != 0 {
return errors.Errorf("receive error code sent by webhook server, code %d, msg: %s", response.Code, response.Message)
}
return nil
}
// PostAsync posts the message to webhook endpoint asynchronously.
// It spawns a new goroutine to handle the request and does not wait for the response.
func PostAsync(requestPayload *WebhookRequestPayload) {
go func() {
if err := Post(requestPayload); err != nil {
slog.Warn("Failed to dispatch webhook asynchronously",
slog.String("url", requestPayload.URL),
slog.String("activityType", requestPayload.ActivityType),
slog.Any("err", err))
}
}()
}

View File

@@ -0,0 +1 @@
package webhook