first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
This commit is contained in:
50
plugin/filter/MAINTENANCE.md
Normal file
50
plugin/filter/MAINTENANCE.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Maintaining the Memo Filter Engine
|
||||
|
||||
The engine is memo-specific; any future field or behavior changes must stay
|
||||
consistent with the memo schema and store implementations. Use this guide when
|
||||
extending or debugging the package.
|
||||
|
||||
## Adding a New Memo Field
|
||||
|
||||
1. **Update the schema**
|
||||
- Add the field entry in `schema.go`.
|
||||
- Define the backing column (`Column`), JSON path (if applicable), type, and
|
||||
allowed operators.
|
||||
- Include the CEL variable in `EnvOptions`.
|
||||
2. **Adjust parser or renderer (if needed)**
|
||||
- For non-scalar fields (JSON booleans, lists), add handling in
|
||||
`parser.go` or extend the renderer helpers.
|
||||
- Keep validation in the parser (e.g., reject unsupported operators).
|
||||
3. **Write a golden test**
|
||||
- Extend the dialect-specific memo filter tests under
|
||||
`store/db/{sqlite,mysql,postgres}/memo_filter_test.go` with a case that
|
||||
exercises the new field.
|
||||
4. **Run `go test ./...`** to ensure the SQL output matches expectations across
|
||||
all dialects.
|
||||
|
||||
## Supporting Dialect Nuances
|
||||
|
||||
- Centralize differences inside `render.go`. If a new dialect-specific behavior
|
||||
emerges (e.g., JSON operators), add the logic there rather than leaking it
|
||||
into store code.
|
||||
- Use the renderer helpers (`jsonExtractExpr`, `jsonArrayExpr`, etc.) rather than
|
||||
sprinkling ad-hoc SQL strings.
|
||||
- When placeholders change, adjust `addArg` so that argument numbering stays in
|
||||
sync with store queries.
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
- **Parser errors** – Most originate in `buildCondition` or schema validation.
|
||||
Enable logging around `parser.go` when diagnosing unknown identifier/operator
|
||||
messages.
|
||||
- **Renderer output** – Temporary printf/log statements in `renderCondition` help
|
||||
identify which IR node produced unexpected SQL.
|
||||
- **Store integration** – Ensure drivers call `filter.DefaultEngine()` exactly once
|
||||
per process; the singleton caches the parsed CEL environment.
|
||||
|
||||
## Testing Checklist
|
||||
|
||||
- `go test ./store/...` ensures all dialect tests consume the engine correctly.
|
||||
- Add targeted unit tests whenever new IR nodes or renderer paths are introduced.
|
||||
- When changing boolean or JSON handling, verify all three dialect test suites
|
||||
(SQLite, MySQL, Postgres) to avoid regression.
|
||||
63
plugin/filter/README.md
Normal file
63
plugin/filter/README.md
Normal file
@@ -0,0 +1,63 @@
|
||||
# Memo Filter Engine
|
||||
|
||||
This package houses the memo-only filter engine that turns CEL expressions into
|
||||
SQL fragments. The engine follows a three phase pipeline inspired by systems
|
||||
such as Calcite or Prisma:
|
||||
|
||||
1. **Parsing** – CEL expressions are parsed with `cel-go` and validated against
|
||||
the memo-specific environment declared in `schema.go`. Only fields that
|
||||
exist in the schema can surface in the filter.
|
||||
2. **Normalization** – the raw CEL AST is converted into an intermediate
|
||||
representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of
|
||||
conditions (logical operators, comparisons, list membership, etc.). This
|
||||
step enforces schema rules (e.g. operator compatibility, type checks).
|
||||
3. **Rendering** – the renderer in `render.go` walks the IR and produces a SQL
|
||||
fragment plus placeholder arguments tailored to a target dialect
|
||||
(`sqlite`, `mysql`, or `postgres`). Dialect differences such as JSON access,
|
||||
boolean semantics, placeholders, and `LIKE` vs `ILIKE` are encapsulated in
|
||||
renderer helpers.
|
||||
|
||||
The entry point is `filter.DefaultEngine()` from `engine.go`. It lazily constructs
|
||||
an `Engine` configured with the memo schema and exposes:
|
||||
|
||||
```go
|
||||
engine, _ := filter.DefaultEngine()
|
||||
stmt, _ := engine.CompileToStatement(ctx, `has_task_list && visibility == "PUBLIC"`, filter.RenderOptions{
|
||||
Dialect: filter.DialectPostgres,
|
||||
})
|
||||
// stmt.SQL -> "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.visibility = $1)"
|
||||
// stmt.Args -> ["PUBLIC"]
|
||||
```
|
||||
|
||||
## Core Files
|
||||
|
||||
| File | Responsibility |
|
||||
| ------------- | ------------------------------------------------------------------------------- |
|
||||
| `schema.go` | Declares memo fields, their types, backing columns, CEL environment options |
|
||||
| `ir.go` | IR node definitions used across the pipeline |
|
||||
| `parser.go` | Converts CEL `Expr` into IR while applying schema validation |
|
||||
| `render.go` | Translates IR into SQL, handling dialect-specific behavior |
|
||||
| `engine.go` | Glue between the phases; exposes `Compile`, `CompileToStatement`, and `DefaultEngine` |
|
||||
| `helpers.go` | Convenience helpers for store integration (appending conditions) |
|
||||
|
||||
## SQL Generation Notes
|
||||
|
||||
- **Placeholders** — `?` is used for SQLite/MySQL, `$n` for Postgres. The renderer
|
||||
tracks offsets to compose queries with pre-existing arguments.
|
||||
- **JSON Fields** — Memo metadata lives in `memo.payload`. The renderer handles
|
||||
`JSON_EXTRACT`/`json_extract`/`->`/`->>` variations and boolean coercion.
|
||||
- **Tag Operations** — `tag in [...]` and `"tag" in tags` become JSON array
|
||||
predicates. SQLite uses `LIKE` patterns, MySQL uses `JSON_CONTAINS`, and
|
||||
Postgres uses `@>`.
|
||||
- **Boolean Flags** — Fields such as `has_task_list` render as `IS TRUE` equality
|
||||
checks, or comparisons against `CAST('true' AS JSON)` depending on the dialect.
|
||||
|
||||
## Typical Integration
|
||||
|
||||
1. Fetch the engine with `filter.DefaultEngine()`.
|
||||
2. Call `CompileToStatement` using the appropriate dialect enum.
|
||||
3. Append the emitted SQL fragment/args to the existing `WHERE` clause.
|
||||
4. Execute the resulting query through the store driver.
|
||||
|
||||
The `helpers.AppendConditions` helper encapsulates steps 2–3 when a driver needs
|
||||
to process an array of filters.
|
||||
191
plugin/filter/engine.go
Normal file
191
plugin/filter/engine.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Engine parses CEL filters into a dialect-agnostic condition tree.
|
||||
type Engine struct {
|
||||
schema Schema
|
||||
env *cel.Env
|
||||
}
|
||||
|
||||
// NewEngine builds a new Engine for the provided schema.
|
||||
func NewEngine(schema Schema) (*Engine, error) {
|
||||
env, err := cel.NewEnv(schema.EnvOptions...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create CEL environment")
|
||||
}
|
||||
return &Engine{
|
||||
schema: schema,
|
||||
env: env,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Program stores a compiled filter condition.
|
||||
type Program struct {
|
||||
schema Schema
|
||||
condition Condition
|
||||
}
|
||||
|
||||
// ConditionTree exposes the underlying condition tree.
|
||||
func (p *Program) ConditionTree() Condition {
|
||||
return p.condition
|
||||
}
|
||||
|
||||
// Compile parses the filter string into an executable program.
|
||||
func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
|
||||
if strings.TrimSpace(filter) == "" {
|
||||
return nil, errors.New("filter expression is empty")
|
||||
}
|
||||
|
||||
filter = normalizeLegacyFilter(filter)
|
||||
|
||||
ast, issues := e.env.Compile(filter)
|
||||
if issues != nil && issues.Err() != nil {
|
||||
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
|
||||
}
|
||||
parsed, err := cel.AstToParsedExpr(ast)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert AST")
|
||||
}
|
||||
|
||||
cond, err := buildCondition(parsed.GetExpr(), e.schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Program{
|
||||
schema: e.schema,
|
||||
condition: cond,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CompileToStatement compiles and renders the filter in a single step.
|
||||
func (e *Engine) CompileToStatement(ctx context.Context, filter string, opts RenderOptions) (Statement, error) {
|
||||
program, err := e.Compile(ctx, filter)
|
||||
if err != nil {
|
||||
return Statement{}, err
|
||||
}
|
||||
return program.Render(opts)
|
||||
}
|
||||
|
||||
// RenderOptions configure SQL rendering.
|
||||
type RenderOptions struct {
|
||||
Dialect DialectName
|
||||
PlaceholderOffset int
|
||||
DisableNullChecks bool
|
||||
}
|
||||
|
||||
// Statement contains the rendered SQL fragment and its args.
|
||||
type Statement struct {
|
||||
SQL string
|
||||
Args []any
|
||||
}
|
||||
|
||||
// Render converts the program into a dialect-specific SQL fragment.
|
||||
func (p *Program) Render(opts RenderOptions) (Statement, error) {
|
||||
renderer := newRenderer(p.schema, opts)
|
||||
return renderer.Render(p.condition)
|
||||
}
|
||||
|
||||
var (
|
||||
defaultOnce sync.Once
|
||||
defaultInst *Engine
|
||||
defaultErr error
|
||||
defaultAttachmentOnce sync.Once
|
||||
defaultAttachmentInst *Engine
|
||||
defaultAttachmentErr error
|
||||
)
|
||||
|
||||
// DefaultEngine returns the process-wide memo filter engine.
|
||||
func DefaultEngine() (*Engine, error) {
|
||||
defaultOnce.Do(func() {
|
||||
defaultInst, defaultErr = NewEngine(NewSchema())
|
||||
})
|
||||
return defaultInst, defaultErr
|
||||
}
|
||||
|
||||
// DefaultAttachmentEngine returns the process-wide attachment filter engine.
|
||||
func DefaultAttachmentEngine() (*Engine, error) {
|
||||
defaultAttachmentOnce.Do(func() {
|
||||
defaultAttachmentInst, defaultAttachmentErr = NewEngine(NewAttachmentSchema())
|
||||
})
|
||||
return defaultAttachmentInst, defaultAttachmentErr
|
||||
}
|
||||
|
||||
func normalizeLegacyFilter(expr string) string {
|
||||
expr = rewriteNumericLogicalOperand(expr, "&&")
|
||||
expr = rewriteNumericLogicalOperand(expr, "||")
|
||||
return expr
|
||||
}
|
||||
|
||||
func rewriteNumericLogicalOperand(expr, op string) string {
|
||||
var builder strings.Builder
|
||||
n := len(expr)
|
||||
i := 0
|
||||
var inQuote rune
|
||||
|
||||
for i < n {
|
||||
ch := expr[i]
|
||||
|
||||
if inQuote != 0 {
|
||||
builder.WriteByte(ch)
|
||||
if ch == '\\' && i+1 < n {
|
||||
builder.WriteByte(expr[i+1])
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if ch == byte(inQuote) {
|
||||
inQuote = 0
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\'' || ch == '"' {
|
||||
inQuote = rune(ch)
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(expr[i:], op) {
|
||||
builder.WriteString(op)
|
||||
i += len(op)
|
||||
|
||||
// Preserve whitespace following the operator.
|
||||
wsStart := i
|
||||
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
|
||||
i++
|
||||
}
|
||||
builder.WriteString(expr[wsStart:i])
|
||||
|
||||
signStart := i
|
||||
if i < n && (expr[i] == '+' || expr[i] == '-') {
|
||||
i++
|
||||
}
|
||||
for i < n && expr[i] >= '0' && expr[i] <= '9' {
|
||||
i++
|
||||
}
|
||||
if i > signStart {
|
||||
numLiteral := expr[signStart:i]
|
||||
builder.WriteString(fmt.Sprintf("(%s != 0)", numLiteral))
|
||||
} else {
|
||||
builder.WriteString(expr[signStart:i])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
25
plugin/filter/helpers.go
Normal file
25
plugin/filter/helpers.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// AppendConditions compiles the provided filters and appends the resulting SQL fragments and args.
|
||||
func AppendConditions(ctx context.Context, engine *Engine, filters []string, dialect DialectName, where *[]string, args *[]any) error {
|
||||
for _, filterStr := range filters {
|
||||
stmt, err := engine.CompileToStatement(ctx, filterStr, RenderOptions{
|
||||
Dialect: dialect,
|
||||
PlaceholderOffset: len(*args),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stmt.SQL == "" {
|
||||
continue
|
||||
}
|
||||
*where = append(*where, fmt.Sprintf("(%s)", stmt.SQL))
|
||||
*args = append(*args, stmt.Args...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
159
plugin/filter/ir.go
Normal file
159
plugin/filter/ir.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package filter
|
||||
|
||||
// Condition represents a boolean expression derived from the CEL filter.
|
||||
type Condition interface {
|
||||
isCondition()
|
||||
}
|
||||
|
||||
// LogicalOperator enumerates the supported logical operators.
|
||||
type LogicalOperator string
|
||||
|
||||
const (
|
||||
LogicalAnd LogicalOperator = "AND"
|
||||
LogicalOr LogicalOperator = "OR"
|
||||
)
|
||||
|
||||
// LogicalCondition composes two conditions with a logical operator.
|
||||
type LogicalCondition struct {
|
||||
Operator LogicalOperator
|
||||
Left Condition
|
||||
Right Condition
|
||||
}
|
||||
|
||||
func (*LogicalCondition) isCondition() {}
|
||||
|
||||
// NotCondition negates a child condition.
|
||||
type NotCondition struct {
|
||||
Expr Condition
|
||||
}
|
||||
|
||||
func (*NotCondition) isCondition() {}
|
||||
|
||||
// FieldPredicateCondition asserts that a field evaluates to true.
|
||||
type FieldPredicateCondition struct {
|
||||
Field string
|
||||
}
|
||||
|
||||
func (*FieldPredicateCondition) isCondition() {}
|
||||
|
||||
// ComparisonOperator lists supported comparison operators.
|
||||
type ComparisonOperator string
|
||||
|
||||
const (
|
||||
CompareEq ComparisonOperator = "="
|
||||
CompareNeq ComparisonOperator = "!="
|
||||
CompareLt ComparisonOperator = "<"
|
||||
CompareLte ComparisonOperator = "<="
|
||||
CompareGt ComparisonOperator = ">"
|
||||
CompareGte ComparisonOperator = ">="
|
||||
)
|
||||
|
||||
// ComparisonCondition represents a binary comparison.
|
||||
type ComparisonCondition struct {
|
||||
Left ValueExpr
|
||||
Operator ComparisonOperator
|
||||
Right ValueExpr
|
||||
}
|
||||
|
||||
func (*ComparisonCondition) isCondition() {}
|
||||
|
||||
// InCondition represents an IN predicate with literal list values.
|
||||
type InCondition struct {
|
||||
Left ValueExpr
|
||||
Values []ValueExpr
|
||||
}
|
||||
|
||||
func (*InCondition) isCondition() {}
|
||||
|
||||
// ElementInCondition represents the CEL syntax `"value" in field`.
|
||||
type ElementInCondition struct {
|
||||
Element ValueExpr
|
||||
Field string
|
||||
}
|
||||
|
||||
func (*ElementInCondition) isCondition() {}
|
||||
|
||||
// ContainsCondition models the <field>.contains(<value>) call.
|
||||
type ContainsCondition struct {
|
||||
Field string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (*ContainsCondition) isCondition() {}
|
||||
|
||||
// ConstantCondition captures a literal boolean outcome.
|
||||
type ConstantCondition struct {
|
||||
Value bool
|
||||
}
|
||||
|
||||
func (*ConstantCondition) isCondition() {}
|
||||
|
||||
// ValueExpr models arithmetic or scalar expressions whose result feeds a comparison.
|
||||
type ValueExpr interface {
|
||||
isValueExpr()
|
||||
}
|
||||
|
||||
// FieldRef references a named schema field.
|
||||
type FieldRef struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (*FieldRef) isValueExpr() {}
|
||||
|
||||
// LiteralValue holds a literal scalar.
|
||||
type LiteralValue struct {
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (*LiteralValue) isValueExpr() {}
|
||||
|
||||
// FunctionValue captures simple function calls like size(tags).
|
||||
type FunctionValue struct {
|
||||
Name string
|
||||
Args []ValueExpr
|
||||
}
|
||||
|
||||
func (*FunctionValue) isValueExpr() {}
|
||||
|
||||
// ListComprehensionCondition represents CEL macros like exists(), all(), filter().
|
||||
type ListComprehensionCondition struct {
|
||||
Kind ComprehensionKind
|
||||
Field string // The list field to iterate over (e.g., "tags")
|
||||
IterVar string // The iteration variable name (e.g., "t")
|
||||
Predicate PredicateExpr // The predicate to evaluate for each element
|
||||
}
|
||||
|
||||
func (*ListComprehensionCondition) isCondition() {}
|
||||
|
||||
// ComprehensionKind enumerates the types of list comprehensions.
|
||||
type ComprehensionKind string
|
||||
|
||||
const (
|
||||
ComprehensionExists ComprehensionKind = "exists"
|
||||
)
|
||||
|
||||
// PredicateExpr represents predicates used in comprehensions.
|
||||
type PredicateExpr interface {
|
||||
isPredicateExpr()
|
||||
}
|
||||
|
||||
// StartsWithPredicate represents t.startsWith("prefix").
|
||||
type StartsWithPredicate struct {
|
||||
Prefix string
|
||||
}
|
||||
|
||||
func (*StartsWithPredicate) isPredicateExpr() {}
|
||||
|
||||
// EndsWithPredicate represents t.endsWith("suffix").
|
||||
type EndsWithPredicate struct {
|
||||
Suffix string
|
||||
}
|
||||
|
||||
func (*EndsWithPredicate) isPredicateExpr() {}
|
||||
|
||||
// ContainsPredicate represents t.contains("substring").
|
||||
type ContainsPredicate struct {
|
||||
Substring string
|
||||
}
|
||||
|
||||
func (*ContainsPredicate) isPredicateExpr() {}
|
||||
586
plugin/filter/parser.go
Normal file
586
plugin/filter/parser.go
Normal file
@@ -0,0 +1,586 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
|
||||
switch v := expr.ExprKind.(type) {
|
||||
case *exprv1.Expr_CallExpr:
|
||||
return buildCallCondition(v.CallExpr, schema)
|
||||
case *exprv1.Expr_ConstExpr:
|
||||
val, err := getConstValue(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
return &ConstantCondition{Value: v}, nil
|
||||
case int64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
case float64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
default:
|
||||
return nil, errors.New("filter must evaluate to a boolean value")
|
||||
}
|
||||
case *exprv1.Expr_IdentExpr:
|
||||
name := v.IdentExpr.GetName()
|
||||
field, ok := schema.Field(name)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", name)
|
||||
}
|
||||
if field.Type != FieldTypeBool {
|
||||
return nil, errors.Errorf("identifier %q is not boolean", name)
|
||||
}
|
||||
return &FieldPredicateCondition{Field: name}, nil
|
||||
case *exprv1.Expr_ComprehensionExpr:
|
||||
return buildComprehensionCondition(v.ComprehensionExpr, schema)
|
||||
default:
|
||||
return nil, errors.New("unsupported top-level expression")
|
||||
}
|
||||
}
|
||||
|
||||
func buildCallCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
switch call.Function {
|
||||
case "_&&_":
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("logical AND expects two arguments")
|
||||
}
|
||||
left, err := buildCondition(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := buildCondition(call.Args[1], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LogicalCondition{
|
||||
Operator: LogicalAnd,
|
||||
Left: left,
|
||||
Right: right,
|
||||
}, nil
|
||||
case "_||_":
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("logical OR expects two arguments")
|
||||
}
|
||||
left, err := buildCondition(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := buildCondition(call.Args[1], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LogicalCondition{
|
||||
Operator: LogicalOr,
|
||||
Left: left,
|
||||
Right: right,
|
||||
}, nil
|
||||
case "!_":
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("logical NOT expects one argument")
|
||||
}
|
||||
child, err := buildCondition(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &NotCondition{Expr: child}, nil
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
return buildComparisonCondition(call, schema)
|
||||
case "@in":
|
||||
return buildInCondition(call, schema)
|
||||
case "contains":
|
||||
return buildContainsCondition(call, schema)
|
||||
default:
|
||||
val, ok, err := evaluateBool(call)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return &ConstantCondition{Value: val}, nil
|
||||
}
|
||||
return nil, errors.Errorf("unsupported call expression %q", call.Function)
|
||||
}
|
||||
}
|
||||
|
||||
func buildComparisonCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("comparison expects two arguments")
|
||||
}
|
||||
op, err := toComparisonOperator(call.Function)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
left, err := buildValueExpr(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := buildValueExpr(call.Args[1], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the left side is a field, validate allowed operators.
|
||||
if field, ok := left.(*FieldRef); ok {
|
||||
def, exists := schema.Field(field.Name)
|
||||
if !exists {
|
||||
return nil, errors.Errorf("unknown identifier %q", field.Name)
|
||||
}
|
||||
if def.Kind == FieldKindVirtualAlias {
|
||||
def, exists = schema.ResolveAlias(field.Name)
|
||||
if !exists {
|
||||
return nil, errors.Errorf("invalid alias %q", field.Name)
|
||||
}
|
||||
}
|
||||
if def.AllowedComparisonOps != nil {
|
||||
if _, allowed := def.AllowedComparisonOps[op]; !allowed {
|
||||
return nil, errors.Errorf("operator %s not allowed for field %q", op, field.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ComparisonCondition{
|
||||
Left: left,
|
||||
Operator: op,
|
||||
Right: right,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildInCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("in operator expects two arguments")
|
||||
}
|
||||
|
||||
// Handle identifier in list syntax.
|
||||
if identName, err := getIdentName(call.Args[0]); err == nil {
|
||||
if field, ok := schema.Field(identName); ok && field.Kind == FieldKindVirtualAlias {
|
||||
if _, aliasOk := schema.ResolveAlias(identName); !aliasOk {
|
||||
return nil, errors.Errorf("invalid alias %q", identName)
|
||||
}
|
||||
} else if !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", identName)
|
||||
}
|
||||
|
||||
if listExpr := call.Args[1].GetListExpr(); listExpr != nil {
|
||||
values := make([]ValueExpr, 0, len(listExpr.Elements))
|
||||
for _, element := range listExpr.Elements {
|
||||
value, err := buildValueExpr(element, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
return &InCondition{
|
||||
Left: &FieldRef{Name: identName},
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle "value in identifier" syntax.
|
||||
if identName, err := getIdentName(call.Args[1]); err == nil {
|
||||
if _, ok := schema.Field(identName); !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", identName)
|
||||
}
|
||||
element, err := buildValueExpr(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ElementInCondition{
|
||||
Element: element,
|
||||
Field: identName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid use of in operator")
|
||||
}
|
||||
|
||||
func buildContainsCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
if call.Target == nil {
|
||||
return nil, errors.New("contains requires a target")
|
||||
}
|
||||
targetName, err := getIdentName(call.Target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
field, ok := schema.Field(targetName)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", targetName)
|
||||
}
|
||||
if !field.SupportsContains {
|
||||
return nil, errors.Errorf("identifier %q does not support contains()", targetName)
|
||||
}
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("contains expects exactly one argument")
|
||||
}
|
||||
value, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "contains only supports literal arguments")
|
||||
}
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("contains argument must be a string")
|
||||
}
|
||||
return &ContainsCondition{
|
||||
Field: targetName,
|
||||
Value: str,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildValueExpr(expr *exprv1.Expr, schema Schema) (ValueExpr, error) {
|
||||
if identName, err := getIdentName(expr); err == nil {
|
||||
if _, ok := schema.Field(identName); !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", identName)
|
||||
}
|
||||
return &FieldRef{Name: identName}, nil
|
||||
}
|
||||
|
||||
if literal, err := getConstValue(expr); err == nil {
|
||||
return &LiteralValue{Value: literal}, nil
|
||||
}
|
||||
|
||||
if value, ok, err := evaluateNumeric(expr); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
return &LiteralValue{Value: value}, nil
|
||||
}
|
||||
|
||||
if boolVal, ok, err := evaluateBoolExpr(expr); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
return &LiteralValue{Value: boolVal}, nil
|
||||
}
|
||||
|
||||
if call := expr.GetCallExpr(); call != nil {
|
||||
switch call.Function {
|
||||
case "size":
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("size() expects one argument")
|
||||
}
|
||||
arg, err := buildValueExpr(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FunctionValue{
|
||||
Name: "size",
|
||||
Args: []ValueExpr{arg},
|
||||
}, nil
|
||||
case "now":
|
||||
return &LiteralValue{Value: timeNowUnix()}, nil
|
||||
case "_+_", "_-_", "_*_":
|
||||
value, ok, err := evaluateNumeric(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return &LiteralValue{Value: value}, nil
|
||||
}
|
||||
default:
|
||||
// Fall through to error return below
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported value expression")
|
||||
}
|
||||
|
||||
func toComparisonOperator(fn string) (ComparisonOperator, error) {
|
||||
switch fn {
|
||||
case "_==_":
|
||||
return CompareEq, nil
|
||||
case "_!=_":
|
||||
return CompareNeq, nil
|
||||
case "_<_":
|
||||
return CompareLt, nil
|
||||
case "_>_":
|
||||
return CompareGt, nil
|
||||
case "_<=_":
|
||||
return CompareLte, nil
|
||||
case "_>=_":
|
||||
return CompareGte, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported comparison operator %q", fn)
|
||||
}
|
||||
}
|
||||
|
||||
func getIdentName(expr *exprv1.Expr) (string, error) {
|
||||
if ident := expr.GetIdentExpr(); ident != nil {
|
||||
return ident.GetName(), nil
|
||||
}
|
||||
return "", errors.New("expression is not an identifier")
|
||||
}
|
||||
|
||||
func getConstValue(expr *exprv1.Expr) (interface{}, error) {
|
||||
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
|
||||
if !ok {
|
||||
return nil, errors.New("expression is not a literal")
|
||||
}
|
||||
switch x := v.ConstExpr.ConstantKind.(type) {
|
||||
case *exprv1.Constant_StringValue:
|
||||
return v.ConstExpr.GetStringValue(), nil
|
||||
case *exprv1.Constant_Int64Value:
|
||||
return v.ConstExpr.GetInt64Value(), nil
|
||||
case *exprv1.Constant_Uint64Value:
|
||||
return int64(v.ConstExpr.GetUint64Value()), nil
|
||||
case *exprv1.Constant_DoubleValue:
|
||||
return v.ConstExpr.GetDoubleValue(), nil
|
||||
case *exprv1.Constant_BoolValue:
|
||||
return v.ConstExpr.GetBoolValue(), nil
|
||||
case *exprv1.Constant_NullValue:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported constant %T", x)
|
||||
}
|
||||
}
|
||||
|
||||
func evaluateBool(call *exprv1.Expr_Call) (bool, bool, error) {
|
||||
val, ok, err := evaluateBoolExpr(&exprv1.Expr{ExprKind: &exprv1.Expr_CallExpr{CallExpr: call}})
|
||||
return val, ok, err
|
||||
}
|
||||
|
||||
func evaluateBoolExpr(expr *exprv1.Expr) (bool, bool, error) {
|
||||
if literal, err := getConstValue(expr); err == nil {
|
||||
if b, ok := literal.(bool); ok {
|
||||
return b, true, nil
|
||||
}
|
||||
return false, false, nil
|
||||
}
|
||||
if call := expr.GetCallExpr(); call != nil && call.Function == "!_" {
|
||||
if len(call.Args) != 1 {
|
||||
return false, false, errors.New("NOT expects exactly one argument")
|
||||
}
|
||||
val, ok, err := evaluateBoolExpr(call.Args[0])
|
||||
if err != nil || !ok {
|
||||
return false, false, err
|
||||
}
|
||||
return !val, true, nil
|
||||
}
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) {
|
||||
if literal, err := getConstValue(expr); err == nil {
|
||||
switch v := literal.(type) {
|
||||
case int64:
|
||||
return v, true, nil
|
||||
case float64:
|
||||
return int64(v), true, nil
|
||||
}
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
call := expr.GetCallExpr()
|
||||
if call == nil {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
switch call.Function {
|
||||
case "now":
|
||||
return timeNowUnix(), true, nil
|
||||
case "_+_", "_-_", "_*_":
|
||||
if len(call.Args) != 2 {
|
||||
return 0, false, errors.New("arithmetic requires two arguments")
|
||||
}
|
||||
left, ok, err := evaluateNumeric(call.Args[0])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
right, ok, err := evaluateNumeric(call.Args[1])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
switch call.Function {
|
||||
case "_+_":
|
||||
return left + right, true, nil
|
||||
case "_-_":
|
||||
return left - right, true, nil
|
||||
case "_*_":
|
||||
return left * right, true, nil
|
||||
default:
|
||||
return 0, false, errors.Errorf("unsupported arithmetic operator %q", call.Function)
|
||||
}
|
||||
default:
|
||||
return 0, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func timeNowUnix() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
// buildComprehensionCondition handles CEL comprehension expressions (exists, all, etc.).
|
||||
func buildComprehensionCondition(comp *exprv1.Expr_Comprehension, schema Schema) (Condition, error) {
|
||||
// Determine the comprehension kind by examining the loop initialization and step
|
||||
kind, err := detectComprehensionKind(comp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the field being iterated over
|
||||
iterRangeIdent := comp.IterRange.GetIdentExpr()
|
||||
if iterRangeIdent == nil {
|
||||
return nil, errors.New("comprehension range must be a field identifier")
|
||||
}
|
||||
fieldName := iterRangeIdent.GetName()
|
||||
|
||||
// Validate the field
|
||||
field, ok := schema.Field(fieldName)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown field %q in comprehension", fieldName)
|
||||
}
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return nil, errors.Errorf("field %q does not support comprehension (must be a list)", fieldName)
|
||||
}
|
||||
|
||||
// Extract the predicate from the loop step
|
||||
predicate, err := extractPredicate(comp, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ListComprehensionCondition{
|
||||
Kind: kind,
|
||||
Field: fieldName,
|
||||
IterVar: comp.IterVar,
|
||||
Predicate: predicate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// detectComprehensionKind determines if this is an exists() macro.
|
||||
// Only exists() is currently supported.
|
||||
func detectComprehensionKind(comp *exprv1.Expr_Comprehension) (ComprehensionKind, error) {
|
||||
// Check the accumulator initialization
|
||||
accuInit := comp.AccuInit.GetConstExpr()
|
||||
if accuInit == nil {
|
||||
return "", errors.New("comprehension accumulator must be initialized with a constant")
|
||||
}
|
||||
|
||||
// exists() starts with false and uses OR (||) in loop step
|
||||
if !accuInit.GetBoolValue() {
|
||||
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_||_" {
|
||||
return ComprehensionExists, nil
|
||||
}
|
||||
}
|
||||
|
||||
// all() starts with true and uses AND (&&) - not supported
|
||||
if accuInit.GetBoolValue() {
|
||||
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_&&_" {
|
||||
return "", errors.New("all() comprehension is not supported; use exists() instead")
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("unsupported comprehension type; only exists() is supported")
|
||||
}
|
||||
|
||||
// extractPredicate extracts the predicate expression from the comprehension loop step.
|
||||
func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, error) {
|
||||
// The loop step is: @result || predicate(t) for exists
|
||||
// or: @result && predicate(t) for all
|
||||
step := comp.LoopStep.GetCallExpr()
|
||||
if step == nil {
|
||||
return nil, errors.New("comprehension loop step must be a call expression")
|
||||
}
|
||||
|
||||
if len(step.Args) != 2 {
|
||||
return nil, errors.New("comprehension loop step must have two arguments")
|
||||
}
|
||||
|
||||
// The predicate is the second argument
|
||||
predicateExpr := step.Args[1]
|
||||
predicateCall := predicateExpr.GetCallExpr()
|
||||
if predicateCall == nil {
|
||||
return nil, errors.New("comprehension predicate must be a function call")
|
||||
}
|
||||
|
||||
// Handle different predicate functions
|
||||
switch predicateCall.Function {
|
||||
case "startsWith":
|
||||
return buildStartsWithPredicate(predicateCall, comp.IterVar)
|
||||
case "endsWith":
|
||||
return buildEndsWithPredicate(predicateCall, comp.IterVar)
|
||||
case "contains":
|
||||
return buildContainsPredicate(predicateCall, comp.IterVar)
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
|
||||
}
|
||||
}
|
||||
|
||||
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").
|
||||
func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
// Verify the target is the iteration variable
|
||||
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
|
||||
return nil, errors.Errorf("startsWith target must be the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("startsWith expects exactly one argument")
|
||||
}
|
||||
|
||||
prefix, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "startsWith argument must be a constant string")
|
||||
}
|
||||
|
||||
prefixStr, ok := prefix.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("startsWith argument must be a string")
|
||||
}
|
||||
|
||||
return &StartsWithPredicate{Prefix: prefixStr}, nil
|
||||
}
|
||||
|
||||
// buildEndsWithPredicate extracts the pattern from t.endsWith("suffix").
|
||||
func buildEndsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
|
||||
return nil, errors.Errorf("endsWith target must be the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("endsWith expects exactly one argument")
|
||||
}
|
||||
|
||||
suffix, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "endsWith argument must be a constant string")
|
||||
}
|
||||
|
||||
suffixStr, ok := suffix.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("endsWith argument must be a string")
|
||||
}
|
||||
|
||||
return &EndsWithPredicate{Suffix: suffixStr}, nil
|
||||
}
|
||||
|
||||
// buildContainsPredicate extracts the pattern from t.contains("substring").
|
||||
func buildContainsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
|
||||
return nil, errors.Errorf("contains target must be the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("contains expects exactly one argument")
|
||||
}
|
||||
|
||||
substring, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "contains argument must be a constant string")
|
||||
}
|
||||
|
||||
substringStr, ok := substring.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("contains argument must be a string")
|
||||
}
|
||||
|
||||
return &ContainsPredicate{Substring: substringStr}, nil
|
||||
}
|
||||
748
plugin/filter/render.go
Normal file
748
plugin/filter/render.go
Normal file
@@ -0,0 +1,748 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type renderer struct {
|
||||
schema Schema
|
||||
dialect DialectName
|
||||
placeholderOffset int
|
||||
placeholderCounter int
|
||||
args []any
|
||||
}
|
||||
|
||||
type renderResult struct {
|
||||
sql string
|
||||
trivial bool
|
||||
unsatisfiable bool
|
||||
}
|
||||
|
||||
func newRenderer(schema Schema, opts RenderOptions) *renderer {
|
||||
return &renderer{
|
||||
schema: schema,
|
||||
dialect: opts.Dialect,
|
||||
placeholderOffset: opts.PlaceholderOffset,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) Render(cond Condition) (Statement, error) {
|
||||
result, err := r.renderCondition(cond)
|
||||
if err != nil {
|
||||
return Statement{}, err
|
||||
}
|
||||
args := r.args
|
||||
if args == nil {
|
||||
args = []any{}
|
||||
}
|
||||
|
||||
switch {
|
||||
case result.unsatisfiable:
|
||||
return Statement{
|
||||
SQL: "1 = 0",
|
||||
Args: args,
|
||||
}, nil
|
||||
case result.trivial:
|
||||
return Statement{
|
||||
SQL: "",
|
||||
Args: args,
|
||||
}, nil
|
||||
default:
|
||||
return Statement{
|
||||
SQL: result.sql,
|
||||
Args: args,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderCondition(cond Condition) (renderResult, error) {
|
||||
switch c := cond.(type) {
|
||||
case *LogicalCondition:
|
||||
return r.renderLogicalCondition(c)
|
||||
case *NotCondition:
|
||||
return r.renderNotCondition(c)
|
||||
case *FieldPredicateCondition:
|
||||
return r.renderFieldPredicate(c)
|
||||
case *ComparisonCondition:
|
||||
return r.renderComparison(c)
|
||||
case *InCondition:
|
||||
return r.renderInCondition(c)
|
||||
case *ElementInCondition:
|
||||
return r.renderElementInCondition(c)
|
||||
case *ContainsCondition:
|
||||
return r.renderContainsCondition(c)
|
||||
case *ListComprehensionCondition:
|
||||
return r.renderListComprehension(c)
|
||||
case *ConstantCondition:
|
||||
if c.Value {
|
||||
return renderResult{trivial: true}, nil
|
||||
}
|
||||
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported condition type %T", c)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderLogicalCondition(cond *LogicalCondition) (renderResult, error) {
|
||||
left, err := r.renderCondition(cond.Left)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
right, err := r.renderCondition(cond.Right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
switch cond.Operator {
|
||||
case LogicalAnd:
|
||||
return combineAnd(left, right), nil
|
||||
case LogicalOr:
|
||||
return combineOr(left, right), nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported logical operator %s", cond.Operator)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderNotCondition(cond *NotCondition) (renderResult, error) {
|
||||
child, err := r.renderCondition(cond.Expr)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
if child.trivial {
|
||||
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
|
||||
}
|
||||
if child.unsatisfiable {
|
||||
return renderResult{trivial: true}, nil
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("NOT (%s)", child.sql),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderFieldPredicate(cond *FieldPredicateCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
|
||||
switch field.Kind {
|
||||
case FieldKindBoolColumn:
|
||||
column := qualifyColumn(r.dialect, field.Column)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s IS TRUE", column),
|
||||
}, nil
|
||||
case FieldKindJSONBool:
|
||||
sql, err := r.jsonBoolPredicate(field)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
return renderResult{sql: sql}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("field %q cannot be used as a predicate", cond.Field)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderComparison(cond *ComparisonCondition) (renderResult, error) {
|
||||
switch left := cond.Left.(type) {
|
||||
case *FieldRef:
|
||||
field, ok := r.schema.Field(left.Name)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", left.Name)
|
||||
}
|
||||
switch field.Kind {
|
||||
case FieldKindBoolColumn:
|
||||
return r.renderBoolColumnComparison(field, cond.Operator, cond.Right)
|
||||
case FieldKindJSONBool:
|
||||
return r.renderJSONBoolComparison(field, cond.Operator, cond.Right)
|
||||
case FieldKindScalar:
|
||||
return r.renderScalarComparison(field, cond.Operator, cond.Right)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("field %q does not support comparison", field.Name)
|
||||
}
|
||||
case *FunctionValue:
|
||||
return r.renderFunctionComparison(left, cond.Operator, cond.Right)
|
||||
default:
|
||||
return renderResult{}, errors.New("comparison must start with a field reference or supported function")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderFunctionComparison(fn *FunctionValue, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
if fn.Name != "size" {
|
||||
return renderResult{}, errors.Errorf("unsupported function %s in comparison", fn.Name)
|
||||
}
|
||||
if len(fn.Args) != 1 {
|
||||
return renderResult{}, errors.New("size() expects one argument")
|
||||
}
|
||||
fieldArg, ok := fn.Args[0].(*FieldRef)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("size() argument must be a field")
|
||||
}
|
||||
|
||||
field, ok := r.schema.Field(fieldArg.Name)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", fieldArg.Name)
|
||||
}
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return renderResult{}, errors.Errorf("size() only supports tag lists, got %q", field.Name)
|
||||
}
|
||||
|
||||
value, err := expectNumericLiteral(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
expr := jsonArrayLengthExpr(r.dialect, field)
|
||||
placeholder := r.addArg(value)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s %s", expr, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderScalarComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
lit, err := expectLiteral(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
columnExpr := field.columnExpr(r.dialect)
|
||||
if lit == nil {
|
||||
switch op {
|
||||
case CompareEq:
|
||||
return renderResult{sql: fmt.Sprintf("%s IS NULL", columnExpr)}, nil
|
||||
case CompareNeq:
|
||||
return renderResult{sql: fmt.Sprintf("%s IS NOT NULL", columnExpr)}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("operator %s not supported for null comparison", op)
|
||||
}
|
||||
}
|
||||
|
||||
placeholder := ""
|
||||
switch field.Type {
|
||||
case FieldTypeString:
|
||||
value, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("field %q expects string value", field.Name)
|
||||
}
|
||||
placeholder = r.addArg(value)
|
||||
case FieldTypeInt, FieldTypeTimestamp:
|
||||
num, err := toInt64(lit)
|
||||
if err != nil {
|
||||
return renderResult{}, errors.Wrapf(err, "field %q expects integer value", field.Name)
|
||||
}
|
||||
placeholder = r.addArg(num)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported data type %q for field %s", field.Type, field.Name)
|
||||
}
|
||||
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s %s", columnExpr, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderBoolColumnComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
value, err := expectBool(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
placeholder := r.addBoolArg(value)
|
||||
column := qualifyColumn(r.dialect, field.Column)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s %s", column, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderJSONBoolComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
value, err := expectBool(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
jsonExpr := jsonExtractExpr(r.dialect, field)
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
switch op {
|
||||
case CompareEq:
|
||||
if field.Name == "has_task_list" {
|
||||
target := "0"
|
||||
if value {
|
||||
target = "1"
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("%s = %s", jsonExpr, target)}, nil
|
||||
}
|
||||
if value {
|
||||
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
|
||||
case CompareNeq:
|
||||
if field.Name == "has_task_list" {
|
||||
target := "0"
|
||||
if value {
|
||||
target = "1"
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("%s != %s", jsonExpr, target)}, nil
|
||||
}
|
||||
if value {
|
||||
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("operator %s not supported for boolean JSON field", op)
|
||||
}
|
||||
case DialectMySQL:
|
||||
boolStr := "false"
|
||||
if value {
|
||||
boolStr = "true"
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s CAST('%s' AS JSON)", jsonExpr, sqlOperator(op), boolStr),
|
||||
}, nil
|
||||
case DialectPostgres:
|
||||
placeholder := r.addArg(value)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s)::boolean %s %s", jsonExpr, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderInCondition(cond *InCondition) (renderResult, error) {
|
||||
fieldRef, ok := cond.Left.(*FieldRef)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("IN operator requires a field on the left-hand side")
|
||||
}
|
||||
|
||||
if fieldRef.Name == "tag" {
|
||||
return r.renderTagInList(cond.Values)
|
||||
}
|
||||
|
||||
field, ok := r.schema.Field(fieldRef.Name)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", fieldRef.Name)
|
||||
}
|
||||
|
||||
if field.Kind != FieldKindScalar {
|
||||
return renderResult{}, errors.Errorf("field %q does not support IN()", fieldRef.Name)
|
||||
}
|
||||
|
||||
return r.renderScalarInCondition(field, cond.Values)
|
||||
}
|
||||
|
||||
func (r *renderer) renderTagInList(values []ValueExpr) (renderResult, error) {
|
||||
field, ok := r.schema.ResolveAlias("tag")
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("tag attribute is not configured")
|
||||
}
|
||||
|
||||
conditions := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
lit, err := expectLiteral(v)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
str, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("tags must be compared with string literals")
|
||||
}
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
// Support hierarchical tags: match exact tag OR tags with this prefix (e.g., "book" matches "book" and "book/something")
|
||||
exactMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
|
||||
prefixMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
|
||||
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
conditions = append(conditions, expr)
|
||||
case DialectMySQL:
|
||||
// Support hierarchical tags: match exact tag OR tags with this prefix
|
||||
exactMatch := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
||||
prefixMatch := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
|
||||
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
conditions = append(conditions, expr)
|
||||
case DialectPostgres:
|
||||
// Support hierarchical tags: match exact tag OR tags with this prefix
|
||||
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
||||
prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s/%%`, str)))
|
||||
expr := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
conditions = append(conditions, expr)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 1 {
|
||||
return renderResult{sql: conditions[0]}, nil
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderElementInCondition(cond *ElementInCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return renderResult{}, errors.Errorf("field %q is not a tag list", cond.Field)
|
||||
}
|
||||
|
||||
lit, err := expectLiteral(cond.Element)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
str, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("tags membership requires string literal")
|
||||
}
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
sql := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
|
||||
return renderResult{sql: sql}, nil
|
||||
case DialectMySQL:
|
||||
sql := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
||||
return renderResult{sql: sql}, nil
|
||||
case DialectPostgres:
|
||||
sql := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
||||
return renderResult{sql: sql}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderScalarInCondition(field Field, values []ValueExpr) (renderResult, error) {
|
||||
placeholders := make([]string, 0, len(values))
|
||||
|
||||
for _, v := range values {
|
||||
lit, err := expectLiteral(v)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
switch field.Type {
|
||||
case FieldTypeString:
|
||||
str, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("field %q expects string values", field.Name)
|
||||
}
|
||||
placeholders = append(placeholders, r.addArg(str))
|
||||
case FieldTypeInt:
|
||||
num, err := toInt64(lit)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
placeholders = append(placeholders, r.addArg(num))
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("field %q does not support IN() comparisons", field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
column := field.columnExpr(r.dialect)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
column := field.columnExpr(r.dialect)
|
||||
arg := fmt.Sprintf("%%%s%%", cond.Value)
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
// Use custom Unicode-aware case folding function for case-insensitive comparison.
|
||||
// This overcomes SQLite's ASCII-only LOWER() limitation.
|
||||
sql := fmt.Sprintf("memos_unicode_lower(%s) LIKE memos_unicode_lower(%s)", column, r.addArg(arg))
|
||||
return renderResult{sql: sql}, nil
|
||||
case DialectPostgres:
|
||||
sql := fmt.Sprintf("%s ILIKE %s", column, r.addArg(arg))
|
||||
return renderResult{sql: sql}, nil
|
||||
default:
|
||||
sql := fmt.Sprintf("%s LIKE %s", column, r.addArg(arg))
|
||||
return renderResult{sql: sql}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return renderResult{}, errors.Errorf("field %q is not a JSON list", cond.Field)
|
||||
}
|
||||
|
||||
// Render based on predicate type
|
||||
switch pred := cond.Predicate.(type) {
|
||||
case *StartsWithPredicate:
|
||||
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
|
||||
case *EndsWithPredicate:
|
||||
return r.renderTagEndsWith(field, pred.Suffix, cond.Kind)
|
||||
case *ContainsPredicate:
|
||||
return r.renderTagContains(field, pred.Substring, cond.Kind)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported predicate type %T in comprehension", pred)
|
||||
}
|
||||
}
|
||||
|
||||
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
|
||||
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
// Match exact tag or tags with this prefix (hierarchical support)
|
||||
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, prefix))
|
||||
prefixMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s%%`, prefix))
|
||||
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
|
||||
|
||||
case DialectPostgres:
|
||||
// Use PostgreSQL's powerful JSON operators
|
||||
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, prefix)))
|
||||
prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(fmt.Sprintf(`%%"%s%%`, prefix)))
|
||||
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
|
||||
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
// renderTagEndsWith generates SQL for tags.exists(t, t.endsWith("suffix")).
|
||||
func (r *renderer) renderTagEndsWith(field Field, suffix string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
pattern := fmt.Sprintf(`%%%s"%%`, suffix)
|
||||
|
||||
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
|
||||
}
|
||||
|
||||
// renderTagContains generates SQL for tags.exists(t, t.contains("substring")).
|
||||
func (r *renderer) renderTagContains(field Field, substring string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
pattern := fmt.Sprintf(`%%%s%%`, substring)
|
||||
|
||||
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
|
||||
}
|
||||
|
||||
// buildJSONArrayLike builds a LIKE expression for matching within a JSON array.
|
||||
// Returns the LIKE clause without NULL/empty checks.
|
||||
func (r *renderer) buildJSONArrayLike(arrayExpr, pattern string) string {
|
||||
switch r.dialect {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
return fmt.Sprintf("%s LIKE %s", arrayExpr, r.addArg(pattern))
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(pattern))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// wrapWithNullCheck wraps a condition with NULL and empty array checks.
|
||||
// This ensures we don't match against NULL or empty JSON arrays.
|
||||
func (r *renderer) wrapWithNullCheck(arrayExpr, condition string) string {
|
||||
var nullCheck string
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
nullCheck = fmt.Sprintf("%s IS NOT NULL AND %s != '[]'", arrayExpr, arrayExpr)
|
||||
case DialectMySQL:
|
||||
nullCheck = fmt.Sprintf("%s IS NOT NULL AND JSON_LENGTH(%s) > 0", arrayExpr, arrayExpr)
|
||||
case DialectPostgres:
|
||||
nullCheck = fmt.Sprintf("%s IS NOT NULL AND jsonb_array_length(%s) > 0", arrayExpr, arrayExpr)
|
||||
default:
|
||||
return condition
|
||||
}
|
||||
return fmt.Sprintf("(%s AND %s)", condition, nullCheck)
|
||||
}
|
||||
|
||||
func (r *renderer) jsonBoolPredicate(field Field) (string, error) {
|
||||
expr := jsonExtractExpr(r.dialect, field)
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
return fmt.Sprintf("%s IS TRUE", expr), nil
|
||||
case DialectMySQL:
|
||||
return fmt.Sprintf("COALESCE(%s, CAST('false' AS JSON)) = CAST('true' AS JSON)", expr), nil
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func combineAnd(left, right renderResult) renderResult {
|
||||
if left.unsatisfiable || right.unsatisfiable {
|
||||
return renderResult{sql: "1 = 0", unsatisfiable: true}
|
||||
}
|
||||
if left.trivial {
|
||||
return right
|
||||
}
|
||||
if right.trivial {
|
||||
return left
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s AND %s)", left.sql, right.sql),
|
||||
}
|
||||
}
|
||||
|
||||
func combineOr(left, right renderResult) renderResult {
|
||||
if left.trivial || right.trivial {
|
||||
return renderResult{trivial: true}
|
||||
}
|
||||
if left.unsatisfiable {
|
||||
return right
|
||||
}
|
||||
if right.unsatisfiable {
|
||||
return left
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s OR %s)", left.sql, right.sql),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) addArg(value any) string {
|
||||
r.placeholderCounter++
|
||||
r.args = append(r.args, value)
|
||||
if r.dialect == DialectPostgres {
|
||||
return fmt.Sprintf("$%d", r.placeholderOffset+r.placeholderCounter)
|
||||
}
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (r *renderer) addBoolArg(value bool) string {
|
||||
var v any
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
if value {
|
||||
v = 1
|
||||
} else {
|
||||
v = 0
|
||||
}
|
||||
default:
|
||||
v = value
|
||||
}
|
||||
return r.addArg(v)
|
||||
}
|
||||
|
||||
func expectLiteral(expr ValueExpr) (any, error) {
|
||||
lit, ok := expr.(*LiteralValue)
|
||||
if !ok {
|
||||
return nil, errors.New("expression must be a literal")
|
||||
}
|
||||
return lit.Value, nil
|
||||
}
|
||||
|
||||
func expectBool(expr ValueExpr) (bool, error) {
|
||||
lit, err := expectLiteral(expr)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
value, ok := lit.(bool)
|
||||
if !ok {
|
||||
return false, errors.New("boolean literal required")
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func expectNumericLiteral(expr ValueExpr) (int64, error) {
|
||||
lit, err := expectLiteral(expr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return toInt64(lit)
|
||||
}
|
||||
|
||||
func toInt64(value any) (int64, error) {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return int64(v), nil
|
||||
case int32:
|
||||
return int64(v), nil
|
||||
case int64:
|
||||
return v, nil
|
||||
case uint32:
|
||||
return int64(v), nil
|
||||
case uint64:
|
||||
return int64(v), nil
|
||||
case float32:
|
||||
return int64(v), nil
|
||||
case float64:
|
||||
return int64(v), nil
|
||||
default:
|
||||
return 0, errors.Errorf("cannot convert %T to int64", value)
|
||||
}
|
||||
}
|
||||
|
||||
func sqlOperator(op ComparisonOperator) string {
|
||||
return string(op)
|
||||
}
|
||||
|
||||
func qualifyColumn(d DialectName, col Column) string {
|
||||
switch d {
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("%s.%s", col.Table, col.Name)
|
||||
default:
|
||||
return fmt.Sprintf("`%s`.`%s`", col.Table, col.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func jsonPath(field Field) string {
|
||||
return "$." + strings.Join(field.JSONPath, ".")
|
||||
}
|
||||
|
||||
func jsonExtractExpr(d DialectName, field Field) string {
|
||||
column := qualifyColumn(d, field.Column)
|
||||
switch d {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
|
||||
case DialectPostgres:
|
||||
return buildPostgresJSONAccessor(column, field.JSONPath, true)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func jsonArrayExpr(d DialectName, field Field) string {
|
||||
column := qualifyColumn(d, field.Column)
|
||||
switch d {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
|
||||
case DialectPostgres:
|
||||
return buildPostgresJSONAccessor(column, field.JSONPath, false)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func jsonArrayLengthExpr(d DialectName, field Field) string {
|
||||
arrayExpr := jsonArrayExpr(d, field)
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
|
||||
case DialectMySQL:
|
||||
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("jsonb_array_length(COALESCE(%s, '[]'::jsonb))", arrayExpr)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func buildPostgresJSONAccessor(base string, path []string, terminalText bool) string {
|
||||
expr := base
|
||||
for idx, part := range path {
|
||||
if idx == len(path)-1 && terminalText {
|
||||
expr = fmt.Sprintf("%s->>'%s'", expr, part)
|
||||
} else {
|
||||
expr = fmt.Sprintf("%s->'%s'", expr, part)
|
||||
}
|
||||
}
|
||||
return expr
|
||||
}
|
||||
319
plugin/filter/schema.go
Normal file
319
plugin/filter/schema.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
)
|
||||
|
||||
// DialectName enumerates supported SQL dialects.
|
||||
type DialectName string
|
||||
|
||||
const (
|
||||
DialectSQLite DialectName = "sqlite"
|
||||
DialectMySQL DialectName = "mysql"
|
||||
DialectPostgres DialectName = "postgres"
|
||||
)
|
||||
|
||||
// FieldType represents the logical type of a field.
|
||||
type FieldType string
|
||||
|
||||
const (
|
||||
FieldTypeString FieldType = "string"
|
||||
FieldTypeInt FieldType = "int"
|
||||
FieldTypeBool FieldType = "bool"
|
||||
FieldTypeTimestamp FieldType = "timestamp"
|
||||
)
|
||||
|
||||
// FieldKind describes how a field is stored.
|
||||
type FieldKind string
|
||||
|
||||
const (
|
||||
FieldKindScalar FieldKind = "scalar"
|
||||
FieldKindBoolColumn FieldKind = "bool_column"
|
||||
FieldKindJSONBool FieldKind = "json_bool"
|
||||
FieldKindJSONList FieldKind = "json_list"
|
||||
FieldKindVirtualAlias FieldKind = "virtual_alias"
|
||||
)
|
||||
|
||||
// Column identifies the backing table column.
|
||||
type Column struct {
|
||||
Table string
|
||||
Name string
|
||||
}
|
||||
|
||||
// Field captures the schema metadata for an exposed CEL identifier.
|
||||
type Field struct {
|
||||
Name string
|
||||
Kind FieldKind
|
||||
Type FieldType
|
||||
Column Column
|
||||
JSONPath []string
|
||||
AliasFor string
|
||||
SupportsContains bool
|
||||
Expressions map[DialectName]string
|
||||
AllowedComparisonOps map[ComparisonOperator]bool
|
||||
}
|
||||
|
||||
// Schema collects CEL environment options and field metadata.
|
||||
type Schema struct {
|
||||
Name string
|
||||
Fields map[string]Field
|
||||
EnvOptions []cel.EnvOption
|
||||
}
|
||||
|
||||
// Field returns the field metadata if present.
|
||||
func (s Schema) Field(name string) (Field, bool) {
|
||||
f, ok := s.Fields[name]
|
||||
return f, ok
|
||||
}
|
||||
|
||||
// ResolveAlias resolves a virtual alias to its target field.
|
||||
func (s Schema) ResolveAlias(name string) (Field, bool) {
|
||||
field, ok := s.Fields[name]
|
||||
if !ok {
|
||||
return Field{}, false
|
||||
}
|
||||
if field.Kind == FieldKindVirtualAlias {
|
||||
target, ok := s.Fields[field.AliasFor]
|
||||
if !ok {
|
||||
return Field{}, false
|
||||
}
|
||||
return target, true
|
||||
}
|
||||
return field, true
|
||||
}
|
||||
|
||||
var nowFunction = cel.Function("now",
|
||||
cel.Overload("now",
|
||||
[]*cel.Type{},
|
||||
cel.IntType,
|
||||
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
|
||||
return types.Int(time.Now().Unix())
|
||||
}),
|
||||
),
|
||||
)
|
||||
|
||||
// NewSchema constructs the memo filter schema and CEL environment.
|
||||
func NewSchema() Schema {
|
||||
fields := map[string]Field{
|
||||
"content": {
|
||||
Name: "content",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "memo", Name: "content"},
|
||||
SupportsContains: true,
|
||||
Expressions: map[DialectName]string{},
|
||||
},
|
||||
"creator_id": {
|
||||
Name: "creator_id",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeInt,
|
||||
Column: Column{Table: "memo", Name: "creator_id"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"created_ts": {
|
||||
Name: "created_ts",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeTimestamp,
|
||||
Column: Column{Table: "memo", Name: "created_ts"},
|
||||
Expressions: map[DialectName]string{
|
||||
// MySQL stores created_ts as TIMESTAMP, needs conversion to epoch
|
||||
DialectMySQL: "UNIX_TIMESTAMP(%s)",
|
||||
// PostgreSQL and SQLite store created_ts as BIGINT (epoch), no conversion needed
|
||||
DialectPostgres: "%s",
|
||||
DialectSQLite: "%s",
|
||||
},
|
||||
},
|
||||
"updated_ts": {
|
||||
Name: "updated_ts",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeTimestamp,
|
||||
Column: Column{Table: "memo", Name: "updated_ts"},
|
||||
Expressions: map[DialectName]string{
|
||||
// MySQL stores updated_ts as TIMESTAMP, needs conversion to epoch
|
||||
DialectMySQL: "UNIX_TIMESTAMP(%s)",
|
||||
// PostgreSQL and SQLite store updated_ts as BIGINT (epoch), no conversion needed
|
||||
DialectPostgres: "%s",
|
||||
DialectSQLite: "%s",
|
||||
},
|
||||
},
|
||||
"pinned": {
|
||||
Name: "pinned",
|
||||
Kind: FieldKindBoolColumn,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "pinned"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"visibility": {
|
||||
Name: "visibility",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "memo", Name: "visibility"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"tags": {
|
||||
Name: "tags",
|
||||
Kind: FieldKindJSONList,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"tags"},
|
||||
},
|
||||
"tag": {
|
||||
Name: "tag",
|
||||
Kind: FieldKindVirtualAlias,
|
||||
Type: FieldTypeString,
|
||||
AliasFor: "tags",
|
||||
},
|
||||
"has_task_list": {
|
||||
Name: "has_task_list",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasTaskList"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"has_link": {
|
||||
Name: "has_link",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasLink"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"has_code": {
|
||||
Name: "has_code",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasCode"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"has_incomplete_tasks": {
|
||||
Name: "has_incomplete_tasks",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasIncompleteTasks"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
envOptions := []cel.EnvOption{
|
||||
cel.Variable("content", cel.StringType),
|
||||
cel.Variable("creator_id", cel.IntType),
|
||||
cel.Variable("created_ts", cel.IntType),
|
||||
cel.Variable("updated_ts", cel.IntType),
|
||||
cel.Variable("pinned", cel.BoolType),
|
||||
cel.Variable("tag", cel.StringType),
|
||||
cel.Variable("tags", cel.ListType(cel.StringType)),
|
||||
cel.Variable("visibility", cel.StringType),
|
||||
cel.Variable("has_task_list", cel.BoolType),
|
||||
cel.Variable("has_link", cel.BoolType),
|
||||
cel.Variable("has_code", cel.BoolType),
|
||||
cel.Variable("has_incomplete_tasks", cel.BoolType),
|
||||
nowFunction,
|
||||
}
|
||||
|
||||
return Schema{
|
||||
Name: "memo",
|
||||
Fields: fields,
|
||||
EnvOptions: envOptions,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAttachmentSchema constructs the attachment filter schema and CEL environment.
|
||||
func NewAttachmentSchema() Schema {
|
||||
fields := map[string]Field{
|
||||
"filename": {
|
||||
Name: "filename",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "attachment", Name: "filename"},
|
||||
SupportsContains: true,
|
||||
Expressions: map[DialectName]string{},
|
||||
},
|
||||
"mime_type": {
|
||||
Name: "mime_type",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "attachment", Name: "type"},
|
||||
Expressions: map[DialectName]string{},
|
||||
},
|
||||
"create_time": {
|
||||
Name: "create_time",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeTimestamp,
|
||||
Column: Column{Table: "attachment", Name: "created_ts"},
|
||||
Expressions: map[DialectName]string{
|
||||
// MySQL stores created_ts as TIMESTAMP, needs conversion to epoch
|
||||
DialectMySQL: "UNIX_TIMESTAMP(%s)",
|
||||
// PostgreSQL and SQLite store created_ts as BIGINT (epoch), no conversion needed
|
||||
DialectPostgres: "%s",
|
||||
DialectSQLite: "%s",
|
||||
},
|
||||
},
|
||||
"memo_id": {
|
||||
Name: "memo_id",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeInt,
|
||||
Column: Column{Table: "attachment", Name: "memo_id"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
envOptions := []cel.EnvOption{
|
||||
cel.Variable("filename", cel.StringType),
|
||||
cel.Variable("mime_type", cel.StringType),
|
||||
cel.Variable("create_time", cel.IntType),
|
||||
cel.Variable("memo_id", cel.AnyType),
|
||||
nowFunction,
|
||||
}
|
||||
|
||||
return Schema{
|
||||
Name: "attachment",
|
||||
Fields: fields,
|
||||
EnvOptions: envOptions,
|
||||
}
|
||||
}
|
||||
|
||||
// columnExpr returns the field expression for the given dialect, applying
|
||||
// any schema-specific overrides (e.g. UNIX timestamp conversions).
|
||||
func (f Field) columnExpr(d DialectName) string {
|
||||
base := qualifyColumn(d, f.Column)
|
||||
if expr, ok := f.Expressions[d]; ok && expr != "" {
|
||||
return fmt.Sprintf(expr, base)
|
||||
}
|
||||
return base
|
||||
}
|
||||
Reference in New Issue
Block a user