first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled

This commit is contained in:
2026-03-04 06:30:47 +00:00
commit bb402d4ccc
777 changed files with 135661 additions and 0 deletions

View File

@@ -0,0 +1,81 @@
package postgres
import (
"context"
"strings"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
payloadString := "{}"
if create.Payload != nil {
bytes, err := protojson.Marshal(create.Payload)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal activity payload")
}
payloadString = string(bytes)
}
fields := []string{"creator_id", "type", "level", "payload"}
args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
stmt := "INSERT INTO activity (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,
&create.CreatedTs,
); err != nil {
return nil, err
}
return create, nil
}
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
}
if find.Type != nil {
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String())
}
query := "SELECT id, creator_id, type, level, payload, created_ts FROM activity WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.Activity{}
for rows.Next() {
activity := &store.Activity{}
var payloadBytes []byte
if err := rows.Scan(
&activity.ID,
&activity.CreatorID,
&activity.Type,
&activity.Level,
&payloadBytes,
&activity.CreatedTs,
); err != nil {
return nil, err
}
payload := &storepb.ActivityPayload{}
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
return nil, err
}
activity.Payload = payload
list = append(list, activity)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}

View File

@@ -0,0 +1,221 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*store.Attachment, error) {
fields := []string{"uid", "filename", "blob", "type", "size", "creator_id", "memo_id", "storage_type", "reference", "payload"}
storageType := ""
if create.StorageType != storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED {
storageType = create.StorageType.String()
}
payloadString := "{}"
if create.Payload != nil {
bytes, err := protojson.Marshal(create.Payload)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal attachment payload")
}
payloadString = string(bytes)
}
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
stmt := "INSERT INTO attachment (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
return nil, err
}
return create, nil
}
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "attachment.id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.UID; v != nil {
where, args = append(where, "attachment.uid = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "attachment.creator_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Filename; v != nil {
where, args = append(where, "attachment.filename = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.FilenameSearch; v != nil {
where, args = append(where, "attachment.filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v))
}
if v := find.MemoID; v != nil {
where, args = append(where, "attachment.memo_id = "+placeholder(len(args)+1)), append(args, *v)
}
if len(find.MemoIDList) > 0 {
holders := make([]string, 0, len(find.MemoIDList))
for _, id := range find.MemoIDList {
holders = append(holders, placeholder(len(args)+1))
args = append(args, id)
}
where = append(where, "attachment.memo_id IN ("+strings.Join(holders, ", ")+")")
}
if find.HasRelatedMemo {
where = append(where, "attachment.memo_id IS NOT NULL")
}
if v := find.StorageType; v != nil {
where, args = append(where, "attachment.storage_type = "+placeholder(len(args)+1)), append(args, v.String())
}
if len(find.Filters) > 0 {
engine, err := filter.DefaultAttachmentEngine()
if err != nil {
return nil, errors.Wrap(err, "failed to get filter engine")
}
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectPostgres, &where, &args); err != nil {
return nil, errors.Wrap(err, "failed to append filter conditions")
}
}
fields := []string{
"attachment.id AS id",
"attachment.uid AS uid",
"attachment.filename AS filename",
"attachment.type AS type",
"attachment.size AS size",
"attachment.creator_id AS creator_id",
"attachment.created_ts AS created_ts",
"attachment.updated_ts AS updated_ts",
"attachment.memo_id AS memo_id",
"attachment.storage_type AS storage_type",
"attachment.reference AS reference",
"attachment.payload AS payload",
"CASE WHEN memo.uid IS NOT NULL THEN memo.uid ELSE NULL END AS memo_uid",
}
if find.GetBlob {
fields = append(fields, "attachment.blob AS blob")
}
query := fmt.Sprintf(`
SELECT
%s
FROM attachment
LEFT JOIN memo ON attachment.memo_id = memo.id
WHERE %s
ORDER BY attachment.updated_ts DESC
`, strings.Join(fields, ", "), strings.Join(where, " AND "))
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.Attachment, 0)
for rows.Next() {
attachment := store.Attachment{}
var memoID sql.NullInt32
var storageType string
var payloadBytes []byte
dests := []any{
&attachment.ID,
&attachment.UID,
&attachment.Filename,
&attachment.Type,
&attachment.Size,
&attachment.CreatorID,
&attachment.CreatedTs,
&attachment.UpdatedTs,
&memoID,
&storageType,
&attachment.Reference,
&payloadBytes,
&attachment.MemoUID,
}
if find.GetBlob {
dests = append(dests, &attachment.Blob)
}
if err := rows.Scan(dests...); err != nil {
return nil, err
}
if memoID.Valid {
attachment.MemoID = &memoID.Int32
}
attachment.StorageType = storepb.AttachmentStorageType(storepb.AttachmentStorageType_value[storageType])
payload := &storepb.AttachmentPayload{}
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
return nil, err
}
attachment.Payload = payload
list = append(list, &attachment)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachment) error {
set, args := []string{}, []any{}
if v := update.UID; v != nil {
set, args = append(set, "uid = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Filename; v != nil {
set, args = append(set, "filename = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.MemoID; v != nil {
set, args = append(set, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Reference; v != nil {
set, args = append(set, "reference = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Payload; v != nil {
bytes, err := protojson.Marshal(v)
if err != nil {
return errors.Wrap(err, "failed to marshal attachment payload")
}
set, args = append(set, "payload = "+placeholder(len(args)+1)), append(args, string(bytes))
}
stmt := `UPDATE attachment SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
args = append(args, update.ID)
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return nil
}
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
stmt := `DELETE FROM attachment WHERE id = $1`
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,26 @@
package postgres
import (
"fmt"
"strings"
"google.golang.org/protobuf/encoding/protojson"
)
var (
protojsonUnmarshaler = protojson.UnmarshalOptions{
DiscardUnknown: true,
}
)
func placeholder(n int) string {
return "$" + fmt.Sprint(n)
}
func placeholders(n int) string {
list := []string{}
for i := 0; i < n; i++ {
list = append(list, placeholder(i+1))
}
return strings.Join(list, ", ")
}

117
store/db/postgres/idp.go Normal file
View File

@@ -0,0 +1,117 @@
package postgres
import (
"context"
"strings"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
fields := []string{"name", "type", "identifier_filter", "config"}
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
return nil, err
}
identityProvider := create
return identityProvider, nil
}
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
name,
type,
identifier_filter,
config
FROM idp
WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
var identityProviders []*store.IdentityProvider
for rows.Next() {
var identityProvider store.IdentityProvider
var typeString string
if err := rows.Scan(
&identityProvider.ID,
&identityProvider.Name,
&typeString,
&identityProvider.IdentifierFilter,
&identityProvider.Config,
); err != nil {
return nil, err
}
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
identityProviders = append(identityProviders, &identityProvider)
}
if err := rows.Err(); err != nil {
return nil, err
}
return identityProviders, nil
}
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
set, args := []string{}, []any{}
if v := update.Name; v != nil {
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.IdentifierFilter; v != nil {
set, args = append(set, "identifier_filter = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Config; v != nil {
set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, *v)
}
stmt := `
UPDATE idp
SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, name, type, identifier_filter, config
`
args = append(args, update.ID)
var identityProvider store.IdentityProvider
var typeString string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&identityProvider.ID,
&identityProvider.Name,
&typeString,
&identityProvider.IdentifierFilter,
&identityProvider.Config,
); err != nil {
return nil, err
}
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
return &identityProvider, nil
}
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
where, args := []string{"id = $1"}, []any{delete.ID}
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err = result.RowsAffected(); err != nil {
return err
}
return nil
}

151
store/db/postgres/inbox.go Normal file
View File

@@ -0,0 +1,151 @@
package postgres
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) {
messageString := "{}"
if create.Message != nil {
bytes, err := protojson.Marshal(create.Message)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal inbox message")
}
messageString = string(bytes)
}
fields := []string{"sender_id", "receiver_id", "status", "message"}
args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
stmt := "INSERT INTO inbox (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,
&create.CreatedTs,
); err != nil {
return nil, err
}
return create, nil
}
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
}
if find.SenderID != nil {
where, args = append(where, "sender_id = "+placeholder(len(args)+1)), append(args, *find.SenderID)
}
if find.ReceiverID != nil {
where, args = append(where, "receiver_id = "+placeholder(len(args)+1)), append(args, *find.ReceiverID)
}
if find.Status != nil {
where, args = append(where, "status = "+placeholder(len(args)+1)), append(args, *find.Status)
}
if find.MessageType != nil {
// Filter by message type using PostgreSQL JSON extraction
// Note: The type field in JSON is stored as string representation of the enum name
// Cast to JSONB since the column is TEXT
if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED {
where, args = append(where, "(message::JSONB->>'type' IS NULL OR message::JSONB->>'type' = "+placeholder(len(args)+1)+")"), append(args, find.MessageType.String())
} else {
where, args = append(where, "message::JSONB->>'type' = "+placeholder(len(args)+1)), append(args, find.MessageType.String())
}
}
query := "SELECT id, created_ts, sender_id, receiver_id, status, message FROM inbox WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.Inbox{}
for rows.Next() {
inbox := &store.Inbox{}
var messageBytes []byte
if err := rows.Scan(
&inbox.ID,
&inbox.CreatedTs,
&inbox.SenderID,
&inbox.ReceiverID,
&inbox.Status,
&messageBytes,
); err != nil {
return nil, err
}
message := &storepb.InboxMessage{}
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
return nil, err
}
inbox.Message = message
list = append(list, inbox)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
list, err := d.ListInboxes(ctx, find)
if err != nil {
return nil, errors.Wrap(err, "failed to get inbox")
}
if len(list) != 1 {
return nil, errors.Errorf("unexpected inbox count: %d", len(list))
}
return list[0], nil
}
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
set, args := []string{"status = $1"}, []any{update.Status.String()}
args = append(args, update.ID)
query := "UPDATE inbox SET " + strings.Join(set, ", ") + " WHERE id = $2 RETURNING id, created_ts, sender_id, receiver_id, status, message"
inbox := &store.Inbox{}
var messageBytes []byte
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
&inbox.ID,
&inbox.CreatedTs,
&inbox.SenderID,
&inbox.ReceiverID,
&inbox.Status,
&messageBytes,
); err != nil {
return nil, err
}
message := &storepb.InboxMessage{}
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
return nil, err
}
inbox.Message = message
return inbox, nil
}
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
result, err := d.db.ExecContext(ctx, "DELETE FROM inbox WHERE id = $1", delete.ID)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,72 @@
package postgres
import (
"context"
"strings"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertInstanceSetting(ctx context.Context, upsert *store.InstanceSetting) (*store.InstanceSetting, error) {
stmt := `
INSERT INTO system_setting (
name, value, description
)
VALUES ($1, $2, $3)
ON CONFLICT(name) DO UPDATE
SET
value = EXCLUDED.value,
description = EXCLUDED.description
`
if _, err := d.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil {
return nil, err
}
return upsert, nil
}
func (d *DB) ListInstanceSettings(ctx context.Context, find *store.FindInstanceSetting) ([]*store.InstanceSetting, error) {
where, args := []string{"1 = 1"}, []any{}
if find.Name != "" {
where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, find.Name)
}
query := `
SELECT
name,
value,
description
FROM system_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.InstanceSetting{}
for rows.Next() {
systemSettingMessage := &store.InstanceSetting{}
if err := rows.Scan(
&systemSettingMessage.Name,
&systemSettingMessage.Value,
&systemSettingMessage.Description,
); err != nil {
return nil, err
}
list = append(list, systemSettingMessage)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteInstanceSetting(ctx context.Context, delete *store.DeleteInstanceSetting) error {
stmt := `DELETE FROM system_setting WHERE name = $1`
_, err := d.db.ExecContext(ctx, stmt, delete.Name)
return err
}

254
store/db/postgres/memo.go Normal file
View File

@@ -0,0 +1,254 @@
package postgres
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
fields := []string{"uid", "creator_id", "content", "visibility", "payload"}
payload := "{}"
if create.Payload != nil {
payloadBytes, err := protojson.Marshal(create.Payload)
if err != nil {
return nil, err
}
payload = string(payloadBytes)
}
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
// Add custom timestamps if provided
if create.CreatedTs != 0 {
fields = append(fields, "created_ts")
args = append(args, create.CreatedTs)
}
if create.UpdatedTs != 0 {
fields = append(fields, "updated_ts")
args = append(args, create.UpdatedTs)
}
stmt := "INSERT INTO memo (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err
}
return create, nil
}
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
where, args := []string{"1 = 1"}, []any{}
engine, err := filter.DefaultEngine()
if err != nil {
return nil, err
}
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectPostgres, &where, &args); err != nil {
return nil, err
}
if v := find.ID; v != nil {
where, args = append(where, "memo.id = "+placeholder(len(args)+1)), append(args, *v)
}
if len(find.IDList) > 0 {
holders := make([]string, 0, len(find.IDList))
for _, id := range find.IDList {
holders = append(holders, placeholder(len(args)+1))
args = append(args, id)
}
where = append(where, "memo.id IN ("+strings.Join(holders, ", ")+")")
}
if v := find.UID; v != nil {
where, args = append(where, "memo.uid = "+placeholder(len(args)+1)), append(args, *v)
}
if len(find.UIDList) > 0 {
holders := make([]string, 0, len(find.UIDList))
for _, uid := range find.UIDList {
holders = append(holders, placeholder(len(args)+1))
args = append(args, uid)
}
where = append(where, "memo.uid IN ("+strings.Join(holders, ", ")+")")
}
if v := find.CreatorID; v != nil {
where, args = append(where, "memo.creator_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "memo.row_status = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
holders := []string{}
for _, visibility := range v {
holders = append(holders, placeholder(len(args)+1))
args = append(args, visibility.String())
}
where = append(where, fmt.Sprintf("memo.visibility in (%s)", strings.Join(holders, ", ")))
}
if find.ExcludeComments {
where = append(where, "memo_relation.related_memo_id IS NULL")
}
order := "DESC"
if find.OrderByTimeAsc {
order = "ASC"
}
orderBy := []string{}
if find.OrderByPinned {
orderBy = append(orderBy, "pinned DESC")
}
if find.OrderByUpdatedTs {
orderBy = append(orderBy, "updated_ts "+order)
} else {
orderBy = append(orderBy, "created_ts "+order)
}
// Add id as final tie-breaker
orderBy = append(orderBy, "id DESC")
fields := []string{
`memo.id AS id`,
`memo.uid AS uid`,
`memo.creator_id AS creator_id`,
`memo.created_ts AS created_ts`,
`memo.updated_ts AS updated_ts`,
`memo.row_status AS row_status`,
`memo.visibility AS visibility`,
`memo.pinned AS pinned`,
`memo.payload AS payload`,
`CASE WHEN parent_memo.uid IS NOT NULL THEN parent_memo.uid ELSE NULL END AS parent_uid`,
}
if !find.ExcludeContent {
fields = append(fields, `memo.content AS content`)
}
query := `SELECT ` + strings.Join(fields, ", ") + `
FROM memo
LEFT JOIN memo_relation ON memo.id = memo_relation.memo_id AND memo_relation.type = 'COMMENT'
LEFT JOIN memo AS parent_memo ON memo_relation.related_memo_id = parent_memo.id
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY ` + strings.Join(orderBy, ", ")
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.Memo, 0)
for rows.Next() {
var memo store.Memo
var payloadBytes []byte
dests := []any{
&memo.ID,
&memo.UID,
&memo.CreatorID,
&memo.CreatedTs,
&memo.UpdatedTs,
&memo.RowStatus,
&memo.Visibility,
&memo.Pinned,
&payloadBytes,
&memo.ParentUID,
}
if !find.ExcludeContent {
dests = append(dests, &memo.Content)
}
if err := rows.Scan(dests...); err != nil {
return nil, err
}
payload := &storepb.MemoPayload{}
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal payload")
}
memo.Payload = payload
list = append(list, &memo)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, error) {
list, err := d.ListMemos(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
memo := list[0]
return memo, nil
}
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
set, args := []string{}, []any{}
if v := update.UID; v != nil {
set, args = append(set, "uid = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.CreatedTs; v != nil {
set, args = append(set, "created_ts = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Content; v != nil {
set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Visibility; v != nil {
set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Pinned; v != nil {
set, args = append(set, "pinned = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Payload; v != nil {
payloadBytes, err := protojson.Marshal(v)
if err != nil {
return err
}
set, args = append(set, "payload = "+placeholder(len(args)+1)), append(args, string(payloadBytes))
}
if len(set) == 0 {
return nil
}
stmt := `UPDATE memo SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
args = append(args, update.ID)
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
return err
}
return nil
}
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
where, args := []string{"id = " + placeholder(1)}, []any{delete.ID}
stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ")
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return errors.Wrap(err, "failed to delete memo")
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,134 @@
package postgres
import (
"context"
"fmt"
"strings"
"github.com/usememos/memos/plugin/filter"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
stmt := `
INSERT INTO memo_relation (
memo_id,
related_memo_id,
type
)
VALUES (` + placeholders(3) + `)
ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET type = EXCLUDED.type
RETURNING memo_id, related_memo_id, type
`
memoRelation := &store.MemoRelation{}
if err := d.db.QueryRowContext(
ctx,
stmt,
create.MemoID,
create.RelatedMemoID,
create.Type,
).Scan(
&memoRelation.MemoID,
&memoRelation.RelatedMemoID,
&memoRelation.Type,
); err != nil {
return nil, err
}
return memoRelation, nil
}
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
where, args := []string{"1 = 1"}, []any{}
if find.MemoID != nil {
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID)
}
if find.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, find.RelatedMemoID)
}
if find.Type != nil {
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type)
}
if find.MemoFilter != nil {
engine, err := filter.DefaultEngine()
if err != nil {
return nil, err
}
stmt, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{
Dialect: filter.DialectPostgres,
PlaceholderOffset: len(args),
})
if err != nil {
return nil, err
}
if stmt.SQL != "" {
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
args = append(args, stmt.Args...)
stmtRelated, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{
Dialect: filter.DialectPostgres,
PlaceholderOffset: len(args),
})
if err != nil {
return nil, err
}
if stmtRelated.SQL != "" {
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", stmtRelated.SQL))
args = append(args, stmtRelated.Args...)
}
}
}
rows, err := d.db.QueryContext(ctx, `
SELECT
memo_id,
related_memo_id,
type
FROM memo_relation
WHERE `+strings.Join(where, " AND "), args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.MemoRelation{}
for rows.Next() {
memoRelation := &store.MemoRelation{}
if err := rows.Scan(
&memoRelation.MemoID,
&memoRelation.RelatedMemoID,
&memoRelation.Type,
); err != nil {
return nil, err
}
list = append(list, memoRelation)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
where, args := []string{"1 = 1"}, []any{}
if delete.MemoID != nil {
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, delete.MemoID)
}
if delete.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, delete.RelatedMemoID)
}
if delete.Type != nil {
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, delete.Type)
}
stmt := `DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ")
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err = result.RowsAffected(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,57 @@
package postgres
import (
"context"
"database/sql"
"log"
// Import the PostgreSQL driver.
_ "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/store"
)
type DB struct {
db *sql.DB
profile *profile.Profile
}
func NewDB(profile *profile.Profile) (store.Driver, error) {
if profile == nil {
return nil, errors.New("profile is nil")
}
// Open the PostgreSQL connection
db, err := sql.Open("postgres", profile.DSN)
if err != nil {
log.Printf("Failed to open database: %s", err)
return nil, errors.Wrapf(err, "failed to open database: %s", profile.DSN)
}
var driver store.Driver = &DB{
db: db,
profile: profile,
}
// Return the DB struct
return driver, nil
}
func (d *DB) GetDB() *sql.DB {
return d.db
}
func (d *DB) Close() error {
return d.db.Close()
}
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
var exists bool
err := d.db.QueryRowContext(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_catalog = current_database() AND table_name = 'memo' AND table_type = 'BASE TABLE')").Scan(&exists)
if err != nil {
return false, errors.Wrap(err, "failed to check if database is initialized")
}
return exists, nil
}

View File

@@ -0,0 +1,101 @@
package postgres
import (
"context"
"strings"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store.Reaction, error) {
fields := []string{"creator_id", "content_id", "reaction_type"}
args := []interface{}{upsert.CreatorID, upsert.ContentID, upsert.ReactionType}
stmt := "INSERT INTO reaction (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&upsert.ID,
&upsert.CreatedTs,
); err != nil {
return nil, err
}
reaction := upsert
return reaction, nil
}
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
}
if find.CreatorID != nil {
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *find.CreatorID)
}
if find.ContentID != nil {
where, args = append(where, "content_id = "+placeholder(len(args)+1)), append(args, *find.ContentID)
}
if len(find.ContentIDList) > 0 {
holders := make([]string, 0, len(find.ContentIDList))
for _, id := range find.ContentIDList {
holders = append(holders, placeholder(len(args)+1))
args = append(args, id)
}
where = append(where, "content_id IN ("+strings.Join(holders, ", ")+")")
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
created_ts,
creator_id,
content_id,
reaction_type
FROM reaction
WHERE `+strings.Join(where, " AND ")+`
ORDER BY id ASC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.Reaction{}
for rows.Next() {
reaction := &store.Reaction{}
if err := rows.Scan(
&reaction.ID,
&reaction.CreatedTs,
&reaction.CreatorID,
&reaction.ContentID,
&reaction.ReactionType,
); err != nil {
return nil, err
}
list = append(list, reaction)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) {
list, err := d.ListReactions(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
reaction := list[0]
return reaction, nil
}
func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error {
_, err := d.db.ExecContext(ctx, "DELETE FROM reaction WHERE id = $1", delete.ID)
return err
}

172
store/db/postgres/user.go Normal file
View File

@@ -0,0 +1,172 @@
package postgres
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
)
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
fields := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"}
args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
stmt := "INSERT INTO \"user\" (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, description, created_ts, updated_ts, row_status"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,
&create.Description,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err
}
return create, nil
}
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Username; v != nil {
set, args = append(set, "username = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Email; v != nil {
set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Nickname; v != nil {
set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.AvatarURL; v != nil {
set, args = append(set, "avatar_url = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.PasswordHash; v != nil {
set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Description; v != nil {
set, args = append(set, "description = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Role; v != nil {
set, args = append(set, "role = "+placeholder(len(args)+1)), append(args, *v)
}
query := `
UPDATE "user"
SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, username, role, email, nickname, password_hash, avatar_url, description, created_ts, updated_ts, row_status
`
args = append(args, update.ID)
user := &store.User{}
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.Username,
&user.Role,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.AvatarURL,
&user.Description,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
); err != nil {
return nil, err
}
return user, nil
}
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
where, args := []string{"1 = 1"}, []any{}
if len(find.Filters) > 0 {
return nil, errors.Errorf("user filters are not supported")
}
if v := find.ID; v != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Username; v != nil {
where, args = append(where, "username = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Email; v != nil {
where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v)
}
orderBy := []string{"created_ts DESC", "row_status DESC"}
query := `
SELECT
id,
username,
role,
email,
nickname,
password_hash,
avatar_url,
description,
created_ts,
updated_ts,
row_status
FROM "user"
WHERE ` + strings.Join(where, " AND ") + ` ORDER BY ` + strings.Join(orderBy, ", ")
if v := find.Limit; v != nil {
query += fmt.Sprintf(" LIMIT %d", *v)
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.User, 0)
for rows.Next() {
var user store.User
if err := rows.Scan(
&user.ID,
&user.Username,
&user.Role,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.AvatarURL,
&user.Description,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
); err != nil {
return nil, err
}
list = append(list, &user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
result, err := d.db.ExecContext(ctx, `DELETE FROM "user" WHERE id = $1`, delete.ID)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,121 @@
package postgres
import (
"context"
"strings"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
stmt := `
INSERT INTO user_setting (
user_id, key, value
)
VALUES ($1, $2, $3)
ON CONFLICT(user_id, key) DO UPDATE
SET value = EXCLUDED.value
`
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value); err != nil {
return nil, err
}
return upsert, nil
}
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String())
}
if v := find.UserID; v != nil {
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID)
}
query := `
SELECT
user_id,
key,
value
FROM user_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
userSettingList := make([]*store.UserSetting, 0)
for rows.Next() {
userSetting := &store.UserSetting{}
var keyString string
if err := rows.Scan(
&userSetting.UserID,
&keyString,
&userSetting.Value,
); err != nil {
return nil, err
}
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
userSettingList = append(userSettingList, userSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
return userSettingList, nil
}
func (d *DB) GetUserByPATHash(ctx context.Context, tokenHash string) (*store.PATQueryResult, error) {
// Simplified query: fetch all PERSONAL_ACCESS_TOKENS rows and search in Go
// This matches SQLite/MySQL behavior and avoids PostgreSQL's strict JSONB errors
query := `
SELECT
user_id,
value
FROM user_setting
WHERE key = 'PERSONAL_ACCESS_TOKENS'
`
rows, err := d.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
// Iterate through all users with PAT settings
for rows.Next() {
var userID int32
var tokensJSON string
if err := rows.Scan(&userID, &tokensJSON); err != nil {
continue // Skip malformed rows
}
// Try to unmarshal - skip if invalid JSON
patsUserSetting := &storepb.PersonalAccessTokensUserSetting{}
if err := protojsonUnmarshaler.Unmarshal([]byte(tokensJSON), patsUserSetting); err != nil {
continue // Skip invalid JSON
}
// Search for matching token hash
for _, pat := range patsUserSetting.Tokens {
if pat.TokenHash == tokenHash {
return &store.PATQueryResult{
UserID: userID,
PAT: pat,
}, nil
}
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return nil, errors.New("PAT not found")
}

View File

@@ -0,0 +1,290 @@
package postgres
import (
"context"
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// TestGetUserByPATHashWithMissingData tests the fix for #5611 and #5612.
// Verifies that GetUserByPATHash handles missing/malformed data gracefully
// instead of throwing PostgreSQL JSONB errors.
func TestGetUserByPATHashWithMissingData(t *testing.T) {
if testing.Short() {
t.Skip("Skipping PostgreSQL integration test in short mode")
}
// This test requires a real PostgreSQL connection
// If DSN is not provided, skip the test
dsn := getTestDSN()
if dsn == "" {
t.Skip("PostgreSQL DSN not provided, skipping test")
}
db, err := sql.Open("postgres", dsn)
require.NoError(t, err)
defer db.Close()
// Create test database
ctx := context.Background()
driver := &DB{db: db}
// Setup: Create user_setting table if needed
_, err = db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_setting (
user_id INTEGER NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
UNIQUE(user_id, key)
)
`)
require.NoError(t, err)
// Cleanup
defer func() {
db.ExecContext(ctx, "DELETE FROM user_setting WHERE user_id IN (1001, 1002, 1003)")
}()
t.Run("NoTokensKeyAtAll", func(t *testing.T) {
// Test case: User has no PERSONAL_ACCESS_TOKENS key
// This simulates fresh users or users upgraded from v0.25.3
result, err := driver.GetUserByPATHash(ctx, "any-hash")
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "PAT not found")
})
t.Run("EmptyTokensArray", func(t *testing.T) {
// Insert user with empty tokens array
_, err := db.ExecContext(ctx, `
INSERT INTO user_setting (user_id, key, value)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
`, 1001, "PERSONAL_ACCESS_TOKENS", `{"tokens":[]}`)
require.NoError(t, err)
result, err := driver.GetUserByPATHash(ctx, "any-hash")
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "PAT not found")
})
t.Run("MalformedJSON", func(t *testing.T) {
// Insert user with malformed JSON
_, err := db.ExecContext(ctx, `
INSERT INTO user_setting (user_id, key, value)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
`, 1002, "PERSONAL_ACCESS_TOKENS", `{invalid json}`)
require.NoError(t, err)
// Should handle gracefully without crashing
result, err := driver.GetUserByPATHash(ctx, "any-hash")
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "PAT not found")
})
t.Run("MissingTokensField", func(t *testing.T) {
// Insert user with valid JSON but missing 'tokens' field
_, err := db.ExecContext(ctx, `
INSERT INTO user_setting (user_id, key, value)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
`, 1003, "PERSONAL_ACCESS_TOKENS", `{"someOtherField":"value"}`)
require.NoError(t, err)
// Should handle gracefully
result, err := driver.GetUserByPATHash(ctx, "any-hash")
assert.Error(t, err)
assert.Nil(t, result)
})
t.Run("ValidTokenFound", func(t *testing.T) {
// Insert user with valid PAT
validJSON := `{
"tokens": [
{
"tokenId": "pat-test",
"tokenHash": "hash-test-123",
"description": "Test PAT"
}
]
}`
_, err := db.ExecContext(ctx, `
INSERT INTO user_setting (user_id, key, value)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
`, 1001, "PERSONAL_ACCESS_TOKENS", validJSON)
require.NoError(t, err)
// Should find the token
result, err := driver.GetUserByPATHash(ctx, "hash-test-123")
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int32(1001), result.UserID)
assert.Equal(t, "pat-test", result.PAT.TokenId)
assert.Equal(t, "hash-test-123", result.PAT.TokenHash)
})
t.Run("MultipleUsersWithMixedData", func(t *testing.T) {
// User 1001: Valid PAT
validJSON := `{
"tokens": [
{
"tokenId": "pat-user1",
"tokenHash": "hash-user1",
"description": "User 1 PAT"
}
]
}`
_, err := db.ExecContext(ctx, `
INSERT INTO user_setting (user_id, key, value)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
`, 1001, "PERSONAL_ACCESS_TOKENS", validJSON)
require.NoError(t, err)
// User 1002: Malformed JSON (should be skipped)
_, err = db.ExecContext(ctx, `
INSERT INTO user_setting (user_id, key, value)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
`, 1002, "PERSONAL_ACCESS_TOKENS", `{invalid}`)
require.NoError(t, err)
// User 1003: Empty array (should be skipped)
_, err = db.ExecContext(ctx, `
INSERT INTO user_setting (user_id, key, value)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
`, 1003, "PERSONAL_ACCESS_TOKENS", `{"tokens":[]}`)
require.NoError(t, err)
// Should still find user 1001's token despite other users having bad data
result, err := driver.GetUserByPATHash(ctx, "hash-user1")
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int32(1001), result.UserID)
})
}
// TestGetUserByPATHashPerformance ensures the simplified query doesn't cause performance issues.
func TestGetUserByPATHashPerformance(t *testing.T) {
if testing.Short() {
t.Skip("Skipping performance test in short mode")
}
dsn := getTestDSN()
if dsn == "" {
t.Skip("PostgreSQL DSN not provided, skipping test")
}
db, err := sql.Open("postgres", dsn)
require.NoError(t, err)
defer db.Close()
ctx := context.Background()
driver := &DB{db: db}
// Setup table
_, err = db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_setting (
user_id INTEGER NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
UNIQUE(user_id, key)
)
`)
require.NoError(t, err)
// Cleanup
defer func() {
db.ExecContext(ctx, "DELETE FROM user_setting WHERE user_id >= 2000 AND user_id < 2100")
}()
// Insert 100 users with PATs
for i := 2000; i < 2100; i++ {
json := `{
"tokens": [
{
"tokenId": "pat-` + string(rune(i)) + `",
"tokenHash": "hash-` + string(rune(i)) + `",
"description": "Test PAT"
}
]
}`
_, err = db.ExecContext(ctx, `
INSERT INTO user_setting (user_id, key, value)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
`, i, "PERSONAL_ACCESS_TOKENS", json)
require.NoError(t, err)
}
// Query should complete quickly even with 100 users
result, err := driver.GetUserByPATHash(ctx, "hash-"+string(rune(2050)))
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int32(2050), result.UserID)
}
// getTestDSN returns PostgreSQL DSN from environment or returns empty string.
func getTestDSN() string {
// For unit tests, we expect TEST_POSTGRES_DSN to be set.
// Example: TEST_POSTGRES_DSN="postgresql://user:pass@localhost:5432/memos_test?sslmode=disable".
return ""
}
// TestUpsertUserSetting tests basic upsert functionality.
func TestUpsertUserSetting(t *testing.T) {
dsn := getTestDSN()
if dsn == "" {
t.Skip("PostgreSQL DSN not provided, skipping test")
}
db, err := sql.Open("postgres", dsn)
require.NoError(t, err)
defer db.Close()
ctx := context.Background()
driver := &DB{db: db}
// Setup
_, err = db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_setting (
user_id INTEGER NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
UNIQUE(user_id, key)
)
`)
require.NoError(t, err)
defer func() {
db.ExecContext(ctx, "DELETE FROM user_setting WHERE user_id = 9999")
}()
// Test insert
setting := &store.UserSetting{
UserID: 9999,
Key: storepb.UserSetting_GENERAL,
Value: `{"locale":"en"}`,
}
result, err := driver.UpsertUserSetting(ctx, setting)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int32(9999), result.UserID)
// Test update (upsert on conflict)
setting.Value = `{"locale":"zh"}`
result, err = driver.UpsertUserSetting(ctx, setting)
assert.NoError(t, err)
assert.Equal(t, `{"locale":"zh"}`, result.Value)
}