first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled

This commit is contained in:
2026-03-04 06:30:47 +00:00
commit bb402d4ccc
777 changed files with 135661 additions and 0 deletions

View File

@@ -0,0 +1,50 @@
# Maintaining the Memo Filter Engine
The engine is memo-specific; any future field or behavior changes must stay
consistent with the memo schema and store implementations. Use this guide when
extending or debugging the package.
## Adding a New Memo Field
1. **Update the schema**
- Add the field entry in `schema.go`.
- Define the backing column (`Column`), JSON path (if applicable), type, and
allowed operators.
- Include the CEL variable in `EnvOptions`.
2. **Adjust parser or renderer (if needed)**
- For non-scalar fields (JSON booleans, lists), add handling in
`parser.go` or extend the renderer helpers.
- Keep validation in the parser (e.g., reject unsupported operators).
3. **Write a golden test**
- Extend the dialect-specific memo filter tests under
`store/db/{sqlite,mysql,postgres}/memo_filter_test.go` with a case that
exercises the new field.
4. **Run `go test ./...`** to ensure the SQL output matches expectations across
all dialects.
## Supporting Dialect Nuances
- Centralize differences inside `render.go`. If a new dialect-specific behavior
emerges (e.g., JSON operators), add the logic there rather than leaking it
into store code.
- Use the renderer helpers (`jsonExtractExpr`, `jsonArrayExpr`, etc.) rather than
sprinkling ad-hoc SQL strings.
- When placeholders change, adjust `addArg` so that argument numbering stays in
sync with store queries.
## Debugging Tips
- **Parser errors** Most originate in `buildCondition` or schema validation.
Enable logging around `parser.go` when diagnosing unknown identifier/operator
messages.
- **Renderer output** Temporary printf/log statements in `renderCondition` help
identify which IR node produced unexpected SQL.
- **Store integration** Ensure drivers call `filter.DefaultEngine()` exactly once
per process; the singleton caches the parsed CEL environment.
## Testing Checklist
- `go test ./store/...` ensures all dialect tests consume the engine correctly.
- Add targeted unit tests whenever new IR nodes or renderer paths are introduced.
- When changing boolean or JSON handling, verify all three dialect test suites
(SQLite, MySQL, Postgres) to avoid regression.

63
plugin/filter/README.md Normal file
View File

@@ -0,0 +1,63 @@
# Memo Filter Engine
This package houses the memo-only filter engine that turns CEL expressions into
SQL fragments. The engine follows a three phase pipeline inspired by systems
such as Calcite or Prisma:
1. **Parsing** CEL expressions are parsed with `cel-go` and validated against
the memo-specific environment declared in `schema.go`. Only fields that
exist in the schema can surface in the filter.
2. **Normalization** the raw CEL AST is converted into an intermediate
representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of
conditions (logical operators, comparisons, list membership, etc.). This
step enforces schema rules (e.g. operator compatibility, type checks).
3. **Rendering** the renderer in `render.go` walks the IR and produces a SQL
fragment plus placeholder arguments tailored to a target dialect
(`sqlite`, `mysql`, or `postgres`). Dialect differences such as JSON access,
boolean semantics, placeholders, and `LIKE` vs `ILIKE` are encapsulated in
renderer helpers.
The entry point is `filter.DefaultEngine()` from `engine.go`. It lazily constructs
an `Engine` configured with the memo schema and exposes:
```go
engine, _ := filter.DefaultEngine()
stmt, _ := engine.CompileToStatement(ctx, `has_task_list && visibility == "PUBLIC"`, filter.RenderOptions{
Dialect: filter.DialectPostgres,
})
// stmt.SQL -> "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.visibility = $1)"
// stmt.Args -> ["PUBLIC"]
```
## Core Files
| File | Responsibility |
| ------------- | ------------------------------------------------------------------------------- |
| `schema.go` | Declares memo fields, their types, backing columns, CEL environment options |
| `ir.go` | IR node definitions used across the pipeline |
| `parser.go` | Converts CEL `Expr` into IR while applying schema validation |
| `render.go` | Translates IR into SQL, handling dialect-specific behavior |
| `engine.go` | Glue between the phases; exposes `Compile`, `CompileToStatement`, and `DefaultEngine` |
| `helpers.go` | Convenience helpers for store integration (appending conditions) |
## SQL Generation Notes
- **Placeholders** — `?` is used for SQLite/MySQL, `$n` for Postgres. The renderer
tracks offsets to compose queries with pre-existing arguments.
- **JSON Fields** — Memo metadata lives in `memo.payload`. The renderer handles
`JSON_EXTRACT`/`json_extract`/`->`/`->>` variations and boolean coercion.
- **Tag Operations** — `tag in [...]` and `"tag" in tags` become JSON array
predicates. SQLite uses `LIKE` patterns, MySQL uses `JSON_CONTAINS`, and
Postgres uses `@>`.
- **Boolean Flags** — Fields such as `has_task_list` render as `IS TRUE` equality
checks, or comparisons against `CAST('true' AS JSON)` depending on the dialect.
## Typical Integration
1. Fetch the engine with `filter.DefaultEngine()`.
2. Call `CompileToStatement` using the appropriate dialect enum.
3. Append the emitted SQL fragment/args to the existing `WHERE` clause.
4. Execute the resulting query through the store driver.
The `helpers.AppendConditions` helper encapsulates steps 23 when a driver needs
to process an array of filters.

191
plugin/filter/engine.go Normal file
View File

@@ -0,0 +1,191 @@
package filter
import (
"context"
"fmt"
"strings"
"sync"
"github.com/google/cel-go/cel"
"github.com/pkg/errors"
)
// Engine parses CEL filters into a dialect-agnostic condition tree.
type Engine struct {
schema Schema
env *cel.Env
}
// NewEngine builds a new Engine for the provided schema.
func NewEngine(schema Schema) (*Engine, error) {
env, err := cel.NewEnv(schema.EnvOptions...)
if err != nil {
return nil, errors.Wrap(err, "failed to create CEL environment")
}
return &Engine{
schema: schema,
env: env,
}, nil
}
// Program stores a compiled filter condition.
type Program struct {
schema Schema
condition Condition
}
// ConditionTree exposes the underlying condition tree.
func (p *Program) ConditionTree() Condition {
return p.condition
}
// Compile parses the filter string into an executable program.
func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
if strings.TrimSpace(filter) == "" {
return nil, errors.New("filter expression is empty")
}
filter = normalizeLegacyFilter(filter)
ast, issues := e.env.Compile(filter)
if issues != nil && issues.Err() != nil {
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
}
parsed, err := cel.AstToParsedExpr(ast)
if err != nil {
return nil, errors.Wrap(err, "failed to convert AST")
}
cond, err := buildCondition(parsed.GetExpr(), e.schema)
if err != nil {
return nil, err
}
return &Program{
schema: e.schema,
condition: cond,
}, nil
}
// CompileToStatement compiles and renders the filter in a single step.
func (e *Engine) CompileToStatement(ctx context.Context, filter string, opts RenderOptions) (Statement, error) {
program, err := e.Compile(ctx, filter)
if err != nil {
return Statement{}, err
}
return program.Render(opts)
}
// RenderOptions configure SQL rendering.
type RenderOptions struct {
Dialect DialectName
PlaceholderOffset int
DisableNullChecks bool
}
// Statement contains the rendered SQL fragment and its args.
type Statement struct {
SQL string
Args []any
}
// Render converts the program into a dialect-specific SQL fragment.
func (p *Program) Render(opts RenderOptions) (Statement, error) {
renderer := newRenderer(p.schema, opts)
return renderer.Render(p.condition)
}
var (
defaultOnce sync.Once
defaultInst *Engine
defaultErr error
defaultAttachmentOnce sync.Once
defaultAttachmentInst *Engine
defaultAttachmentErr error
)
// DefaultEngine returns the process-wide memo filter engine.
func DefaultEngine() (*Engine, error) {
defaultOnce.Do(func() {
defaultInst, defaultErr = NewEngine(NewSchema())
})
return defaultInst, defaultErr
}
// DefaultAttachmentEngine returns the process-wide attachment filter engine.
func DefaultAttachmentEngine() (*Engine, error) {
defaultAttachmentOnce.Do(func() {
defaultAttachmentInst, defaultAttachmentErr = NewEngine(NewAttachmentSchema())
})
return defaultAttachmentInst, defaultAttachmentErr
}
func normalizeLegacyFilter(expr string) string {
expr = rewriteNumericLogicalOperand(expr, "&&")
expr = rewriteNumericLogicalOperand(expr, "||")
return expr
}
func rewriteNumericLogicalOperand(expr, op string) string {
var builder strings.Builder
n := len(expr)
i := 0
var inQuote rune
for i < n {
ch := expr[i]
if inQuote != 0 {
builder.WriteByte(ch)
if ch == '\\' && i+1 < n {
builder.WriteByte(expr[i+1])
i += 2
continue
}
if ch == byte(inQuote) {
inQuote = 0
}
i++
continue
}
if ch == '\'' || ch == '"' {
inQuote = rune(ch)
builder.WriteByte(ch)
i++
continue
}
if strings.HasPrefix(expr[i:], op) {
builder.WriteString(op)
i += len(op)
// Preserve whitespace following the operator.
wsStart := i
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
i++
}
builder.WriteString(expr[wsStart:i])
signStart := i
if i < n && (expr[i] == '+' || expr[i] == '-') {
i++
}
for i < n && expr[i] >= '0' && expr[i] <= '9' {
i++
}
if i > signStart {
numLiteral := expr[signStart:i]
builder.WriteString(fmt.Sprintf("(%s != 0)", numLiteral))
} else {
builder.WriteString(expr[signStart:i])
}
continue
}
builder.WriteByte(ch)
i++
}
return builder.String()
}

25
plugin/filter/helpers.go Normal file
View File

@@ -0,0 +1,25 @@
package filter
import (
"context"
"fmt"
)
// AppendConditions compiles the provided filters and appends the resulting SQL fragments and args.
func AppendConditions(ctx context.Context, engine *Engine, filters []string, dialect DialectName, where *[]string, args *[]any) error {
for _, filterStr := range filters {
stmt, err := engine.CompileToStatement(ctx, filterStr, RenderOptions{
Dialect: dialect,
PlaceholderOffset: len(*args),
})
if err != nil {
return err
}
if stmt.SQL == "" {
continue
}
*where = append(*where, fmt.Sprintf("(%s)", stmt.SQL))
*args = append(*args, stmt.Args...)
}
return nil
}

159
plugin/filter/ir.go Normal file
View File

@@ -0,0 +1,159 @@
package filter
// Condition represents a boolean expression derived from the CEL filter.
type Condition interface {
isCondition()
}
// LogicalOperator enumerates the supported logical operators.
type LogicalOperator string
const (
LogicalAnd LogicalOperator = "AND"
LogicalOr LogicalOperator = "OR"
)
// LogicalCondition composes two conditions with a logical operator.
type LogicalCondition struct {
Operator LogicalOperator
Left Condition
Right Condition
}
func (*LogicalCondition) isCondition() {}
// NotCondition negates a child condition.
type NotCondition struct {
Expr Condition
}
func (*NotCondition) isCondition() {}
// FieldPredicateCondition asserts that a field evaluates to true.
type FieldPredicateCondition struct {
Field string
}
func (*FieldPredicateCondition) isCondition() {}
// ComparisonOperator lists supported comparison operators.
type ComparisonOperator string
const (
CompareEq ComparisonOperator = "="
CompareNeq ComparisonOperator = "!="
CompareLt ComparisonOperator = "<"
CompareLte ComparisonOperator = "<="
CompareGt ComparisonOperator = ">"
CompareGte ComparisonOperator = ">="
)
// ComparisonCondition represents a binary comparison.
type ComparisonCondition struct {
Left ValueExpr
Operator ComparisonOperator
Right ValueExpr
}
func (*ComparisonCondition) isCondition() {}
// InCondition represents an IN predicate with literal list values.
type InCondition struct {
Left ValueExpr
Values []ValueExpr
}
func (*InCondition) isCondition() {}
// ElementInCondition represents the CEL syntax `"value" in field`.
type ElementInCondition struct {
Element ValueExpr
Field string
}
func (*ElementInCondition) isCondition() {}
// ContainsCondition models the <field>.contains(<value>) call.
type ContainsCondition struct {
Field string
Value string
}
func (*ContainsCondition) isCondition() {}
// ConstantCondition captures a literal boolean outcome.
type ConstantCondition struct {
Value bool
}
func (*ConstantCondition) isCondition() {}
// ValueExpr models arithmetic or scalar expressions whose result feeds a comparison.
type ValueExpr interface {
isValueExpr()
}
// FieldRef references a named schema field.
type FieldRef struct {
Name string
}
func (*FieldRef) isValueExpr() {}
// LiteralValue holds a literal scalar.
type LiteralValue struct {
Value interface{}
}
func (*LiteralValue) isValueExpr() {}
// FunctionValue captures simple function calls like size(tags).
type FunctionValue struct {
Name string
Args []ValueExpr
}
func (*FunctionValue) isValueExpr() {}
// ListComprehensionCondition represents CEL macros like exists(), all(), filter().
type ListComprehensionCondition struct {
Kind ComprehensionKind
Field string // The list field to iterate over (e.g., "tags")
IterVar string // The iteration variable name (e.g., "t")
Predicate PredicateExpr // The predicate to evaluate for each element
}
func (*ListComprehensionCondition) isCondition() {}
// ComprehensionKind enumerates the types of list comprehensions.
type ComprehensionKind string
const (
ComprehensionExists ComprehensionKind = "exists"
)
// PredicateExpr represents predicates used in comprehensions.
type PredicateExpr interface {
isPredicateExpr()
}
// StartsWithPredicate represents t.startsWith("prefix").
type StartsWithPredicate struct {
Prefix string
}
func (*StartsWithPredicate) isPredicateExpr() {}
// EndsWithPredicate represents t.endsWith("suffix").
type EndsWithPredicate struct {
Suffix string
}
func (*EndsWithPredicate) isPredicateExpr() {}
// ContainsPredicate represents t.contains("substring").
type ContainsPredicate struct {
Substring string
}
func (*ContainsPredicate) isPredicateExpr() {}

586
plugin/filter/parser.go Normal file
View File

@@ -0,0 +1,586 @@
package filter
import (
"time"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
switch v := expr.ExprKind.(type) {
case *exprv1.Expr_CallExpr:
return buildCallCondition(v.CallExpr, schema)
case *exprv1.Expr_ConstExpr:
val, err := getConstValue(expr)
if err != nil {
return nil, err
}
switch v := val.(type) {
case bool:
return &ConstantCondition{Value: v}, nil
case int64:
return &ConstantCondition{Value: v != 0}, nil
case float64:
return &ConstantCondition{Value: v != 0}, nil
default:
return nil, errors.New("filter must evaluate to a boolean value")
}
case *exprv1.Expr_IdentExpr:
name := v.IdentExpr.GetName()
field, ok := schema.Field(name)
if !ok {
return nil, errors.Errorf("unknown identifier %q", name)
}
if field.Type != FieldTypeBool {
return nil, errors.Errorf("identifier %q is not boolean", name)
}
return &FieldPredicateCondition{Field: name}, nil
case *exprv1.Expr_ComprehensionExpr:
return buildComprehensionCondition(v.ComprehensionExpr, schema)
default:
return nil, errors.New("unsupported top-level expression")
}
}
func buildCallCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
switch call.Function {
case "_&&_":
if len(call.Args) != 2 {
return nil, errors.New("logical AND expects two arguments")
}
left, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildCondition(call.Args[1], schema)
if err != nil {
return nil, err
}
return &LogicalCondition{
Operator: LogicalAnd,
Left: left,
Right: right,
}, nil
case "_||_":
if len(call.Args) != 2 {
return nil, errors.New("logical OR expects two arguments")
}
left, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildCondition(call.Args[1], schema)
if err != nil {
return nil, err
}
return &LogicalCondition{
Operator: LogicalOr,
Left: left,
Right: right,
}, nil
case "!_":
if len(call.Args) != 1 {
return nil, errors.New("logical NOT expects one argument")
}
child, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
return &NotCondition{Expr: child}, nil
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
return buildComparisonCondition(call, schema)
case "@in":
return buildInCondition(call, schema)
case "contains":
return buildContainsCondition(call, schema)
default:
val, ok, err := evaluateBool(call)
if err != nil {
return nil, err
}
if ok {
return &ConstantCondition{Value: val}, nil
}
return nil, errors.Errorf("unsupported call expression %q", call.Function)
}
}
func buildComparisonCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if len(call.Args) != 2 {
return nil, errors.New("comparison expects two arguments")
}
op, err := toComparisonOperator(call.Function)
if err != nil {
return nil, err
}
left, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildValueExpr(call.Args[1], schema)
if err != nil {
return nil, err
}
// If the left side is a field, validate allowed operators.
if field, ok := left.(*FieldRef); ok {
def, exists := schema.Field(field.Name)
if !exists {
return nil, errors.Errorf("unknown identifier %q", field.Name)
}
if def.Kind == FieldKindVirtualAlias {
def, exists = schema.ResolveAlias(field.Name)
if !exists {
return nil, errors.Errorf("invalid alias %q", field.Name)
}
}
if def.AllowedComparisonOps != nil {
if _, allowed := def.AllowedComparisonOps[op]; !allowed {
return nil, errors.Errorf("operator %s not allowed for field %q", op, field.Name)
}
}
}
return &ComparisonCondition{
Left: left,
Operator: op,
Right: right,
}, nil
}
func buildInCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if len(call.Args) != 2 {
return nil, errors.New("in operator expects two arguments")
}
// Handle identifier in list syntax.
if identName, err := getIdentName(call.Args[0]); err == nil {
if field, ok := schema.Field(identName); ok && field.Kind == FieldKindVirtualAlias {
if _, aliasOk := schema.ResolveAlias(identName); !aliasOk {
return nil, errors.Errorf("invalid alias %q", identName)
}
} else if !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
if listExpr := call.Args[1].GetListExpr(); listExpr != nil {
values := make([]ValueExpr, 0, len(listExpr.Elements))
for _, element := range listExpr.Elements {
value, err := buildValueExpr(element, schema)
if err != nil {
return nil, err
}
values = append(values, value)
}
return &InCondition{
Left: &FieldRef{Name: identName},
Values: values,
}, nil
}
}
// Handle "value in identifier" syntax.
if identName, err := getIdentName(call.Args[1]); err == nil {
if _, ok := schema.Field(identName); !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
element, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
return &ElementInCondition{
Element: element,
Field: identName,
}, nil
}
return nil, errors.New("invalid use of in operator")
}
func buildContainsCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if call.Target == nil {
return nil, errors.New("contains requires a target")
}
targetName, err := getIdentName(call.Target)
if err != nil {
return nil, err
}
field, ok := schema.Field(targetName)
if !ok {
return nil, errors.Errorf("unknown identifier %q", targetName)
}
if !field.SupportsContains {
return nil, errors.Errorf("identifier %q does not support contains()", targetName)
}
if len(call.Args) != 1 {
return nil, errors.New("contains expects exactly one argument")
}
value, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "contains only supports literal arguments")
}
str, ok := value.(string)
if !ok {
return nil, errors.New("contains argument must be a string")
}
return &ContainsCondition{
Field: targetName,
Value: str,
}, nil
}
func buildValueExpr(expr *exprv1.Expr, schema Schema) (ValueExpr, error) {
if identName, err := getIdentName(expr); err == nil {
if _, ok := schema.Field(identName); !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
return &FieldRef{Name: identName}, nil
}
if literal, err := getConstValue(expr); err == nil {
return &LiteralValue{Value: literal}, nil
}
if value, ok, err := evaluateNumeric(expr); err != nil {
return nil, err
} else if ok {
return &LiteralValue{Value: value}, nil
}
if boolVal, ok, err := evaluateBoolExpr(expr); err != nil {
return nil, err
} else if ok {
return &LiteralValue{Value: boolVal}, nil
}
if call := expr.GetCallExpr(); call != nil {
switch call.Function {
case "size":
if len(call.Args) != 1 {
return nil, errors.New("size() expects one argument")
}
arg, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
return &FunctionValue{
Name: "size",
Args: []ValueExpr{arg},
}, nil
case "now":
return &LiteralValue{Value: timeNowUnix()}, nil
case "_+_", "_-_", "_*_":
value, ok, err := evaluateNumeric(expr)
if err != nil {
return nil, err
}
if ok {
return &LiteralValue{Value: value}, nil
}
default:
// Fall through to error return below
}
}
return nil, errors.New("unsupported value expression")
}
func toComparisonOperator(fn string) (ComparisonOperator, error) {
switch fn {
case "_==_":
return CompareEq, nil
case "_!=_":
return CompareNeq, nil
case "_<_":
return CompareLt, nil
case "_>_":
return CompareGt, nil
case "_<=_":
return CompareLte, nil
case "_>=_":
return CompareGte, nil
default:
return "", errors.Errorf("unsupported comparison operator %q", fn)
}
}
func getIdentName(expr *exprv1.Expr) (string, error) {
if ident := expr.GetIdentExpr(); ident != nil {
return ident.GetName(), nil
}
return "", errors.New("expression is not an identifier")
}
func getConstValue(expr *exprv1.Expr) (interface{}, error) {
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
if !ok {
return nil, errors.New("expression is not a literal")
}
switch x := v.ConstExpr.ConstantKind.(type) {
case *exprv1.Constant_StringValue:
return v.ConstExpr.GetStringValue(), nil
case *exprv1.Constant_Int64Value:
return v.ConstExpr.GetInt64Value(), nil
case *exprv1.Constant_Uint64Value:
return int64(v.ConstExpr.GetUint64Value()), nil
case *exprv1.Constant_DoubleValue:
return v.ConstExpr.GetDoubleValue(), nil
case *exprv1.Constant_BoolValue:
return v.ConstExpr.GetBoolValue(), nil
case *exprv1.Constant_NullValue:
return nil, nil
default:
return nil, errors.Errorf("unsupported constant %T", x)
}
}
func evaluateBool(call *exprv1.Expr_Call) (bool, bool, error) {
val, ok, err := evaluateBoolExpr(&exprv1.Expr{ExprKind: &exprv1.Expr_CallExpr{CallExpr: call}})
return val, ok, err
}
func evaluateBoolExpr(expr *exprv1.Expr) (bool, bool, error) {
if literal, err := getConstValue(expr); err == nil {
if b, ok := literal.(bool); ok {
return b, true, nil
}
return false, false, nil
}
if call := expr.GetCallExpr(); call != nil && call.Function == "!_" {
if len(call.Args) != 1 {
return false, false, errors.New("NOT expects exactly one argument")
}
val, ok, err := evaluateBoolExpr(call.Args[0])
if err != nil || !ok {
return false, false, err
}
return !val, true, nil
}
return false, false, nil
}
func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) {
if literal, err := getConstValue(expr); err == nil {
switch v := literal.(type) {
case int64:
return v, true, nil
case float64:
return int64(v), true, nil
}
return 0, false, nil
}
call := expr.GetCallExpr()
if call == nil {
return 0, false, nil
}
switch call.Function {
case "now":
return timeNowUnix(), true, nil
case "_+_", "_-_", "_*_":
if len(call.Args) != 2 {
return 0, false, errors.New("arithmetic requires two arguments")
}
left, ok, err := evaluateNumeric(call.Args[0])
if err != nil {
return 0, false, err
}
if !ok {
return 0, false, nil
}
right, ok, err := evaluateNumeric(call.Args[1])
if err != nil {
return 0, false, err
}
if !ok {
return 0, false, nil
}
switch call.Function {
case "_+_":
return left + right, true, nil
case "_-_":
return left - right, true, nil
case "_*_":
return left * right, true, nil
default:
return 0, false, errors.Errorf("unsupported arithmetic operator %q", call.Function)
}
default:
return 0, false, nil
}
}
func timeNowUnix() int64 {
return time.Now().Unix()
}
// buildComprehensionCondition handles CEL comprehension expressions (exists, all, etc.).
func buildComprehensionCondition(comp *exprv1.Expr_Comprehension, schema Schema) (Condition, error) {
// Determine the comprehension kind by examining the loop initialization and step
kind, err := detectComprehensionKind(comp)
if err != nil {
return nil, err
}
// Get the field being iterated over
iterRangeIdent := comp.IterRange.GetIdentExpr()
if iterRangeIdent == nil {
return nil, errors.New("comprehension range must be a field identifier")
}
fieldName := iterRangeIdent.GetName()
// Validate the field
field, ok := schema.Field(fieldName)
if !ok {
return nil, errors.Errorf("unknown field %q in comprehension", fieldName)
}
if field.Kind != FieldKindJSONList {
return nil, errors.Errorf("field %q does not support comprehension (must be a list)", fieldName)
}
// Extract the predicate from the loop step
predicate, err := extractPredicate(comp, schema)
if err != nil {
return nil, err
}
return &ListComprehensionCondition{
Kind: kind,
Field: fieldName,
IterVar: comp.IterVar,
Predicate: predicate,
}, nil
}
// detectComprehensionKind determines if this is an exists() macro.
// Only exists() is currently supported.
func detectComprehensionKind(comp *exprv1.Expr_Comprehension) (ComprehensionKind, error) {
// Check the accumulator initialization
accuInit := comp.AccuInit.GetConstExpr()
if accuInit == nil {
return "", errors.New("comprehension accumulator must be initialized with a constant")
}
// exists() starts with false and uses OR (||) in loop step
if !accuInit.GetBoolValue() {
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_||_" {
return ComprehensionExists, nil
}
}
// all() starts with true and uses AND (&&) - not supported
if accuInit.GetBoolValue() {
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_&&_" {
return "", errors.New("all() comprehension is not supported; use exists() instead")
}
}
return "", errors.New("unsupported comprehension type; only exists() is supported")
}
// extractPredicate extracts the predicate expression from the comprehension loop step.
func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, error) {
// The loop step is: @result || predicate(t) for exists
// or: @result && predicate(t) for all
step := comp.LoopStep.GetCallExpr()
if step == nil {
return nil, errors.New("comprehension loop step must be a call expression")
}
if len(step.Args) != 2 {
return nil, errors.New("comprehension loop step must have two arguments")
}
// The predicate is the second argument
predicateExpr := step.Args[1]
predicateCall := predicateExpr.GetCallExpr()
if predicateCall == nil {
return nil, errors.New("comprehension predicate must be a function call")
}
// Handle different predicate functions
switch predicateCall.Function {
case "startsWith":
return buildStartsWithPredicate(predicateCall, comp.IterVar)
case "endsWith":
return buildEndsWithPredicate(predicateCall, comp.IterVar)
case "contains":
return buildContainsPredicate(predicateCall, comp.IterVar)
default:
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
}
}
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").
func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
// Verify the target is the iteration variable
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
return nil, errors.Errorf("startsWith target must be the iteration variable %q", iterVar)
}
if len(call.Args) != 1 {
return nil, errors.New("startsWith expects exactly one argument")
}
prefix, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "startsWith argument must be a constant string")
}
prefixStr, ok := prefix.(string)
if !ok {
return nil, errors.New("startsWith argument must be a string")
}
return &StartsWithPredicate{Prefix: prefixStr}, nil
}
// buildEndsWithPredicate extracts the pattern from t.endsWith("suffix").
func buildEndsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
return nil, errors.Errorf("endsWith target must be the iteration variable %q", iterVar)
}
if len(call.Args) != 1 {
return nil, errors.New("endsWith expects exactly one argument")
}
suffix, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "endsWith argument must be a constant string")
}
suffixStr, ok := suffix.(string)
if !ok {
return nil, errors.New("endsWith argument must be a string")
}
return &EndsWithPredicate{Suffix: suffixStr}, nil
}
// buildContainsPredicate extracts the pattern from t.contains("substring").
func buildContainsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
return nil, errors.Errorf("contains target must be the iteration variable %q", iterVar)
}
if len(call.Args) != 1 {
return nil, errors.New("contains expects exactly one argument")
}
substring, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "contains argument must be a constant string")
}
substringStr, ok := substring.(string)
if !ok {
return nil, errors.New("contains argument must be a string")
}
return &ContainsPredicate{Substring: substringStr}, nil
}

748
plugin/filter/render.go Normal file
View File

@@ -0,0 +1,748 @@
package filter
import (
"fmt"
"strings"
"github.com/pkg/errors"
)
type renderer struct {
schema Schema
dialect DialectName
placeholderOffset int
placeholderCounter int
args []any
}
type renderResult struct {
sql string
trivial bool
unsatisfiable bool
}
func newRenderer(schema Schema, opts RenderOptions) *renderer {
return &renderer{
schema: schema,
dialect: opts.Dialect,
placeholderOffset: opts.PlaceholderOffset,
}
}
func (r *renderer) Render(cond Condition) (Statement, error) {
result, err := r.renderCondition(cond)
if err != nil {
return Statement{}, err
}
args := r.args
if args == nil {
args = []any{}
}
switch {
case result.unsatisfiable:
return Statement{
SQL: "1 = 0",
Args: args,
}, nil
case result.trivial:
return Statement{
SQL: "",
Args: args,
}, nil
default:
return Statement{
SQL: result.sql,
Args: args,
}, nil
}
}
func (r *renderer) renderCondition(cond Condition) (renderResult, error) {
switch c := cond.(type) {
case *LogicalCondition:
return r.renderLogicalCondition(c)
case *NotCondition:
return r.renderNotCondition(c)
case *FieldPredicateCondition:
return r.renderFieldPredicate(c)
case *ComparisonCondition:
return r.renderComparison(c)
case *InCondition:
return r.renderInCondition(c)
case *ElementInCondition:
return r.renderElementInCondition(c)
case *ContainsCondition:
return r.renderContainsCondition(c)
case *ListComprehensionCondition:
return r.renderListComprehension(c)
case *ConstantCondition:
if c.Value {
return renderResult{trivial: true}, nil
}
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
default:
return renderResult{}, errors.Errorf("unsupported condition type %T", c)
}
}
func (r *renderer) renderLogicalCondition(cond *LogicalCondition) (renderResult, error) {
left, err := r.renderCondition(cond.Left)
if err != nil {
return renderResult{}, err
}
right, err := r.renderCondition(cond.Right)
if err != nil {
return renderResult{}, err
}
switch cond.Operator {
case LogicalAnd:
return combineAnd(left, right), nil
case LogicalOr:
return combineOr(left, right), nil
default:
return renderResult{}, errors.Errorf("unsupported logical operator %s", cond.Operator)
}
}
func (r *renderer) renderNotCondition(cond *NotCondition) (renderResult, error) {
child, err := r.renderCondition(cond.Expr)
if err != nil {
return renderResult{}, err
}
if child.trivial {
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
}
if child.unsatisfiable {
return renderResult{trivial: true}, nil
}
return renderResult{
sql: fmt.Sprintf("NOT (%s)", child.sql),
}, nil
}
func (r *renderer) renderFieldPredicate(cond *FieldPredicateCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
switch field.Kind {
case FieldKindBoolColumn:
column := qualifyColumn(r.dialect, field.Column)
return renderResult{
sql: fmt.Sprintf("%s IS TRUE", column),
}, nil
case FieldKindJSONBool:
sql, err := r.jsonBoolPredicate(field)
if err != nil {
return renderResult{}, err
}
return renderResult{sql: sql}, nil
default:
return renderResult{}, errors.Errorf("field %q cannot be used as a predicate", cond.Field)
}
}
func (r *renderer) renderComparison(cond *ComparisonCondition) (renderResult, error) {
switch left := cond.Left.(type) {
case *FieldRef:
field, ok := r.schema.Field(left.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", left.Name)
}
switch field.Kind {
case FieldKindBoolColumn:
return r.renderBoolColumnComparison(field, cond.Operator, cond.Right)
case FieldKindJSONBool:
return r.renderJSONBoolComparison(field, cond.Operator, cond.Right)
case FieldKindScalar:
return r.renderScalarComparison(field, cond.Operator, cond.Right)
default:
return renderResult{}, errors.Errorf("field %q does not support comparison", field.Name)
}
case *FunctionValue:
return r.renderFunctionComparison(left, cond.Operator, cond.Right)
default:
return renderResult{}, errors.New("comparison must start with a field reference or supported function")
}
}
func (r *renderer) renderFunctionComparison(fn *FunctionValue, op ComparisonOperator, right ValueExpr) (renderResult, error) {
if fn.Name != "size" {
return renderResult{}, errors.Errorf("unsupported function %s in comparison", fn.Name)
}
if len(fn.Args) != 1 {
return renderResult{}, errors.New("size() expects one argument")
}
fieldArg, ok := fn.Args[0].(*FieldRef)
if !ok {
return renderResult{}, errors.New("size() argument must be a field")
}
field, ok := r.schema.Field(fieldArg.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", fieldArg.Name)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("size() only supports tag lists, got %q", field.Name)
}
value, err := expectNumericLiteral(right)
if err != nil {
return renderResult{}, err
}
expr := jsonArrayLengthExpr(r.dialect, field)
placeholder := r.addArg(value)
return renderResult{
sql: fmt.Sprintf("%s %s %s", expr, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderScalarComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
lit, err := expectLiteral(right)
if err != nil {
return renderResult{}, err
}
columnExpr := field.columnExpr(r.dialect)
if lit == nil {
switch op {
case CompareEq:
return renderResult{sql: fmt.Sprintf("%s IS NULL", columnExpr)}, nil
case CompareNeq:
return renderResult{sql: fmt.Sprintf("%s IS NOT NULL", columnExpr)}, nil
default:
return renderResult{}, errors.Errorf("operator %s not supported for null comparison", op)
}
}
placeholder := ""
switch field.Type {
case FieldTypeString:
value, ok := lit.(string)
if !ok {
return renderResult{}, errors.Errorf("field %q expects string value", field.Name)
}
placeholder = r.addArg(value)
case FieldTypeInt, FieldTypeTimestamp:
num, err := toInt64(lit)
if err != nil {
return renderResult{}, errors.Wrapf(err, "field %q expects integer value", field.Name)
}
placeholder = r.addArg(num)
default:
return renderResult{}, errors.Errorf("unsupported data type %q for field %s", field.Type, field.Name)
}
return renderResult{
sql: fmt.Sprintf("%s %s %s", columnExpr, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderBoolColumnComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
value, err := expectBool(right)
if err != nil {
return renderResult{}, err
}
placeholder := r.addBoolArg(value)
column := qualifyColumn(r.dialect, field.Column)
return renderResult{
sql: fmt.Sprintf("%s %s %s", column, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderJSONBoolComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
value, err := expectBool(right)
if err != nil {
return renderResult{}, err
}
jsonExpr := jsonExtractExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite:
switch op {
case CompareEq:
if field.Name == "has_task_list" {
target := "0"
if value {
target = "1"
}
return renderResult{sql: fmt.Sprintf("%s = %s", jsonExpr, target)}, nil
}
if value {
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
}
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
case CompareNeq:
if field.Name == "has_task_list" {
target := "0"
if value {
target = "1"
}
return renderResult{sql: fmt.Sprintf("%s != %s", jsonExpr, target)}, nil
}
if value {
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
}
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
default:
return renderResult{}, errors.Errorf("operator %s not supported for boolean JSON field", op)
}
case DialectMySQL:
boolStr := "false"
if value {
boolStr = "true"
}
return renderResult{
sql: fmt.Sprintf("%s %s CAST('%s' AS JSON)", jsonExpr, sqlOperator(op), boolStr),
}, nil
case DialectPostgres:
placeholder := r.addArg(value)
return renderResult{
sql: fmt.Sprintf("(%s)::boolean %s %s", jsonExpr, sqlOperator(op), placeholder),
}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func (r *renderer) renderInCondition(cond *InCondition) (renderResult, error) {
fieldRef, ok := cond.Left.(*FieldRef)
if !ok {
return renderResult{}, errors.New("IN operator requires a field on the left-hand side")
}
if fieldRef.Name == "tag" {
return r.renderTagInList(cond.Values)
}
field, ok := r.schema.Field(fieldRef.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", fieldRef.Name)
}
if field.Kind != FieldKindScalar {
return renderResult{}, errors.Errorf("field %q does not support IN()", fieldRef.Name)
}
return r.renderScalarInCondition(field, cond.Values)
}
func (r *renderer) renderTagInList(values []ValueExpr) (renderResult, error) {
field, ok := r.schema.ResolveAlias("tag")
if !ok {
return renderResult{}, errors.New("tag attribute is not configured")
}
conditions := make([]string, 0, len(values))
for _, v := range values {
lit, err := expectLiteral(v)
if err != nil {
return renderResult{}, err
}
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.New("tags must be compared with string literals")
}
switch r.dialect {
case DialectSQLite:
// Support hierarchical tags: match exact tag OR tags with this prefix (e.g., "book" matches "book" and "book/something")
exactMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
prefixMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
conditions = append(conditions, expr)
case DialectMySQL:
// Support hierarchical tags: match exact tag OR tags with this prefix
exactMatch := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
prefixMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
conditions = append(conditions, expr)
case DialectPostgres:
// Support hierarchical tags: match exact tag OR tags with this prefix
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
conditions = append(conditions, expr)
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
if len(conditions) == 1 {
return renderResult{sql: conditions[0]}, nil
}
return renderResult{
sql: fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")),
}, nil
}
func (r *renderer) renderElementInCondition(cond *ElementInCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("field %q is not a tag list", cond.Field)
}
lit, err := expectLiteral(cond.Element)
if err != nil {
return renderResult{}, err
}
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.New("tags membership requires string literal")
}
switch r.dialect {
case DialectSQLite:
sql := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
return renderResult{sql: sql}, nil
case DialectMySQL:
sql := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
return renderResult{sql: sql}, nil
case DialectPostgres:
sql := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
return renderResult{sql: sql}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func (r *renderer) renderScalarInCondition(field Field, values []ValueExpr) (renderResult, error) {
placeholders := make([]string, 0, len(values))
for _, v := range values {
lit, err := expectLiteral(v)
if err != nil {
return renderResult{}, err
}
switch field.Type {
case FieldTypeString:
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.Errorf("field %q expects string values", field.Name)
}
placeholders = append(placeholders, r.addArg(str))
case FieldTypeInt:
num, err := toInt64(lit)
if err != nil {
return renderResult{}, err
}
placeholders = append(placeholders, r.addArg(num))
default:
return renderResult{}, errors.Errorf("field %q does not support IN() comparisons", field.Name)
}
}
column := field.columnExpr(r.dialect)
return renderResult{
sql: fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")),
}, nil
}
func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
column := field.columnExpr(r.dialect)
arg := fmt.Sprintf("%%%s%%", cond.Value)
switch r.dialect {
case DialectSQLite:
// Use custom Unicode-aware case folding function for case-insensitive comparison.
// This overcomes SQLite's ASCII-only LOWER() limitation.
sql := fmt.Sprintf("memos_unicode_lower(%s) LIKE memos_unicode_lower(%s)", column, r.addArg(arg))
return renderResult{sql: sql}, nil
case DialectPostgres:
sql := fmt.Sprintf("%s ILIKE %s", column, r.addArg(arg))
return renderResult{sql: sql}, nil
default:
sql := fmt.Sprintf("%s LIKE %s", column, r.addArg(arg))
return renderResult{sql: sql}, nil
}
}
func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("field %q is not a JSON list", cond.Field)
}
// Render based on predicate type
switch pred := cond.Predicate.(type) {
case *StartsWithPredicate:
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
case *EndsWithPredicate:
return r.renderTagEndsWith(field, pred.Suffix, cond.Kind)
case *ContainsPredicate:
return r.renderTagContains(field, pred.Substring, cond.Kind)
default:
return renderResult{}, errors.Errorf("unsupported predicate type %T in comprehension", pred)
}
}
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite, DialectMySQL:
// Match exact tag or tags with this prefix (hierarchical support)
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, prefix))
prefixMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s%%`, prefix))
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
case DialectPostgres:
// Use PostgreSQL's powerful JSON operators
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, prefix)))
prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(fmt.Sprintf(`%%"%s%%`, prefix)))
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
// renderTagEndsWith generates SQL for tags.exists(t, t.endsWith("suffix")).
func (r *renderer) renderTagEndsWith(field Field, suffix string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)
pattern := fmt.Sprintf(`%%%s"%%`, suffix)
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
}
// renderTagContains generates SQL for tags.exists(t, t.contains("substring")).
func (r *renderer) renderTagContains(field Field, substring string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)
pattern := fmt.Sprintf(`%%%s%%`, substring)
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
}
// buildJSONArrayLike builds a LIKE expression for matching within a JSON array.
// Returns the LIKE clause without NULL/empty checks.
func (r *renderer) buildJSONArrayLike(arrayExpr, pattern string) string {
switch r.dialect {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("%s LIKE %s", arrayExpr, r.addArg(pattern))
case DialectPostgres:
return fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(pattern))
default:
return ""
}
}
// wrapWithNullCheck wraps a condition with NULL and empty array checks.
// This ensures we don't match against NULL or empty JSON arrays.
func (r *renderer) wrapWithNullCheck(arrayExpr, condition string) string {
var nullCheck string
switch r.dialect {
case DialectSQLite:
nullCheck = fmt.Sprintf("%s IS NOT NULL AND %s != '[]'", arrayExpr, arrayExpr)
case DialectMySQL:
nullCheck = fmt.Sprintf("%s IS NOT NULL AND JSON_LENGTH(%s) > 0", arrayExpr, arrayExpr)
case DialectPostgres:
nullCheck = fmt.Sprintf("%s IS NOT NULL AND jsonb_array_length(%s) > 0", arrayExpr, arrayExpr)
default:
return condition
}
return fmt.Sprintf("(%s AND %s)", condition, nullCheck)
}
func (r *renderer) jsonBoolPredicate(field Field) (string, error) {
expr := jsonExtractExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite:
return fmt.Sprintf("%s IS TRUE", expr), nil
case DialectMySQL:
return fmt.Sprintf("COALESCE(%s, CAST('false' AS JSON)) = CAST('true' AS JSON)", expr), nil
case DialectPostgres:
return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil
default:
return "", errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func combineAnd(left, right renderResult) renderResult {
if left.unsatisfiable || right.unsatisfiable {
return renderResult{sql: "1 = 0", unsatisfiable: true}
}
if left.trivial {
return right
}
if right.trivial {
return left
}
return renderResult{
sql: fmt.Sprintf("(%s AND %s)", left.sql, right.sql),
}
}
func combineOr(left, right renderResult) renderResult {
if left.trivial || right.trivial {
return renderResult{trivial: true}
}
if left.unsatisfiable {
return right
}
if right.unsatisfiable {
return left
}
return renderResult{
sql: fmt.Sprintf("(%s OR %s)", left.sql, right.sql),
}
}
func (r *renderer) addArg(value any) string {
r.placeholderCounter++
r.args = append(r.args, value)
if r.dialect == DialectPostgres {
return fmt.Sprintf("$%d", r.placeholderOffset+r.placeholderCounter)
}
return "?"
}
func (r *renderer) addBoolArg(value bool) string {
var v any
switch r.dialect {
case DialectSQLite:
if value {
v = 1
} else {
v = 0
}
default:
v = value
}
return r.addArg(v)
}
func expectLiteral(expr ValueExpr) (any, error) {
lit, ok := expr.(*LiteralValue)
if !ok {
return nil, errors.New("expression must be a literal")
}
return lit.Value, nil
}
func expectBool(expr ValueExpr) (bool, error) {
lit, err := expectLiteral(expr)
if err != nil {
return false, err
}
value, ok := lit.(bool)
if !ok {
return false, errors.New("boolean literal required")
}
return value, nil
}
func expectNumericLiteral(expr ValueExpr) (int64, error) {
lit, err := expectLiteral(expr)
if err != nil {
return 0, err
}
return toInt64(lit)
}
func toInt64(value any) (int64, error) {
switch v := value.(type) {
case int:
return int64(v), nil
case int32:
return int64(v), nil
case int64:
return v, nil
case uint32:
return int64(v), nil
case uint64:
return int64(v), nil
case float32:
return int64(v), nil
case float64:
return int64(v), nil
default:
return 0, errors.Errorf("cannot convert %T to int64", value)
}
}
func sqlOperator(op ComparisonOperator) string {
return string(op)
}
func qualifyColumn(d DialectName, col Column) string {
switch d {
case DialectPostgres:
return fmt.Sprintf("%s.%s", col.Table, col.Name)
default:
return fmt.Sprintf("`%s`.`%s`", col.Table, col.Name)
}
}
func jsonPath(field Field) string {
return "$." + strings.Join(field.JSONPath, ".")
}
func jsonExtractExpr(d DialectName, field Field) string {
column := qualifyColumn(d, field.Column)
switch d {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
case DialectPostgres:
return buildPostgresJSONAccessor(column, field.JSONPath, true)
default:
return ""
}
}
func jsonArrayExpr(d DialectName, field Field) string {
column := qualifyColumn(d, field.Column)
switch d {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
case DialectPostgres:
return buildPostgresJSONAccessor(column, field.JSONPath, false)
default:
return ""
}
}
func jsonArrayLengthExpr(d DialectName, field Field) string {
arrayExpr := jsonArrayExpr(d, field)
switch d {
case DialectSQLite:
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
case DialectMySQL:
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
case DialectPostgres:
return fmt.Sprintf("jsonb_array_length(COALESCE(%s, '[]'::jsonb))", arrayExpr)
default:
return ""
}
}
func buildPostgresJSONAccessor(base string, path []string, terminalText bool) string {
expr := base
for idx, part := range path {
if idx == len(path)-1 && terminalText {
expr = fmt.Sprintf("%s->>'%s'", expr, part)
} else {
expr = fmt.Sprintf("%s->'%s'", expr, part)
}
}
return expr
}

319
plugin/filter/schema.go Normal file
View File

@@ -0,0 +1,319 @@
package filter
import (
"fmt"
"time"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// DialectName enumerates supported SQL dialects.
type DialectName string
const (
DialectSQLite DialectName = "sqlite"
DialectMySQL DialectName = "mysql"
DialectPostgres DialectName = "postgres"
)
// FieldType represents the logical type of a field.
type FieldType string
const (
FieldTypeString FieldType = "string"
FieldTypeInt FieldType = "int"
FieldTypeBool FieldType = "bool"
FieldTypeTimestamp FieldType = "timestamp"
)
// FieldKind describes how a field is stored.
type FieldKind string
const (
FieldKindScalar FieldKind = "scalar"
FieldKindBoolColumn FieldKind = "bool_column"
FieldKindJSONBool FieldKind = "json_bool"
FieldKindJSONList FieldKind = "json_list"
FieldKindVirtualAlias FieldKind = "virtual_alias"
)
// Column identifies the backing table column.
type Column struct {
Table string
Name string
}
// Field captures the schema metadata for an exposed CEL identifier.
type Field struct {
Name string
Kind FieldKind
Type FieldType
Column Column
JSONPath []string
AliasFor string
SupportsContains bool
Expressions map[DialectName]string
AllowedComparisonOps map[ComparisonOperator]bool
}
// Schema collects CEL environment options and field metadata.
type Schema struct {
Name string
Fields map[string]Field
EnvOptions []cel.EnvOption
}
// Field returns the field metadata if present.
func (s Schema) Field(name string) (Field, bool) {
f, ok := s.Fields[name]
return f, ok
}
// ResolveAlias resolves a virtual alias to its target field.
func (s Schema) ResolveAlias(name string) (Field, bool) {
field, ok := s.Fields[name]
if !ok {
return Field{}, false
}
if field.Kind == FieldKindVirtualAlias {
target, ok := s.Fields[field.AliasFor]
if !ok {
return Field{}, false
}
return target, true
}
return field, true
}
var nowFunction = cel.Function("now",
cel.Overload("now",
[]*cel.Type{},
cel.IntType,
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
return types.Int(time.Now().Unix())
}),
),
)
// NewSchema constructs the memo filter schema and CEL environment.
func NewSchema() Schema {
fields := map[string]Field{
"content": {
Name: "content",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "memo", Name: "content"},
SupportsContains: true,
Expressions: map[DialectName]string{},
},
"creator_id": {
Name: "creator_id",
Kind: FieldKindScalar,
Type: FieldTypeInt,
Column: Column{Table: "memo", Name: "creator_id"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"created_ts": {
Name: "created_ts",
Kind: FieldKindScalar,
Type: FieldTypeTimestamp,
Column: Column{Table: "memo", Name: "created_ts"},
Expressions: map[DialectName]string{
// MySQL stores created_ts as TIMESTAMP, needs conversion to epoch
DialectMySQL: "UNIX_TIMESTAMP(%s)",
// PostgreSQL and SQLite store created_ts as BIGINT (epoch), no conversion needed
DialectPostgres: "%s",
DialectSQLite: "%s",
},
},
"updated_ts": {
Name: "updated_ts",
Kind: FieldKindScalar,
Type: FieldTypeTimestamp,
Column: Column{Table: "memo", Name: "updated_ts"},
Expressions: map[DialectName]string{
// MySQL stores updated_ts as TIMESTAMP, needs conversion to epoch
DialectMySQL: "UNIX_TIMESTAMP(%s)",
// PostgreSQL and SQLite store updated_ts as BIGINT (epoch), no conversion needed
DialectPostgres: "%s",
DialectSQLite: "%s",
},
},
"pinned": {
Name: "pinned",
Kind: FieldKindBoolColumn,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "pinned"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"visibility": {
Name: "visibility",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "memo", Name: "visibility"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"tags": {
Name: "tags",
Kind: FieldKindJSONList,
Type: FieldTypeString,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"tags"},
},
"tag": {
Name: "tag",
Kind: FieldKindVirtualAlias,
Type: FieldTypeString,
AliasFor: "tags",
},
"has_task_list": {
Name: "has_task_list",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasTaskList"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"has_link": {
Name: "has_link",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasLink"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"has_code": {
Name: "has_code",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasCode"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"has_incomplete_tasks": {
Name: "has_incomplete_tasks",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasIncompleteTasks"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
}
envOptions := []cel.EnvOption{
cel.Variable("content", cel.StringType),
cel.Variable("creator_id", cel.IntType),
cel.Variable("created_ts", cel.IntType),
cel.Variable("updated_ts", cel.IntType),
cel.Variable("pinned", cel.BoolType),
cel.Variable("tag", cel.StringType),
cel.Variable("tags", cel.ListType(cel.StringType)),
cel.Variable("visibility", cel.StringType),
cel.Variable("has_task_list", cel.BoolType),
cel.Variable("has_link", cel.BoolType),
cel.Variable("has_code", cel.BoolType),
cel.Variable("has_incomplete_tasks", cel.BoolType),
nowFunction,
}
return Schema{
Name: "memo",
Fields: fields,
EnvOptions: envOptions,
}
}
// NewAttachmentSchema constructs the attachment filter schema and CEL environment.
func NewAttachmentSchema() Schema {
fields := map[string]Field{
"filename": {
Name: "filename",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "attachment", Name: "filename"},
SupportsContains: true,
Expressions: map[DialectName]string{},
},
"mime_type": {
Name: "mime_type",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "attachment", Name: "type"},
Expressions: map[DialectName]string{},
},
"create_time": {
Name: "create_time",
Kind: FieldKindScalar,
Type: FieldTypeTimestamp,
Column: Column{Table: "attachment", Name: "created_ts"},
Expressions: map[DialectName]string{
// MySQL stores created_ts as TIMESTAMP, needs conversion to epoch
DialectMySQL: "UNIX_TIMESTAMP(%s)",
// PostgreSQL and SQLite store created_ts as BIGINT (epoch), no conversion needed
DialectPostgres: "%s",
DialectSQLite: "%s",
},
},
"memo_id": {
Name: "memo_id",
Kind: FieldKindScalar,
Type: FieldTypeInt,
Column: Column{Table: "attachment", Name: "memo_id"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
}
envOptions := []cel.EnvOption{
cel.Variable("filename", cel.StringType),
cel.Variable("mime_type", cel.StringType),
cel.Variable("create_time", cel.IntType),
cel.Variable("memo_id", cel.AnyType),
nowFunction,
}
return Schema{
Name: "attachment",
Fields: fields,
EnvOptions: envOptions,
}
}
// columnExpr returns the field expression for the given dialect, applying
// any schema-specific overrides (e.g. UNIX timestamp conversions).
func (f Field) columnExpr(d DialectName) string {
base := qualifyColumn(d, f.Column)
if expr, ok := f.Expressions[d]; ok && expr != "" {
return fmt.Sprintf(expr, base)
}
return base
}