first commit
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
Some checks failed
Backend Tests / Static Checks (push) Has been cancelled
Backend Tests / Tests (other) (push) Has been cancelled
Backend Tests / Tests (plugin) (push) Has been cancelled
Backend Tests / Tests (server) (push) Has been cancelled
Backend Tests / Tests (store) (push) Has been cancelled
Build Canary Image / build-frontend (push) Has been cancelled
Build Canary Image / build-push (linux/amd64) (push) Has been cancelled
Build Canary Image / build-push (linux/arm64) (push) Has been cancelled
Build Canary Image / merge (push) Has been cancelled
Frontend Tests / Lint (push) Has been cancelled
Frontend Tests / Build (push) Has been cancelled
Proto Linter / Lint Protos (push) Has been cancelled
This commit is contained in:
32
store/db/db.go
Normal file
32
store/db/db.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
"github.com/usememos/memos/store/db/mysql"
|
||||
"github.com/usememos/memos/store/db/postgres"
|
||||
"github.com/usememos/memos/store/db/sqlite"
|
||||
)
|
||||
|
||||
// NewDBDriver creates new db driver based on profile.
|
||||
func NewDBDriver(profile *profile.Profile) (store.Driver, error) {
|
||||
var driver store.Driver
|
||||
var err error
|
||||
|
||||
switch profile.Driver {
|
||||
case "sqlite":
|
||||
driver, err = sqlite.NewDB(profile)
|
||||
case "mysql":
|
||||
driver, err = mysql.NewDB(profile)
|
||||
case "postgres":
|
||||
driver, err = postgres.NewDB(profile)
|
||||
default:
|
||||
return nil, errors.New("unknown db driver")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create db driver")
|
||||
}
|
||||
return driver, nil
|
||||
}
|
||||
93
store/db/mysql/activity.go
Normal file
93
store/db/mysql/activity.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal activity payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
fields := []string{"`creator_id`", "`type`", "`level`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?"}
|
||||
args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
|
||||
|
||||
stmt := "INSERT INTO `activity` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute statement")
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get last insert id")
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
|
||||
list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id32})
|
||||
if err != nil || len(list) == 0 {
|
||||
return nil, errors.Wrap(err, "failed to find activity")
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, find.Type.String())
|
||||
}
|
||||
|
||||
query := "SELECT `id`, `creator_id`, `type`, `level`, `payload`, UNIX_TIMESTAMP(`created_ts`) FROM `activity` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Activity{}
|
||||
for rows.Next() {
|
||||
activity := &store.Activity{}
|
||||
var payloadBytes []byte
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&payloadBytes,
|
||||
&activity.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := &storepb.ActivityPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
activity.Payload = payload
|
||||
list = append(list, activity)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
241
store/db/mysql/attachment.go
Normal file
241
store/db/mysql/attachment.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*store.Attachment, error) {
|
||||
fields := []string{"`uid`", "`filename`", "`blob`", "`type`", "`size`", "`creator_id`", "`memo_id`", "`storage_type`", "`reference`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?", "?", "?", "?", "?"}
|
||||
storageType := ""
|
||||
if create.StorageType != storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED {
|
||||
storageType = create.StorageType.String()
|
||||
}
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
|
||||
|
||||
stmt := "INSERT INTO `attachment` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
return d.GetAttachment(ctx, &store.FindAttachment{ID: &id32})
|
||||
}
|
||||
|
||||
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`attachment`.`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`attachment`.`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`attachment`.`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "`attachment`.`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.FilenameSearch; v != nil {
|
||||
where, args = append(where, "`attachment`.`filename` LIKE ?"), append(args, "%"+*v+"%")
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "`attachment`.`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if len(find.MemoIDList) > 0 {
|
||||
placeholders := make([]string, 0, len(find.MemoIDList))
|
||||
for range find.MemoIDList {
|
||||
placeholders = append(placeholders, "?")
|
||||
}
|
||||
where = append(where, "`attachment`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
|
||||
for _, id := range find.MemoIDList {
|
||||
args = append(args, id)
|
||||
}
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "`attachment`.`memo_id` IS NOT NULL")
|
||||
}
|
||||
if find.StorageType != nil {
|
||||
where, args = append(where, "`attachment`.`storage_type` = ?"), append(args, find.StorageType.String())
|
||||
}
|
||||
|
||||
if len(find.Filters) > 0 {
|
||||
engine, err := filter.DefaultAttachmentEngine()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get filter engine")
|
||||
}
|
||||
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectMySQL, &where, &args); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to append filter conditions")
|
||||
}
|
||||
}
|
||||
|
||||
fields := []string{
|
||||
"`attachment`.`id` AS `id`",
|
||||
"`attachment`.`uid` AS `uid`",
|
||||
"`attachment`.`filename` AS `filename`",
|
||||
"`attachment`.`type` AS `type`",
|
||||
"`attachment`.`size` AS `size`",
|
||||
"`attachment`.`creator_id` AS `creator_id`",
|
||||
"UNIX_TIMESTAMP(`attachment`.`created_ts`) AS `created_ts`",
|
||||
"UNIX_TIMESTAMP(`attachment`.`updated_ts`) AS `updated_ts`",
|
||||
"`attachment`.`memo_id` AS `memo_id`",
|
||||
"`attachment`.`storage_type` AS `storage_type`",
|
||||
"`attachment`.`reference` AS `reference`",
|
||||
"`attachment`.`payload` AS `payload`",
|
||||
"CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`",
|
||||
}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "`attachment`.`blob` AS `blob`")
|
||||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " FROM `attachment`" + " " +
|
||||
"LEFT JOIN `memo` ON `attachment`.`memo_id` = `memo`.`id`" + " " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
"ORDER BY `updated_ts` DESC"
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Attachment, 0)
|
||||
for rows.Next() {
|
||||
attachment := store.Attachment{}
|
||||
var memoID sql.NullInt32
|
||||
var storageType string
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&attachment.ID,
|
||||
&attachment.UID,
|
||||
&attachment.Filename,
|
||||
&attachment.Type,
|
||||
&attachment.Size,
|
||||
&attachment.CreatorID,
|
||||
&attachment.CreatedTs,
|
||||
&attachment.UpdatedTs,
|
||||
&memoID,
|
||||
&storageType,
|
||||
&attachment.Reference,
|
||||
&payloadBytes,
|
||||
&attachment.MemoUID,
|
||||
}
|
||||
if find.GetBlob {
|
||||
dests = append(dests, &attachment.Blob)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if memoID.Valid {
|
||||
attachment.MemoID = &memoID.Int32
|
||||
}
|
||||
attachment.StorageType = storepb.AttachmentStorageType(storepb.AttachmentStorageType_value[storageType])
|
||||
payload := &storepb.AttachmentPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attachment.Payload = payload
|
||||
list = append(list, &attachment)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetAttachment(ctx context.Context, find *store.FindAttachment) (*store.Attachment, error) {
|
||||
list, err := d.ListAttachments(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachment) error {
|
||||
set, args := []string{}, []any{}
|
||||
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.Filename; v != nil {
|
||||
set, args = append(set, "`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.MemoID; v != nil {
|
||||
set, args = append(set, "`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Reference; v != nil {
|
||||
set, args = append(set, "`reference` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
bytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
set, args = append(set, "`payload` = ?"), append(args, string(bytes))
|
||||
}
|
||||
|
||||
args = append(args, update.ID)
|
||||
stmt := "UPDATE `attachment` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
|
||||
stmt := "DELETE FROM `attachment` WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
10
store/db/mysql/common.go
Normal file
10
store/db/mysql/common.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package mysql
|
||||
|
||||
import "google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
var (
|
||||
protojsonUnmarshaler = protojson.UnmarshalOptions{
|
||||
AllowPartial: true,
|
||||
DiscardUnknown: true,
|
||||
}
|
||||
)
|
||||
126
store/db/mysql/idp.go
Normal file
126
store/db/mysql/idp.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
placeholders := []string{"?", "?", "?", "?"}
|
||||
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
|
||||
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
|
||||
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
create.ID = int32(id)
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, "SELECT `id`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC",
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var identityProviders []*store.IdentityProvider
|
||||
for rows.Next() {
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
identityProviders = append(identityProviders, &identityProvider)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identityProviders, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) {
|
||||
list, err := d.ListIdentityProviders(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
identityProvider := list[0]
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.Name; v != nil {
|
||||
set, args = append(set, "`name` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
set, args = append(set, "`identifier_filter` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
set, args = append(set, "`config` = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := "UPDATE `idp` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
_, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider, err := d.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &update.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, errors.Errorf("idp %d not found", update.ID)
|
||||
}
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
|
||||
where, args := []string{"`id` = ?"}, []any{delete.ID}
|
||||
stmt := "DELETE FROM `idp` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
150
store/db/mysql/inbox.go
Normal file
150
store/db/mysql/inbox.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) {
|
||||
messageString := "{}"
|
||||
if create.Message != nil {
|
||||
bytes, err := protojson.Marshal(create.Message)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal inbox message")
|
||||
}
|
||||
messageString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"`sender_id`", "`receiver_id`", "`status`", "`message`"}
|
||||
placeholder := []string{"?", "?", "?", "?"}
|
||||
args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
|
||||
|
||||
stmt := "INSERT INTO `inbox` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
inbox, err := d.GetInbox(ctx, &store.FindInbox{ID: &id32})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.SenderID != nil {
|
||||
where, args = append(where, "`sender_id` = ?"), append(args, *find.SenderID)
|
||||
}
|
||||
if find.ReceiverID != nil {
|
||||
where, args = append(where, "`receiver_id` = ?"), append(args, *find.ReceiverID)
|
||||
}
|
||||
if find.Status != nil {
|
||||
where, args = append(where, "`status` = ?"), append(args, *find.Status)
|
||||
}
|
||||
if find.MessageType != nil {
|
||||
// Filter by message type using JSON extraction
|
||||
// Note: The type field in JSON is stored as string representation of the enum name
|
||||
if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED {
|
||||
where, args = append(where, "(JSON_EXTRACT(`message`, '$.type') IS NULL OR JSON_EXTRACT(`message`, '$.type') = ?)"), append(args, find.MessageType.String())
|
||||
} else {
|
||||
where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String())
|
||||
}
|
||||
}
|
||||
|
||||
query := "SELECT `id`, UNIX_TIMESTAMP(`created_ts`), `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Inbox{}
|
||||
for rows.Next() {
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := rows.Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
list = append(list, inbox)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
|
||||
list, err := d.ListInboxes(ctx, find)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get inbox")
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected inbox count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
|
||||
set, args := []string{"`status` = ?"}, []any{update.Status.String()}
|
||||
args = append(args, update.ID)
|
||||
query := "UPDATE `inbox` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to update inbox")
|
||||
}
|
||||
inbox, err := d.GetInbox(ctx, &store.FindInbox{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM `inbox` WHERE `id` = ?", delete.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to delete inbox")
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
65
store/db/mysql/instance_setting.go
Normal file
65
store/db/mysql/instance_setting.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertInstanceSetting(ctx context.Context, upsert *store.InstanceSetting) (*store.InstanceSetting, error) {
|
||||
stmt := "INSERT INTO `system_setting` (`name`, `value`, `description`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `value` = ?, `description` = ?"
|
||||
_, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
upsert.Name,
|
||||
upsert.Value,
|
||||
upsert.Description,
|
||||
upsert.Value,
|
||||
upsert.Description,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInstanceSettings(ctx context.Context, find *store.FindInstanceSetting) ([]*store.InstanceSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Name != "" {
|
||||
where, args = append(where, "`name` = ?"), append(args, find.Name)
|
||||
}
|
||||
|
||||
query := "SELECT `name`, `value`, `description` FROM `system_setting` WHERE " + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.InstanceSetting{}
|
||||
for rows.Next() {
|
||||
systemSettingMessage := &store.InstanceSetting{}
|
||||
if err := rows.Scan(
|
||||
&systemSettingMessage.Name,
|
||||
&systemSettingMessage.Value,
|
||||
&systemSettingMessage.Description,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, systemSettingMessage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInstanceSetting(ctx context.Context, delete *store.DeleteInstanceSetting) error {
|
||||
stmt := "DELETE FROM `system_setting` WHERE `name` = ?"
|
||||
_, err := d.db.ExecContext(ctx, stmt, delete.Name)
|
||||
return err
|
||||
}
|
||||
269
store/db/mysql/memo.go
Normal file
269
store/db/mysql/memo.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
|
||||
fields := []string{"`uid`", "`creator_id`", "`content`", "`visibility`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?"}
|
||||
payload := "{}"
|
||||
if create.Payload != nil {
|
||||
payloadBytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload = string(payloadBytes)
|
||||
}
|
||||
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
|
||||
|
||||
// Add custom timestamps if provided
|
||||
if create.CreatedTs != 0 {
|
||||
fields = append(fields, "`created_ts`")
|
||||
placeholder = append(placeholder, "FROM_UNIXTIME(?)")
|
||||
args = append(args, create.CreatedTs)
|
||||
}
|
||||
if create.UpdatedTs != 0 {
|
||||
fields = append(fields, "`updated_ts`")
|
||||
placeholder = append(placeholder, "FROM_UNIXTIME(?)")
|
||||
args = append(args, create.UpdatedTs)
|
||||
}
|
||||
|
||||
stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := int32(rawID)
|
||||
memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, errors.Errorf("failed to create memo")
|
||||
}
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
|
||||
where, having, args := []string{"1 = 1"}, []string{"1 = 1"}, []any{}
|
||||
|
||||
engine, err := filter.DefaultEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectMySQL, &where, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`memo`.`id` = ?"), append(args, *v)
|
||||
}
|
||||
if len(find.IDList) > 0 {
|
||||
placeholders := make([]string, 0, len(find.IDList))
|
||||
for range find.IDList {
|
||||
placeholders = append(placeholders, "?")
|
||||
}
|
||||
where = append(where, "`memo`.`id` IN ("+strings.Join(placeholders, ",")+")")
|
||||
for _, id := range find.IDList {
|
||||
args = append(args, id)
|
||||
}
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`memo`.`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if len(find.UIDList) > 0 {
|
||||
placeholders := make([]string, 0, len(find.UIDList))
|
||||
for range find.UIDList {
|
||||
placeholders = append(placeholders, "?")
|
||||
}
|
||||
where = append(where, "`memo`.`uid` IN ("+strings.Join(placeholders, ",")+")")
|
||||
for _, uid := range find.UIDList {
|
||||
args = append(args, uid)
|
||||
}
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`memo`.`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "`memo`.`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
placeholder := []string{}
|
||||
for _, visibility := range v {
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, visibility.String())
|
||||
}
|
||||
where = append(where, fmt.Sprintf("`memo`.`visibility` in (%s)", strings.Join(placeholder, ",")))
|
||||
}
|
||||
if find.ExcludeComments {
|
||||
having = append(having, "`parent_uid` IS NULL")
|
||||
}
|
||||
|
||||
order := "DESC"
|
||||
if find.OrderByTimeAsc {
|
||||
order = "ASC"
|
||||
}
|
||||
orderBy := []string{}
|
||||
if find.OrderByPinned {
|
||||
orderBy = append(orderBy, "`pinned` DESC")
|
||||
}
|
||||
if find.OrderByUpdatedTs {
|
||||
orderBy = append(orderBy, "`updated_ts` "+order)
|
||||
} else {
|
||||
orderBy = append(orderBy, "`created_ts` "+order)
|
||||
}
|
||||
// Add id as final tie-breaker
|
||||
orderBy = append(orderBy, "`id` DESC")
|
||||
fields := []string{
|
||||
"`memo`.`id` AS `id`",
|
||||
"`memo`.`uid` AS `uid`",
|
||||
"`memo`.`creator_id` AS `creator_id`",
|
||||
"UNIX_TIMESTAMP(`memo`.`created_ts`) AS `created_ts`",
|
||||
"UNIX_TIMESTAMP(`memo`.`updated_ts`) AS `updated_ts`",
|
||||
"`memo`.`row_status` AS `row_status`",
|
||||
"`memo`.`visibility` AS `visibility`",
|
||||
"`memo`.`pinned` AS `pinned`",
|
||||
"`memo`.`payload` AS `payload`",
|
||||
"CASE WHEN `parent_memo`.`uid` IS NOT NULL THEN `parent_memo`.`uid` ELSE NULL END AS `parent_uid`",
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
fields = append(fields, "`memo`.`content` AS `content`")
|
||||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " FROM `memo`" + " " +
|
||||
"LEFT JOIN `memo_relation` ON `memo`.`id` = `memo_relation`.`memo_id` AND `memo_relation`.`type` = 'COMMENT'" + " " +
|
||||
"LEFT JOIN `memo` AS `parent_memo` ON `memo_relation`.`related_memo_id` = `parent_memo`.`id`" + " " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
"HAVING " + strings.Join(having, " AND ") + " " +
|
||||
"ORDER BY " + strings.Join(orderBy, ", ")
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Memo, 0)
|
||||
for rows.Next() {
|
||||
var memo store.Memo
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&memo.ID,
|
||||
&memo.UID,
|
||||
&memo.CreatorID,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&memo.RowStatus,
|
||||
&memo.Visibility,
|
||||
&memo.Pinned,
|
||||
&payloadBytes,
|
||||
&memo.ParentUID,
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
dests = append(dests, &memo.Content)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := &storepb.MemoPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal payload")
|
||||
}
|
||||
memo.Payload = payload
|
||||
list = append(list, &memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, error) {
|
||||
list, err := d.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
memo := list[0]
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.CreatedTs; v != nil {
|
||||
set, args = append(set, "`created_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Content; v != nil {
|
||||
set, args = append(set, "`content` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Visibility; v != nil {
|
||||
set, args = append(set, "`visibility` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Pinned; v != nil {
|
||||
set, args = append(set, "`pinned` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
payloadBytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
set, args = append(set, "`payload` = ?"), append(args, string(payloadBytes))
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := "UPDATE `memo` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
where, args := []string{"`id` = ?"}, []any{delete.ID}
|
||||
stmt := "DELETE FROM `memo` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
110
store/db/mysql/memo_relation.go
Normal file
110
store/db/mysql/memo_relation.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
stmt := "INSERT INTO `memo_relation` (`memo_id`, `related_memo_id`, `type`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `type` = `type`"
|
||||
_, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.MemoID,
|
||||
create.RelatedMemoID,
|
||||
create.Type,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memoRelation := store.MemoRelation{
|
||||
MemoID: create.MemoID,
|
||||
RelatedMemoID: create.RelatedMemoID,
|
||||
Type: create.Type,
|
||||
}
|
||||
|
||||
return &memoRelation, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if find.MemoID != nil {
|
||||
where, args = append(where, "`memo_id` = ?"), append(args, find.MemoID)
|
||||
}
|
||||
if find.RelatedMemoID != nil {
|
||||
where, args = append(where, "`related_memo_id` = ?"), append(args, find.RelatedMemoID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, find.Type)
|
||||
}
|
||||
if find.MemoFilter != nil {
|
||||
engine, err := filter.DefaultEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{
|
||||
Dialect: filter.DialectMySQL,
|
||||
PlaceholderOffset: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if stmt.SQL != "" {
|
||||
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
|
||||
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
|
||||
args = append(args, append(stmt.Args, stmt.Args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, "SELECT `memo_id`, `related_memo_id`, `type` FROM `memo_relation` WHERE "+strings.Join(where, " AND "), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.MemoRelation{}
|
||||
for rows.Next() {
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := rows.Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoRelation)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if delete.MemoID != nil {
|
||||
where, args = append(where, "`memo_id` = ?"), append(args, delete.MemoID)
|
||||
}
|
||||
if delete.RelatedMemoID != nil {
|
||||
where, args = append(where, "`related_memo_id` = ?"), append(args, delete.RelatedMemoID)
|
||||
}
|
||||
if delete.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, delete.Type)
|
||||
}
|
||||
stmt := "DELETE FROM `memo_relation` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
68
store/db/mysql/mysql.go
Normal file
68
store/db/mysql/mysql.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
config *mysql.Config
|
||||
}
|
||||
|
||||
func NewDB(profile *profile.Profile) (store.Driver, error) {
|
||||
// Open MySQL connection with parameter.
|
||||
// multiStatements=true is required for migration.
|
||||
// See more in: https://github.com/go-sql-driver/mysql#multistatements
|
||||
dsn, err := mergeDSN(profile.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
driver := DB{profile: profile}
|
||||
driver.config, err = mysql.ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, errors.New("Parse DSN eroor")
|
||||
}
|
||||
|
||||
driver.db, err = sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to open db: %s", profile.DSN)
|
||||
}
|
||||
|
||||
return &driver, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
|
||||
var exists bool
|
||||
err := d.db.QueryRowContext(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'memo' AND TABLE_TYPE = 'BASE TABLE')").Scan(&exists)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func mergeDSN(baseDSN string) (string, error) {
|
||||
config, err := mysql.ParseDSN(baseDSN)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to parse DSN: %s", baseDSN)
|
||||
}
|
||||
|
||||
config.MultiStatements = true
|
||||
return config.FormatDSN(), nil
|
||||
}
|
||||
113
store/db/mysql/reaction.go
Normal file
113
store/db/mysql/reaction.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store.Reaction, error) {
|
||||
fields := []string{"`creator_id`", "`content_id`", "`reaction_type`"}
|
||||
placeholder := []string{"?", "?", "?"}
|
||||
args := []interface{}{upsert.CreatorID, upsert.ContentID, upsert.ReactionType}
|
||||
stmt := "INSERT INTO `reaction` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := int32(rawID)
|
||||
reaction, err := d.GetReaction(ctx, &store.FindReaction{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reaction == nil {
|
||||
return nil, errors.Errorf("failed to create reaction")
|
||||
}
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.CreatorID != nil {
|
||||
where, args = append(where, "`creator_id` = ?"), append(args, *find.CreatorID)
|
||||
}
|
||||
if find.ContentID != nil {
|
||||
where, args = append(where, "`content_id` = ?"), append(args, *find.ContentID)
|
||||
}
|
||||
if len(find.ContentIDList) > 0 {
|
||||
placeholders := make([]string, 0, len(find.ContentIDList))
|
||||
for _, id := range find.ContentIDList {
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, id)
|
||||
}
|
||||
where = append(where, "`content_id` IN ("+strings.Join(placeholders, ",")+")")
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
UNIX_TIMESTAMP(created_ts) AS created_ts,
|
||||
creator_id,
|
||||
content_id,
|
||||
reaction_type
|
||||
FROM reaction
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Reaction{}
|
||||
for rows.Next() {
|
||||
reaction := &store.Reaction{}
|
||||
if err := rows.Scan(
|
||||
&reaction.ID,
|
||||
&reaction.CreatedTs,
|
||||
&reaction.CreatorID,
|
||||
&reaction.ContentID,
|
||||
&reaction.ReactionType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, reaction)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) {
|
||||
list, err := d.ListReactions(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reaction := list[0]
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM `reaction` WHERE `id` = ?", delete.ID)
|
||||
return err
|
||||
}
|
||||
166
store/db/mysql/user.go
Normal file
166
store/db/mysql/user.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
fields := []string{"`username`", "`role`", "`email`", "`nickname`", "`password_hash`", "`avatar_url`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?"}
|
||||
args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
|
||||
|
||||
stmt := "INSERT INTO user (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
list, err := d.ListUsers(ctx, &store.FindUser{ID: &id32})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected user count: %d", len(list))
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
set, args = append(set, "`username` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "`email` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "`nickname` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
set, args = append(set, "`avatar_url` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "`password_hash` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Description; v != nil {
|
||||
set, args = append(set, "`description` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Role; v != nil {
|
||||
set, args = append(set, "`role` = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
query := "UPDATE `user` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if len(find.Filters) > 0 {
|
||||
return nil, errors.Errorf("user filters are not supported")
|
||||
}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
where, args = append(where, "`username` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "`role` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "`email` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "`nickname` = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
orderBy := []string{"`created_ts` DESC", "`row_status` DESC"}
|
||||
query := "SELECT `id`, `username`, `role`, `email`, `nickname`, `password_hash`, `avatar_url`, `description`, UNIX_TIMESTAMP(`created_ts`), UNIX_TIMESTAMP(`updated_ts`), `row_status` FROM `user` WHERE " + strings.Join(where, " AND ") + " ORDER BY " + strings.Join(orderBy, ", ")
|
||||
if v := find.Limit; v != nil {
|
||||
query += fmt.Sprintf(" LIMIT %d", *v)
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.User, 0)
|
||||
for rows.Next() {
|
||||
var user store.User
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) {
|
||||
list, err := d.ListUsers(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected user count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM `user` WHERE `id` = ?", delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
93
store/db/mysql/user_setting.go
Normal file
93
store/db/mysql/user_setting.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
|
||||
stmt := "INSERT INTO `user_setting` (`user_id`, `key`, `value`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `value` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value, upsert.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "`key` = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "`user_id` = ?"), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := "SELECT `user_id`, `key`, `value` FROM `user_setting` WHERE " + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*store.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &store.UserSetting{}
|
||||
var keyString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserID,
|
||||
&keyString,
|
||||
&userSetting.Value,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUserByPATHash(ctx context.Context, tokenHash string) (*store.PATQueryResult, error) {
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + "`key`" + ` = 'PERSONAL_ACCESS_TOKENS'
|
||||
AND JSON_SEARCH(value, 'one', ?, NULL, '$.tokens[*].tokenHash') IS NOT NULL
|
||||
`
|
||||
|
||||
var userID int32
|
||||
var tokensJSON string
|
||||
|
||||
err := d.db.QueryRowContext(ctx, query, tokenHash).Scan(&userID, &tokensJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
patsUserSetting := &storepb.PersonalAccessTokensUserSetting{}
|
||||
if err := protojsonUnmarshaler.Unmarshal([]byte(tokensJSON), patsUserSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, pat := range patsUserSetting.Tokens {
|
||||
if pat.TokenHash == tokenHash {
|
||||
return &store.PATQueryResult{
|
||||
UserID: userID,
|
||||
PAT: pat,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("PAT not found")
|
||||
}
|
||||
81
store/db/postgres/activity.go
Normal file
81
store/db/postgres/activity.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal activity payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"creator_id", "type", "level", "payload"}
|
||||
args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
|
||||
stmt := "INSERT INTO activity (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String())
|
||||
}
|
||||
|
||||
query := "SELECT id, creator_id, type, level, payload, created_ts FROM activity WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Activity{}
|
||||
for rows.Next() {
|
||||
activity := &store.Activity{}
|
||||
var payloadBytes []byte
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&payloadBytes,
|
||||
&activity.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := &storepb.ActivityPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
activity.Payload = payload
|
||||
list = append(list, activity)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
221
store/db/postgres/attachment.go
Normal file
221
store/db/postgres/attachment.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*store.Attachment, error) {
|
||||
fields := []string{"uid", "filename", "blob", "type", "size", "creator_id", "memo_id", "storage_type", "reference", "payload"}
|
||||
storageType := ""
|
||||
if create.StorageType != storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED {
|
||||
storageType = create.StorageType.String()
|
||||
}
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
|
||||
|
||||
stmt := "INSERT INTO attachment (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "attachment.id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "attachment.uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "attachment.creator_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "attachment.filename = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.FilenameSearch; v != nil {
|
||||
where, args = append(where, "attachment.filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v))
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "attachment.memo_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if len(find.MemoIDList) > 0 {
|
||||
holders := make([]string, 0, len(find.MemoIDList))
|
||||
for _, id := range find.MemoIDList {
|
||||
holders = append(holders, placeholder(len(args)+1))
|
||||
args = append(args, id)
|
||||
}
|
||||
where = append(where, "attachment.memo_id IN ("+strings.Join(holders, ", ")+")")
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "attachment.memo_id IS NOT NULL")
|
||||
}
|
||||
if v := find.StorageType; v != nil {
|
||||
where, args = append(where, "attachment.storage_type = "+placeholder(len(args)+1)), append(args, v.String())
|
||||
}
|
||||
|
||||
if len(find.Filters) > 0 {
|
||||
engine, err := filter.DefaultAttachmentEngine()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get filter engine")
|
||||
}
|
||||
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectPostgres, &where, &args); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to append filter conditions")
|
||||
}
|
||||
}
|
||||
|
||||
fields := []string{
|
||||
"attachment.id AS id",
|
||||
"attachment.uid AS uid",
|
||||
"attachment.filename AS filename",
|
||||
"attachment.type AS type",
|
||||
"attachment.size AS size",
|
||||
"attachment.creator_id AS creator_id",
|
||||
"attachment.created_ts AS created_ts",
|
||||
"attachment.updated_ts AS updated_ts",
|
||||
"attachment.memo_id AS memo_id",
|
||||
"attachment.storage_type AS storage_type",
|
||||
"attachment.reference AS reference",
|
||||
"attachment.payload AS payload",
|
||||
"CASE WHEN memo.uid IS NOT NULL THEN memo.uid ELSE NULL END AS memo_uid",
|
||||
}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "attachment.blob AS blob")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
%s
|
||||
FROM attachment
|
||||
LEFT JOIN memo ON attachment.memo_id = memo.id
|
||||
WHERE %s
|
||||
ORDER BY attachment.updated_ts DESC
|
||||
`, strings.Join(fields, ", "), strings.Join(where, " AND "))
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Attachment, 0)
|
||||
for rows.Next() {
|
||||
attachment := store.Attachment{}
|
||||
var memoID sql.NullInt32
|
||||
var storageType string
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&attachment.ID,
|
||||
&attachment.UID,
|
||||
&attachment.Filename,
|
||||
&attachment.Type,
|
||||
&attachment.Size,
|
||||
&attachment.CreatorID,
|
||||
&attachment.CreatedTs,
|
||||
&attachment.UpdatedTs,
|
||||
&memoID,
|
||||
&storageType,
|
||||
&attachment.Reference,
|
||||
&payloadBytes,
|
||||
&attachment.MemoUID,
|
||||
}
|
||||
if find.GetBlob {
|
||||
dests = append(dests, &attachment.Blob)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if memoID.Valid {
|
||||
attachment.MemoID = &memoID.Int32
|
||||
}
|
||||
attachment.StorageType = storepb.AttachmentStorageType(storepb.AttachmentStorageType_value[storageType])
|
||||
payload := &storepb.AttachmentPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attachment.Payload = payload
|
||||
list = append(list, &attachment)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachment) error {
|
||||
set, args := []string{}, []any{}
|
||||
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Filename; v != nil {
|
||||
set, args = append(set, "filename = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.MemoID; v != nil {
|
||||
set, args = append(set, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Reference; v != nil {
|
||||
set, args = append(set, "reference = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
bytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
set, args = append(set, "payload = "+placeholder(len(args)+1)), append(args, string(bytes))
|
||||
}
|
||||
|
||||
stmt := `UPDATE attachment SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
|
||||
args = append(args, update.ID)
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
|
||||
stmt := `DELETE FROM attachment WHERE id = $1`
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
26
store/db/postgres/common.go
Normal file
26
store/db/postgres/common.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
var (
|
||||
protojsonUnmarshaler = protojson.UnmarshalOptions{
|
||||
DiscardUnknown: true,
|
||||
}
|
||||
)
|
||||
|
||||
func placeholder(n int) string {
|
||||
return "$" + fmt.Sprint(n)
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
list := []string{}
|
||||
for i := 0; i < n; i++ {
|
||||
list = append(list, placeholder(i+1))
|
||||
}
|
||||
return strings.Join(list, ", ")
|
||||
}
|
||||
117
store/db/postgres/idp.go
Normal file
117
store/db/postgres/idp.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
fields := []string{"name", "type", "identifier_filter", "config"}
|
||||
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider := create
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
config
|
||||
FROM idp
|
||||
WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var identityProviders []*store.IdentityProvider
|
||||
for rows.Next() {
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
identityProviders = append(identityProviders, &identityProvider)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identityProviders, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.Name; v != nil {
|
||||
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
set, args = append(set, "identifier_filter = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
stmt := `
|
||||
UPDATE idp
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING id, name, type, identifier_filter, config
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
return &identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
|
||||
where, args := []string{"id = $1"}, []any{delete.ID}
|
||||
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
151
store/db/postgres/inbox.go
Normal file
151
store/db/postgres/inbox.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) {
|
||||
messageString := "{}"
|
||||
if create.Message != nil {
|
||||
bytes, err := protojson.Marshal(create.Message)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal inbox message")
|
||||
}
|
||||
messageString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"sender_id", "receiver_id", "status", "message"}
|
||||
args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
|
||||
stmt := "INSERT INTO inbox (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||
}
|
||||
if find.SenderID != nil {
|
||||
where, args = append(where, "sender_id = "+placeholder(len(args)+1)), append(args, *find.SenderID)
|
||||
}
|
||||
if find.ReceiverID != nil {
|
||||
where, args = append(where, "receiver_id = "+placeholder(len(args)+1)), append(args, *find.ReceiverID)
|
||||
}
|
||||
if find.Status != nil {
|
||||
where, args = append(where, "status = "+placeholder(len(args)+1)), append(args, *find.Status)
|
||||
}
|
||||
if find.MessageType != nil {
|
||||
// Filter by message type using PostgreSQL JSON extraction
|
||||
// Note: The type field in JSON is stored as string representation of the enum name
|
||||
// Cast to JSONB since the column is TEXT
|
||||
if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED {
|
||||
where, args = append(where, "(message::JSONB->>'type' IS NULL OR message::JSONB->>'type' = "+placeholder(len(args)+1)+")"), append(args, find.MessageType.String())
|
||||
} else {
|
||||
where, args = append(where, "message::JSONB->>'type' = "+placeholder(len(args)+1)), append(args, find.MessageType.String())
|
||||
}
|
||||
}
|
||||
|
||||
query := "SELECT id, created_ts, sender_id, receiver_id, status, message FROM inbox WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Inbox{}
|
||||
for rows.Next() {
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := rows.Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
list = append(list, inbox)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
|
||||
list, err := d.ListInboxes(ctx, find)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get inbox")
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected inbox count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
|
||||
set, args := []string{"status = $1"}, []any{update.Status.String()}
|
||||
args = append(args, update.ID)
|
||||
query := "UPDATE inbox SET " + strings.Join(set, ", ") + " WHERE id = $2 RETURNING id, created_ts, sender_id, receiver_id, status, message"
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM inbox WHERE id = $1", delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
72
store/db/postgres/instance_setting.go
Normal file
72
store/db/postgres/instance_setting.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertInstanceSetting(ctx context.Context, upsert *store.InstanceSetting) (*store.InstanceSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO system_setting (
|
||||
name, value, description
|
||||
)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT(name) DO UPDATE
|
||||
SET
|
||||
value = EXCLUDED.value,
|
||||
description = EXCLUDED.description
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInstanceSettings(ctx context.Context, find *store.FindInstanceSetting) ([]*store.InstanceSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Name != "" {
|
||||
where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, find.Name)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
name,
|
||||
value,
|
||||
description
|
||||
FROM system_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.InstanceSetting{}
|
||||
for rows.Next() {
|
||||
systemSettingMessage := &store.InstanceSetting{}
|
||||
if err := rows.Scan(
|
||||
&systemSettingMessage.Name,
|
||||
&systemSettingMessage.Value,
|
||||
&systemSettingMessage.Description,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, systemSettingMessage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInstanceSetting(ctx context.Context, delete *store.DeleteInstanceSetting) error {
|
||||
stmt := `DELETE FROM system_setting WHERE name = $1`
|
||||
_, err := d.db.ExecContext(ctx, stmt, delete.Name)
|
||||
return err
|
||||
}
|
||||
254
store/db/postgres/memo.go
Normal file
254
store/db/postgres/memo.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
|
||||
fields := []string{"uid", "creator_id", "content", "visibility", "payload"}
|
||||
payload := "{}"
|
||||
if create.Payload != nil {
|
||||
payloadBytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload = string(payloadBytes)
|
||||
}
|
||||
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
|
||||
|
||||
// Add custom timestamps if provided
|
||||
if create.CreatedTs != 0 {
|
||||
fields = append(fields, "created_ts")
|
||||
args = append(args, create.CreatedTs)
|
||||
}
|
||||
if create.UpdatedTs != 0 {
|
||||
fields = append(fields, "updated_ts")
|
||||
args = append(args, create.UpdatedTs)
|
||||
}
|
||||
|
||||
stmt := "INSERT INTO memo (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
engine, err := filter.DefaultEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectPostgres, &where, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "memo.id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if len(find.IDList) > 0 {
|
||||
holders := make([]string, 0, len(find.IDList))
|
||||
for _, id := range find.IDList {
|
||||
holders = append(holders, placeholder(len(args)+1))
|
||||
args = append(args, id)
|
||||
}
|
||||
where = append(where, "memo.id IN ("+strings.Join(holders, ", ")+")")
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "memo.uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if len(find.UIDList) > 0 {
|
||||
holders := make([]string, 0, len(find.UIDList))
|
||||
for _, uid := range find.UIDList {
|
||||
holders = append(holders, placeholder(len(args)+1))
|
||||
args = append(args, uid)
|
||||
}
|
||||
where = append(where, "memo.uid IN ("+strings.Join(holders, ", ")+")")
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "memo.creator_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "memo.row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
holders := []string{}
|
||||
for _, visibility := range v {
|
||||
holders = append(holders, placeholder(len(args)+1))
|
||||
args = append(args, visibility.String())
|
||||
}
|
||||
where = append(where, fmt.Sprintf("memo.visibility in (%s)", strings.Join(holders, ", ")))
|
||||
}
|
||||
if find.ExcludeComments {
|
||||
where = append(where, "memo_relation.related_memo_id IS NULL")
|
||||
}
|
||||
|
||||
order := "DESC"
|
||||
if find.OrderByTimeAsc {
|
||||
order = "ASC"
|
||||
}
|
||||
orderBy := []string{}
|
||||
if find.OrderByPinned {
|
||||
orderBy = append(orderBy, "pinned DESC")
|
||||
}
|
||||
if find.OrderByUpdatedTs {
|
||||
orderBy = append(orderBy, "updated_ts "+order)
|
||||
} else {
|
||||
orderBy = append(orderBy, "created_ts "+order)
|
||||
}
|
||||
// Add id as final tie-breaker
|
||||
orderBy = append(orderBy, "id DESC")
|
||||
fields := []string{
|
||||
`memo.id AS id`,
|
||||
`memo.uid AS uid`,
|
||||
`memo.creator_id AS creator_id`,
|
||||
`memo.created_ts AS created_ts`,
|
||||
`memo.updated_ts AS updated_ts`,
|
||||
`memo.row_status AS row_status`,
|
||||
`memo.visibility AS visibility`,
|
||||
`memo.pinned AS pinned`,
|
||||
`memo.payload AS payload`,
|
||||
`CASE WHEN parent_memo.uid IS NOT NULL THEN parent_memo.uid ELSE NULL END AS parent_uid`,
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
fields = append(fields, `memo.content AS content`)
|
||||
}
|
||||
|
||||
query := `SELECT ` + strings.Join(fields, ", ") + `
|
||||
FROM memo
|
||||
LEFT JOIN memo_relation ON memo.id = memo_relation.memo_id AND memo_relation.type = 'COMMENT'
|
||||
LEFT JOIN memo AS parent_memo ON memo_relation.related_memo_id = parent_memo.id
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
ORDER BY ` + strings.Join(orderBy, ", ")
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Memo, 0)
|
||||
for rows.Next() {
|
||||
var memo store.Memo
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&memo.ID,
|
||||
&memo.UID,
|
||||
&memo.CreatorID,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&memo.RowStatus,
|
||||
&memo.Visibility,
|
||||
&memo.Pinned,
|
||||
&payloadBytes,
|
||||
&memo.ParentUID,
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
dests = append(dests, &memo.Content)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := &storepb.MemoPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal payload")
|
||||
}
|
||||
memo.Payload = payload
|
||||
list = append(list, &memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, error) {
|
||||
list, err := d.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
memo := list[0]
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.CreatedTs; v != nil {
|
||||
set, args = append(set, "created_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Content; v != nil {
|
||||
set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Visibility; v != nil {
|
||||
set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Pinned; v != nil {
|
||||
set, args = append(set, "pinned = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
payloadBytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
set, args = append(set, "payload = "+placeholder(len(args)+1)), append(args, string(payloadBytes))
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := `UPDATE memo SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
|
||||
args = append(args, update.ID)
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
where, args := []string{"id = " + placeholder(1)}, []any{delete.ID}
|
||||
stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to delete memo")
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
134
store/db/postgres/memo_relation.go
Normal file
134
store/db/postgres/memo_relation.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
stmt := `
|
||||
INSERT INTO memo_relation (
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
)
|
||||
VALUES (` + placeholders(3) + `)
|
||||
ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET type = EXCLUDED.type
|
||||
RETURNING memo_id, related_memo_id, type
|
||||
`
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := d.db.QueryRowContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.MemoID,
|
||||
create.RelatedMemoID,
|
||||
create.Type,
|
||||
).Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return memoRelation, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.MemoID != nil {
|
||||
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID)
|
||||
}
|
||||
if find.RelatedMemoID != nil {
|
||||
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, find.RelatedMemoID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type)
|
||||
}
|
||||
if find.MemoFilter != nil {
|
||||
engine, err := filter.DefaultEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{
|
||||
Dialect: filter.DialectPostgres,
|
||||
PlaceholderOffset: len(args),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if stmt.SQL != "" {
|
||||
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
|
||||
args = append(args, stmt.Args...)
|
||||
|
||||
stmtRelated, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{
|
||||
Dialect: filter.DialectPostgres,
|
||||
PlaceholderOffset: len(args),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if stmtRelated.SQL != "" {
|
||||
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", stmtRelated.SQL))
|
||||
args = append(args, stmtRelated.Args...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
FROM memo_relation
|
||||
WHERE `+strings.Join(where, " AND "), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.MemoRelation{}
|
||||
for rows.Next() {
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := rows.Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoRelation)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if delete.MemoID != nil {
|
||||
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, delete.MemoID)
|
||||
}
|
||||
if delete.RelatedMemoID != nil {
|
||||
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, delete.RelatedMemoID)
|
||||
}
|
||||
if delete.Type != nil {
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, delete.Type)
|
||||
}
|
||||
stmt := `DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
57
store/db/postgres/postgres.go
Normal file
57
store/db/postgres/postgres.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log"
|
||||
|
||||
// Import the PostgreSQL driver.
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
}
|
||||
|
||||
func NewDB(profile *profile.Profile) (store.Driver, error) {
|
||||
if profile == nil {
|
||||
return nil, errors.New("profile is nil")
|
||||
}
|
||||
|
||||
// Open the PostgreSQL connection
|
||||
db, err := sql.Open("postgres", profile.DSN)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open database: %s", err)
|
||||
return nil, errors.Wrapf(err, "failed to open database: %s", profile.DSN)
|
||||
}
|
||||
|
||||
var driver store.Driver = &DB{
|
||||
db: db,
|
||||
profile: profile,
|
||||
}
|
||||
|
||||
// Return the DB struct
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
|
||||
var exists bool
|
||||
err := d.db.QueryRowContext(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_catalog = current_database() AND table_name = 'memo' AND table_type = 'BASE TABLE')").Scan(&exists)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
101
store/db/postgres/reaction.go
Normal file
101
store/db/postgres/reaction.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store.Reaction, error) {
|
||||
fields := []string{"creator_id", "content_id", "reaction_type"}
|
||||
args := []interface{}{upsert.CreatorID, upsert.ContentID, upsert.ReactionType}
|
||||
stmt := "INSERT INTO reaction (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&upsert.ID,
|
||||
&upsert.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reaction := upsert
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||
}
|
||||
if find.CreatorID != nil {
|
||||
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *find.CreatorID)
|
||||
}
|
||||
if find.ContentID != nil {
|
||||
where, args = append(where, "content_id = "+placeholder(len(args)+1)), append(args, *find.ContentID)
|
||||
}
|
||||
if len(find.ContentIDList) > 0 {
|
||||
holders := make([]string, 0, len(find.ContentIDList))
|
||||
for _, id := range find.ContentIDList {
|
||||
holders = append(holders, placeholder(len(args)+1))
|
||||
args = append(args, id)
|
||||
}
|
||||
where = append(where, "content_id IN ("+strings.Join(holders, ", ")+")")
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
created_ts,
|
||||
creator_id,
|
||||
content_id,
|
||||
reaction_type
|
||||
FROM reaction
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Reaction{}
|
||||
for rows.Next() {
|
||||
reaction := &store.Reaction{}
|
||||
if err := rows.Scan(
|
||||
&reaction.ID,
|
||||
&reaction.CreatedTs,
|
||||
&reaction.CreatorID,
|
||||
&reaction.ContentID,
|
||||
&reaction.ReactionType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, reaction)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) {
|
||||
list, err := d.ListReactions(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reaction := list[0]
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM reaction WHERE id = $1", delete.ID)
|
||||
return err
|
||||
}
|
||||
172
store/db/postgres/user.go
Normal file
172
store/db/postgres/user.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
fields := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"}
|
||||
args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
|
||||
stmt := "INSERT INTO \"user\" (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, description, created_ts, updated_ts, row_status"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.Description,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
set, args = append(set, "username = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
set, args = append(set, "avatar_url = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Description; v != nil {
|
||||
set, args = append(set, "description = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Role; v != nil {
|
||||
set, args = append(set, "role = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
query := `
|
||||
UPDATE "user"
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING id, username, role, email, nickname, password_hash, avatar_url, description, created_ts, updated_ts, row_status
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
user := &store.User{}
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if len(find.Filters) > 0 {
|
||||
return nil, errors.Errorf("user filters are not supported")
|
||||
}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
where, args = append(where, "username = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
orderBy := []string{"created_ts DESC", "row_status DESC"}
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
avatar_url,
|
||||
description,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status
|
||||
FROM "user"
|
||||
WHERE ` + strings.Join(where, " AND ") + ` ORDER BY ` + strings.Join(orderBy, ", ")
|
||||
if v := find.Limit; v != nil {
|
||||
query += fmt.Sprintf(" LIMIT %d", *v)
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.User, 0)
|
||||
for rows.Next() {
|
||||
var user store.User
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
result, err := d.db.ExecContext(ctx, `DELETE FROM "user" WHERE id = $1`, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
121
store/db/postgres/user_setting.go
Normal file
121
store/db/postgres/user_setting.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO user_setting (
|
||||
user_id, key, value
|
||||
)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT(user_id, key) DO UPDATE
|
||||
SET value = EXCLUDED.value
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
key,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*store.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &store.UserSetting{}
|
||||
var keyString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserID,
|
||||
&keyString,
|
||||
&userSetting.Value,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUserByPATHash(ctx context.Context, tokenHash string) (*store.PATQueryResult, error) {
|
||||
// Simplified query: fetch all PERSONAL_ACCESS_TOKENS rows and search in Go
|
||||
// This matches SQLite/MySQL behavior and avoids PostgreSQL's strict JSONB errors
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE key = 'PERSONAL_ACCESS_TOKENS'
|
||||
`
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Iterate through all users with PAT settings
|
||||
for rows.Next() {
|
||||
var userID int32
|
||||
var tokensJSON string
|
||||
|
||||
if err := rows.Scan(&userID, &tokensJSON); err != nil {
|
||||
continue // Skip malformed rows
|
||||
}
|
||||
|
||||
// Try to unmarshal - skip if invalid JSON
|
||||
patsUserSetting := &storepb.PersonalAccessTokensUserSetting{}
|
||||
if err := protojsonUnmarshaler.Unmarshal([]byte(tokensJSON), patsUserSetting); err != nil {
|
||||
continue // Skip invalid JSON
|
||||
}
|
||||
|
||||
// Search for matching token hash
|
||||
for _, pat := range patsUserSetting.Tokens {
|
||||
if pat.TokenHash == tokenHash {
|
||||
return &store.PATQueryResult{
|
||||
UserID: userID,
|
||||
PAT: pat,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, errors.New("PAT not found")
|
||||
}
|
||||
290
store/db/postgres/user_setting_test.go
Normal file
290
store/db/postgres/user_setting_test.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// TestGetUserByPATHashWithMissingData tests the fix for #5611 and #5612.
|
||||
// Verifies that GetUserByPATHash handles missing/malformed data gracefully
|
||||
// instead of throwing PostgreSQL JSONB errors.
|
||||
func TestGetUserByPATHashWithMissingData(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping PostgreSQL integration test in short mode")
|
||||
}
|
||||
|
||||
// This test requires a real PostgreSQL connection
|
||||
// If DSN is not provided, skip the test
|
||||
dsn := getTestDSN()
|
||||
if dsn == "" {
|
||||
t.Skip("PostgreSQL DSN not provided, skipping test")
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Create test database
|
||||
ctx := context.Background()
|
||||
driver := &DB{db: db}
|
||||
|
||||
// Setup: Create user_setting table if needed
|
||||
_, err = db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS user_setting (
|
||||
user_id INTEGER NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
UNIQUE(user_id, key)
|
||||
)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanup
|
||||
defer func() {
|
||||
db.ExecContext(ctx, "DELETE FROM user_setting WHERE user_id IN (1001, 1002, 1003)")
|
||||
}()
|
||||
|
||||
t.Run("NoTokensKeyAtAll", func(t *testing.T) {
|
||||
// Test case: User has no PERSONAL_ACCESS_TOKENS key
|
||||
// This simulates fresh users or users upgraded from v0.25.3
|
||||
result, err := driver.GetUserByPATHash(ctx, "any-hash")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "PAT not found")
|
||||
})
|
||||
|
||||
t.Run("EmptyTokensArray", func(t *testing.T) {
|
||||
// Insert user with empty tokens array
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO user_setting (user_id, key, value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
|
||||
`, 1001, "PERSONAL_ACCESS_TOKENS", `{"tokens":[]}`)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := driver.GetUserByPATHash(ctx, "any-hash")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "PAT not found")
|
||||
})
|
||||
|
||||
t.Run("MalformedJSON", func(t *testing.T) {
|
||||
// Insert user with malformed JSON
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO user_setting (user_id, key, value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
|
||||
`, 1002, "PERSONAL_ACCESS_TOKENS", `{invalid json}`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should handle gracefully without crashing
|
||||
result, err := driver.GetUserByPATHash(ctx, "any-hash")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "PAT not found")
|
||||
})
|
||||
|
||||
t.Run("MissingTokensField", func(t *testing.T) {
|
||||
// Insert user with valid JSON but missing 'tokens' field
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO user_setting (user_id, key, value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
|
||||
`, 1003, "PERSONAL_ACCESS_TOKENS", `{"someOtherField":"value"}`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should handle gracefully
|
||||
result, err := driver.GetUserByPATHash(ctx, "any-hash")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("ValidTokenFound", func(t *testing.T) {
|
||||
// Insert user with valid PAT
|
||||
validJSON := `{
|
||||
"tokens": [
|
||||
{
|
||||
"tokenId": "pat-test",
|
||||
"tokenHash": "hash-test-123",
|
||||
"description": "Test PAT"
|
||||
}
|
||||
]
|
||||
}`
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO user_setting (user_id, key, value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
|
||||
`, 1001, "PERSONAL_ACCESS_TOKENS", validJSON)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should find the token
|
||||
result, err := driver.GetUserByPATHash(ctx, "hash-test-123")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, int32(1001), result.UserID)
|
||||
assert.Equal(t, "pat-test", result.PAT.TokenId)
|
||||
assert.Equal(t, "hash-test-123", result.PAT.TokenHash)
|
||||
})
|
||||
|
||||
t.Run("MultipleUsersWithMixedData", func(t *testing.T) {
|
||||
// User 1001: Valid PAT
|
||||
validJSON := `{
|
||||
"tokens": [
|
||||
{
|
||||
"tokenId": "pat-user1",
|
||||
"tokenHash": "hash-user1",
|
||||
"description": "User 1 PAT"
|
||||
}
|
||||
]
|
||||
}`
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO user_setting (user_id, key, value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
|
||||
`, 1001, "PERSONAL_ACCESS_TOKENS", validJSON)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User 1002: Malformed JSON (should be skipped)
|
||||
_, err = db.ExecContext(ctx, `
|
||||
INSERT INTO user_setting (user_id, key, value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
|
||||
`, 1002, "PERSONAL_ACCESS_TOKENS", `{invalid}`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User 1003: Empty array (should be skipped)
|
||||
_, err = db.ExecContext(ctx, `
|
||||
INSERT INTO user_setting (user_id, key, value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
|
||||
`, 1003, "PERSONAL_ACCESS_TOKENS", `{"tokens":[]}`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should still find user 1001's token despite other users having bad data
|
||||
result, err := driver.GetUserByPATHash(ctx, "hash-user1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, int32(1001), result.UserID)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGetUserByPATHashPerformance ensures the simplified query doesn't cause performance issues.
|
||||
func TestGetUserByPATHashPerformance(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping performance test in short mode")
|
||||
}
|
||||
|
||||
dsn := getTestDSN()
|
||||
if dsn == "" {
|
||||
t.Skip("PostgreSQL DSN not provided, skipping test")
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
driver := &DB{db: db}
|
||||
|
||||
// Setup table
|
||||
_, err = db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS user_setting (
|
||||
user_id INTEGER NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
UNIQUE(user_id, key)
|
||||
)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanup
|
||||
defer func() {
|
||||
db.ExecContext(ctx, "DELETE FROM user_setting WHERE user_id >= 2000 AND user_id < 2100")
|
||||
}()
|
||||
|
||||
// Insert 100 users with PATs
|
||||
for i := 2000; i < 2100; i++ {
|
||||
json := `{
|
||||
"tokens": [
|
||||
{
|
||||
"tokenId": "pat-` + string(rune(i)) + `",
|
||||
"tokenHash": "hash-` + string(rune(i)) + `",
|
||||
"description": "Test PAT"
|
||||
}
|
||||
]
|
||||
}`
|
||||
_, err = db.ExecContext(ctx, `
|
||||
INSERT INTO user_setting (user_id, key, value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
|
||||
`, i, "PERSONAL_ACCESS_TOKENS", json)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Query should complete quickly even with 100 users
|
||||
result, err := driver.GetUserByPATHash(ctx, "hash-"+string(rune(2050)))
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, int32(2050), result.UserID)
|
||||
}
|
||||
|
||||
// getTestDSN returns PostgreSQL DSN from environment or returns empty string.
|
||||
func getTestDSN() string {
|
||||
// For unit tests, we expect TEST_POSTGRES_DSN to be set.
|
||||
// Example: TEST_POSTGRES_DSN="postgresql://user:pass@localhost:5432/memos_test?sslmode=disable".
|
||||
return ""
|
||||
}
|
||||
|
||||
// TestUpsertUserSetting tests basic upsert functionality.
|
||||
func TestUpsertUserSetting(t *testing.T) {
|
||||
dsn := getTestDSN()
|
||||
if dsn == "" {
|
||||
t.Skip("PostgreSQL DSN not provided, skipping test")
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
driver := &DB{db: db}
|
||||
|
||||
// Setup
|
||||
_, err = db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS user_setting (
|
||||
user_id INTEGER NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
UNIQUE(user_id, key)
|
||||
)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
db.ExecContext(ctx, "DELETE FROM user_setting WHERE user_id = 9999")
|
||||
}()
|
||||
|
||||
// Test insert
|
||||
setting := &store.UserSetting{
|
||||
UserID: 9999,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
Value: `{"locale":"en"}`,
|
||||
}
|
||||
result, err := driver.UpsertUserSetting(ctx, setting)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, int32(9999), result.UserID)
|
||||
|
||||
// Test update (upsert on conflict)
|
||||
setting.Value = `{"locale":"zh"}`
|
||||
result, err = driver.UpsertUserSetting(ctx, setting)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"locale":"zh"}`, result.Value)
|
||||
}
|
||||
83
store/db/sqlite/activity.go
Normal file
83
store/db/sqlite/activity.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal activity payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"`creator_id`", "`type`", "`level`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?"}
|
||||
args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
|
||||
|
||||
stmt := "INSERT INTO activity (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, find.Type.String())
|
||||
}
|
||||
|
||||
query := "SELECT `id`, `creator_id`, `type`, `level`, `payload`, `created_ts` FROM `activity` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Activity{}
|
||||
for rows.Next() {
|
||||
activity := &store.Activity{}
|
||||
var payloadBytes []byte
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&payloadBytes,
|
||||
&activity.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := &storepb.ActivityPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
activity.Payload = payload
|
||||
list = append(list, activity)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
221
store/db/sqlite/attachment.go
Normal file
221
store/db/sqlite/attachment.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*store.Attachment, error) {
|
||||
fields := []string{"`uid`", "`filename`", "`blob`", "`type`", "`size`", "`creator_id`", "`memo_id`", "`storage_type`", "`reference`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?", "?", "?", "?", "?"}
|
||||
storageType := ""
|
||||
if create.StorageType != storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED {
|
||||
storageType = create.StorageType.String()
|
||||
}
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
|
||||
|
||||
stmt := "INSERT INTO `attachment` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`attachment`.`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`attachment`.`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`attachment`.`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "`attachment`.`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.FilenameSearch; v != nil {
|
||||
where, args = append(where, "`attachment`.`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v))
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "`attachment`.`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if len(find.MemoIDList) > 0 {
|
||||
placeholders := make([]string, 0, len(find.MemoIDList))
|
||||
for range find.MemoIDList {
|
||||
placeholders = append(placeholders, "?")
|
||||
}
|
||||
where = append(where, "`attachment`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
|
||||
for _, id := range find.MemoIDList {
|
||||
args = append(args, id)
|
||||
}
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "`attachment`.`memo_id` IS NOT NULL")
|
||||
}
|
||||
if find.StorageType != nil {
|
||||
where, args = append(where, "`attachment`.`storage_type` = ?"), append(args, find.StorageType.String())
|
||||
}
|
||||
|
||||
if len(find.Filters) > 0 {
|
||||
engine, err := filter.DefaultAttachmentEngine()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get filter engine")
|
||||
}
|
||||
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectSQLite, &where, &args); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to append filter conditions")
|
||||
}
|
||||
}
|
||||
|
||||
fields := []string{
|
||||
"`attachment`.`id` AS `id`",
|
||||
"`attachment`.`uid` AS `uid`",
|
||||
"`attachment`.`filename` AS `filename`",
|
||||
"`attachment`.`type` AS `type`",
|
||||
"`attachment`.`size` AS `size`",
|
||||
"`attachment`.`creator_id` AS `creator_id`",
|
||||
"`attachment`.`created_ts` AS `created_ts`",
|
||||
"`attachment`.`updated_ts` AS `updated_ts`",
|
||||
"`attachment`.`memo_id` AS `memo_id`",
|
||||
"`attachment`.`storage_type` AS `storage_type`",
|
||||
"`attachment`.`reference` AS `reference`",
|
||||
"`attachment`.`payload` AS `payload`",
|
||||
"CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`",
|
||||
}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "`attachment`.`blob` AS `blob`")
|
||||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " FROM `attachment`" + " " +
|
||||
"LEFT JOIN `memo` ON `attachment`.`memo_id` = `memo`.`id`" + " " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
"ORDER BY `attachment`.`updated_ts` DESC"
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Attachment, 0)
|
||||
for rows.Next() {
|
||||
attachment := store.Attachment{}
|
||||
var memoID sql.NullInt32
|
||||
var storageType string
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&attachment.ID,
|
||||
&attachment.UID,
|
||||
&attachment.Filename,
|
||||
&attachment.Type,
|
||||
&attachment.Size,
|
||||
&attachment.CreatorID,
|
||||
&attachment.CreatedTs,
|
||||
&attachment.UpdatedTs,
|
||||
&memoID,
|
||||
&storageType,
|
||||
&attachment.Reference,
|
||||
&payloadBytes,
|
||||
&attachment.MemoUID,
|
||||
}
|
||||
if find.GetBlob {
|
||||
dests = append(dests, &attachment.Blob)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if memoID.Valid {
|
||||
attachment.MemoID = &memoID.Int32
|
||||
}
|
||||
attachment.StorageType = storepb.AttachmentStorageType(storepb.AttachmentStorageType_value[storageType])
|
||||
payload := &storepb.AttachmentPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attachment.Payload = payload
|
||||
list = append(list, &attachment)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachment) error {
|
||||
set, args := []string{}, []any{}
|
||||
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Filename; v != nil {
|
||||
set, args = append(set, "`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.MemoID; v != nil {
|
||||
set, args = append(set, "`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Reference; v != nil {
|
||||
set, args = append(set, "`reference` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
bytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
set, args = append(set, "`payload` = ?"), append(args, string(bytes))
|
||||
}
|
||||
|
||||
args = append(args, update.ID)
|
||||
stmt := "UPDATE `attachment` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to update attachment")
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
|
||||
stmt := "DELETE FROM `attachment` WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
9
store/db/sqlite/common.go
Normal file
9
store/db/sqlite/common.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package sqlite
|
||||
|
||||
import "google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
var (
|
||||
protojsonUnmarshaler = protojson.UnmarshalOptions{
|
||||
DiscardUnknown: true,
|
||||
}
|
||||
)
|
||||
44
store/db/sqlite/functions.go
Normal file
44
store/db/sqlite/functions.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// Package sqlite provides SQLite driver implementation with custom functions.
|
||||
// Custom functions are registered globally on first use to extend SQLite's
|
||||
// limited ASCII-only text operations with proper Unicode support.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/text/cases"
|
||||
msqlite "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
var (
|
||||
registerUnicodeLowerOnce sync.Once
|
||||
registerUnicodeLowerErr error
|
||||
// unicodeFold provides Unicode case folding for case-insensitive comparisons.
|
||||
// It's safe to use concurrently and reused across all function calls.
|
||||
unicodeFold = cases.Fold()
|
||||
)
|
||||
|
||||
// ensureUnicodeLowerRegistered registers the memos_unicode_lower custom function
|
||||
// with SQLite. This function provides proper Unicode case folding for case-insensitive
|
||||
// text comparisons, overcoming modernc.org/sqlite's lack of ICU extension.
|
||||
//
|
||||
// The function is registered once globally and is safe to call multiple times.
|
||||
func ensureUnicodeLowerRegistered() error {
|
||||
registerUnicodeLowerOnce.Do(func() {
|
||||
registerUnicodeLowerErr = msqlite.RegisterScalarFunction("memos_unicode_lower", 1, func(_ *msqlite.FunctionContext, args []driver.Value) (driver.Value, error) {
|
||||
if len(args) == 0 || args[0] == nil {
|
||||
return nil, nil
|
||||
}
|
||||
switch v := args[0].(type) {
|
||||
case string:
|
||||
return unicodeFold.String(v), nil
|
||||
case []byte:
|
||||
return unicodeFold.String(string(v)), nil
|
||||
default:
|
||||
return v, nil
|
||||
}
|
||||
})
|
||||
})
|
||||
return registerUnicodeLowerErr
|
||||
}
|
||||
117
store/db/sqlite/idp.go
Normal file
117
store/db/sqlite/idp.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
placeholders := []string{"?", "?", "?", "?"}
|
||||
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
|
||||
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
|
||||
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ") RETURNING `id`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider := create
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
config
|
||||
FROM idp
|
||||
WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var identityProviders []*store.IdentityProvider
|
||||
for rows.Next() {
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
identityProviders = append(identityProviders, &identityProvider)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identityProviders, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.Name; v != nil {
|
||||
set, args = append(set, "name = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
set, args = append(set, "identifier_filter = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
set, args = append(set, "config = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE idp
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
RETURNING id, name, type, identifier_filter, config
|
||||
`
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
return &identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
|
||||
where, args := []string{"id = ?"}, []any{delete.ID}
|
||||
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
141
store/db/sqlite/inbox.go
Normal file
141
store/db/sqlite/inbox.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) {
|
||||
messageString := "{}"
|
||||
if create.Message != nil {
|
||||
bytes, err := protojson.Marshal(create.Message)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal inbox message")
|
||||
}
|
||||
messageString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"`sender_id`", "`receiver_id`", "`status`", "`message`"}
|
||||
placeholder := []string{"?", "?", "?", "?"}
|
||||
args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
|
||||
|
||||
stmt := "INSERT INTO `inbox` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.SenderID != nil {
|
||||
where, args = append(where, "`sender_id` = ?"), append(args, *find.SenderID)
|
||||
}
|
||||
if find.ReceiverID != nil {
|
||||
where, args = append(where, "`receiver_id` = ?"), append(args, *find.ReceiverID)
|
||||
}
|
||||
if find.Status != nil {
|
||||
where, args = append(where, "`status` = ?"), append(args, *find.Status)
|
||||
}
|
||||
if find.MessageType != nil {
|
||||
// Filter by message type using JSON extraction
|
||||
// Note: The type field in JSON is stored as string representation of the enum name
|
||||
if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED {
|
||||
where, args = append(where, "(JSON_EXTRACT(`message`, '$.type') IS NULL OR JSON_EXTRACT(`message`, '$.type') = ?)"), append(args, find.MessageType.String())
|
||||
} else {
|
||||
where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String())
|
||||
}
|
||||
}
|
||||
|
||||
query := "SELECT `id`, `created_ts`, `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Inbox{}
|
||||
for rows.Next() {
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := rows.Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
list = append(list, inbox)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
|
||||
set, args := []string{"`status` = ?"}, []any{update.Status.String()}
|
||||
args = append(args, update.ID)
|
||||
query := "UPDATE `inbox` SET " + strings.Join(set, ", ") + " WHERE `id` = ? RETURNING `id`, `created_ts`, `sender_id`, `receiver_id`, `status`, `message`"
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM `inbox` WHERE `id` = ?", delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
72
store/db/sqlite/instance_setting.go
Normal file
72
store/db/sqlite/instance_setting.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertInstanceSetting(ctx context.Context, upsert *store.InstanceSetting) (*store.InstanceSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO system_setting (
|
||||
name, value, description
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(name) DO UPDATE
|
||||
SET
|
||||
value = EXCLUDED.value,
|
||||
description = EXCLUDED.description
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInstanceSettings(ctx context.Context, find *store.FindInstanceSetting) ([]*store.InstanceSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Name != "" {
|
||||
where, args = append(where, "name = ?"), append(args, find.Name)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
name,
|
||||
value,
|
||||
description
|
||||
FROM system_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.InstanceSetting{}
|
||||
for rows.Next() {
|
||||
systemSettingMessage := &store.InstanceSetting{}
|
||||
if err := rows.Scan(
|
||||
&systemSettingMessage.Name,
|
||||
&systemSettingMessage.Value,
|
||||
&systemSettingMessage.Description,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, systemSettingMessage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInstanceSetting(ctx context.Context, delete *store.DeleteInstanceSetting) error {
|
||||
stmt := "DELETE FROM system_setting WHERE name = ?"
|
||||
_, err := d.db.ExecContext(ctx, stmt, delete.Name)
|
||||
return err
|
||||
}
|
||||
247
store/db/sqlite/memo.go
Normal file
247
store/db/sqlite/memo.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
|
||||
fields := []string{"`uid`", "`creator_id`", "`content`", "`visibility`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?"}
|
||||
payload := "{}"
|
||||
if create.Payload != nil {
|
||||
payloadBytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload = string(payloadBytes)
|
||||
}
|
||||
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
|
||||
|
||||
// Add custom timestamps if provided
|
||||
if create.CreatedTs != 0 {
|
||||
fields = append(fields, "`created_ts`")
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, create.CreatedTs)
|
||||
}
|
||||
if create.UpdatedTs != 0 {
|
||||
fields = append(fields, "`updated_ts`")
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, create.UpdatedTs)
|
||||
}
|
||||
|
||||
stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`, `row_status`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
engine, err := filter.DefaultEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectSQLite, &where, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`memo`.`id` = ?"), append(args, *v)
|
||||
}
|
||||
if len(find.IDList) > 0 {
|
||||
placeholders := make([]string, 0, len(find.IDList))
|
||||
for range find.IDList {
|
||||
placeholders = append(placeholders, "?")
|
||||
}
|
||||
where = append(where, "`memo`.`id` IN ("+strings.Join(placeholders, ",")+")")
|
||||
for _, id := range find.IDList {
|
||||
args = append(args, id)
|
||||
}
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`memo`.`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if len(find.UIDList) > 0 {
|
||||
placeholders := make([]string, 0, len(find.UIDList))
|
||||
for range find.UIDList {
|
||||
placeholders = append(placeholders, "?")
|
||||
}
|
||||
where = append(where, "`memo`.`uid` IN ("+strings.Join(placeholders, ",")+")")
|
||||
for _, uid := range find.UIDList {
|
||||
args = append(args, uid)
|
||||
}
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`memo`.`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "`memo`.`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
placeholder := []string{}
|
||||
for _, visibility := range v {
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, visibility.String())
|
||||
}
|
||||
where = append(where, fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ",")))
|
||||
}
|
||||
if find.ExcludeComments {
|
||||
where = append(where, "`parent_uid` IS NULL")
|
||||
}
|
||||
|
||||
order := "DESC"
|
||||
if find.OrderByTimeAsc {
|
||||
order = "ASC"
|
||||
}
|
||||
orderBy := []string{}
|
||||
if find.OrderByPinned {
|
||||
orderBy = append(orderBy, "`pinned` DESC")
|
||||
}
|
||||
if find.OrderByUpdatedTs {
|
||||
orderBy = append(orderBy, "`updated_ts` "+order)
|
||||
} else {
|
||||
orderBy = append(orderBy, "`created_ts` "+order)
|
||||
}
|
||||
// Add id as final tie-breaker
|
||||
orderBy = append(orderBy, "`id` DESC")
|
||||
fields := []string{
|
||||
"`memo`.`id` AS `id`",
|
||||
"`memo`.`uid` AS `uid`",
|
||||
"`memo`.`creator_id` AS `creator_id`",
|
||||
"`memo`.`created_ts` AS `created_ts`",
|
||||
"`memo`.`updated_ts` AS `updated_ts`",
|
||||
"`memo`.`row_status` AS `row_status`",
|
||||
"`memo`.`visibility` AS `visibility`",
|
||||
"`memo`.`pinned` AS `pinned`",
|
||||
"`memo`.`payload` AS `payload`",
|
||||
"CASE WHEN `parent_memo`.`uid` IS NOT NULL THEN `parent_memo`.`uid` ELSE NULL END AS `parent_uid`",
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
fields = append(fields, "`memo`.`content` AS `content`")
|
||||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + "FROM `memo` " +
|
||||
"LEFT JOIN `memo_relation` ON `memo`.`id` = `memo_relation`.`memo_id` AND `memo_relation`.`type` = \"COMMENT\" " +
|
||||
"LEFT JOIN `memo` AS `parent_memo` ON `memo_relation`.`related_memo_id` = `parent_memo`.`id` " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
"ORDER BY " + strings.Join(orderBy, ", ")
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Memo, 0)
|
||||
for rows.Next() {
|
||||
var memo store.Memo
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&memo.ID,
|
||||
&memo.UID,
|
||||
&memo.CreatorID,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&memo.RowStatus,
|
||||
&memo.Visibility,
|
||||
&memo.Pinned,
|
||||
&payloadBytes,
|
||||
&memo.ParentUID,
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
dests = append(dests, &memo.Content)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := &storepb.MemoPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal payload")
|
||||
}
|
||||
memo.Payload = payload
|
||||
list = append(list, &memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.CreatedTs; v != nil {
|
||||
set, args = append(set, "`created_ts` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Content; v != nil {
|
||||
set, args = append(set, "`content` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Visibility; v != nil {
|
||||
set, args = append(set, "`visibility` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Pinned; v != nil {
|
||||
set, args = append(set, "`pinned` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
payloadBytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
set, args = append(set, "`payload` = ?"), append(args, string(payloadBytes))
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := "UPDATE `memo` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
where, args := []string{"`id` = ?"}, []any{delete.ID}
|
||||
stmt := "DELETE FROM `memo` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
122
store/db/sqlite/memo_relation.go
Normal file
122
store/db/sqlite/memo_relation.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
stmt := `
|
||||
INSERT INTO memo_relation (
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(memo_id, related_memo_id, type) DO UPDATE SET type = excluded.type
|
||||
RETURNING memo_id, related_memo_id, type
|
||||
`
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := d.db.QueryRowContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.MemoID,
|
||||
create.RelatedMemoID,
|
||||
create.Type,
|
||||
).Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return memoRelation, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if find.MemoID != nil {
|
||||
where, args = append(where, "memo_id = ?"), append(args, find.MemoID)
|
||||
}
|
||||
if find.RelatedMemoID != nil {
|
||||
where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "type = ?"), append(args, find.Type)
|
||||
}
|
||||
if find.MemoFilter != nil {
|
||||
engine, err := filter.DefaultEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{Dialect: filter.DialectSQLite})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if stmt.SQL != "" {
|
||||
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
|
||||
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
|
||||
args = append(args, append(stmt.Args, stmt.Args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
FROM memo_relation
|
||||
WHERE `+strings.Join(where, " AND "), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.MemoRelation{}
|
||||
for rows.Next() {
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := rows.Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoRelation)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if delete.MemoID != nil {
|
||||
where, args = append(where, "memo_id = ?"), append(args, delete.MemoID)
|
||||
}
|
||||
if delete.RelatedMemoID != nil {
|
||||
where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID)
|
||||
}
|
||||
if delete.Type != nil {
|
||||
where, args = append(where, "type = ?"), append(args, delete.Type)
|
||||
}
|
||||
stmt := `
|
||||
DELETE FROM memo_relation
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
136
store/db/sqlite/reaction.go
Normal file
136
store/db/sqlite/reaction.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store.Reaction, error) {
|
||||
fields := []string{"`creator_id`", "`content_id`", "`reaction_type`"}
|
||||
placeholder := []string{"?", "?", "?"}
|
||||
args := []interface{}{upsert.CreatorID, upsert.ContentID, upsert.ReactionType}
|
||||
stmt := "INSERT INTO `reaction` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&upsert.ID,
|
||||
&upsert.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reaction := upsert
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.CreatorID != nil {
|
||||
where, args = append(where, "creator_id = ?"), append(args, *find.CreatorID)
|
||||
}
|
||||
if find.ContentID != nil {
|
||||
where, args = append(where, "content_id = ?"), append(args, *find.ContentID)
|
||||
}
|
||||
if len(find.ContentIDList) > 0 {
|
||||
placeholders := make([]string, 0, len(find.ContentIDList))
|
||||
for range find.ContentIDList {
|
||||
placeholders = append(placeholders, "?")
|
||||
}
|
||||
if len(placeholders) > 0 {
|
||||
where = append(where, "content_id IN ("+strings.Join(placeholders, ",")+")")
|
||||
for _, id := range find.ContentIDList {
|
||||
args = append(args, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
created_ts,
|
||||
creator_id,
|
||||
content_id,
|
||||
reaction_type
|
||||
FROM reaction
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Reaction{}
|
||||
for rows.Next() {
|
||||
reaction := &store.Reaction{}
|
||||
if err := rows.Scan(
|
||||
&reaction.ID,
|
||||
&reaction.CreatedTs,
|
||||
&reaction.CreatorID,
|
||||
&reaction.ContentID,
|
||||
&reaction.ReactionType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, reaction)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.CreatorID != nil {
|
||||
where, args = append(where, "creator_id = ?"), append(args, *find.CreatorID)
|
||||
}
|
||||
if find.ContentID != nil {
|
||||
where, args = append(where, "content_id = ?"), append(args, *find.ContentID)
|
||||
}
|
||||
|
||||
reaction := &store.Reaction{}
|
||||
if err := d.db.QueryRowContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
created_ts,
|
||||
creator_id,
|
||||
content_id,
|
||||
reaction_type
|
||||
FROM reaction
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
LIMIT 1`,
|
||||
args...,
|
||||
).Scan(
|
||||
&reaction.ID,
|
||||
&reaction.CreatedTs,
|
||||
&reaction.CreatorID,
|
||||
&reaction.ContentID,
|
||||
&reaction.ReactionType,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM `reaction` WHERE `id` = ?", delete.ID)
|
||||
return err
|
||||
}
|
||||
75
store/db/sqlite/sqlite.go
Normal file
75
store/db/sqlite/sqlite.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
// Note: modernc.org/sqlite driver is imported in functions.go where
|
||||
// RegisterScalarFunction is used. No blank import needed here.
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
}
|
||||
|
||||
// NewDB opens a database specified by its database driver name and a
|
||||
// driver-specific data source name, usually consisting of at least a
|
||||
// database name and connection information.
|
||||
func NewDB(profile *profile.Profile) (store.Driver, error) {
|
||||
// Ensure a DSN is set before attempting to open the database.
|
||||
if profile.DSN == "" {
|
||||
return nil, errors.New("dsn required")
|
||||
}
|
||||
|
||||
if err := ensureUnicodeLowerRegistered(); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to register sqlite unicode lower function")
|
||||
}
|
||||
|
||||
// Connect to the database with some sane settings:
|
||||
// - No shared-cache: it's obsolete; WAL journal mode is a better solution.
|
||||
// - No foreign key constraints: it's currently disabled by default, but it's a
|
||||
// good practice to be explicit and prevent future surprises on SQLite upgrades.
|
||||
// - Journal mode set to WAL: it's the recommended journal mode for most applications
|
||||
// as it prevents locking issues.
|
||||
// - mmap size set to 0: it disables memory mapping, which can cause OOM errors on some systems.
|
||||
//
|
||||
// Notes:
|
||||
// - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`.
|
||||
//
|
||||
// References:
|
||||
// - https://pkg.go.dev/modernc.org/sqlite#Driver.Open
|
||||
// - https://www.sqlite.org/sharedcache.html
|
||||
// - https://www.sqlite.org/pragma.html
|
||||
sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)&_pragma=mmap_size(0)")
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to open db with dsn: %s", profile.DSN)
|
||||
}
|
||||
|
||||
driver := DB{db: sqliteDB, profile: profile}
|
||||
|
||||
return &driver, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
|
||||
// Check if the database is initialized by checking if the memo table exists.
|
||||
var exists bool
|
||||
err := d.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='memo')").Scan(&exists)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
176
store/db/sqlite/user.go
Normal file
176
store/db/sqlite/user.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
fields := []string{"`username`", "`role`", "`email`", "`nickname`", "`password_hash`, `avatar_url`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?"}
|
||||
args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
|
||||
stmt := "INSERT INTO user (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING id, description, created_ts, updated_ts, row_status"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.Description,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
set, args = append(set, "username = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
set, args = append(set, "avatar_url = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "password_hash = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Description; v != nil {
|
||||
set, args = append(set, "description = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Role; v != nil {
|
||||
set, args = append(set, "role = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
query := `
|
||||
UPDATE user
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
RETURNING id, username, role, email, nickname, password_hash, avatar_url, description, created_ts, updated_ts, row_status
|
||||
`
|
||||
user := &store.User{}
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if len(find.Filters) > 0 {
|
||||
return nil, errors.Errorf("user filters are not supported")
|
||||
}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
where, args = append(where, "username = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "role = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
orderBy := []string{"created_ts DESC", "row_status DESC"}
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
avatar_url,
|
||||
description,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status
|
||||
FROM user
|
||||
WHERE ` + strings.Join(where, " AND ") + ` ORDER BY ` + strings.Join(orderBy, ", ")
|
||||
if v := find.Limit; v != nil {
|
||||
query += fmt.Sprintf(" LIMIT %d", *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.User, 0)
|
||||
for rows.Next() {
|
||||
var user store.User
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
result, err := d.db.ExecContext(ctx, `
|
||||
DELETE FROM user WHERE id = ?
|
||||
`, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
109
store/db/sqlite/user_setting.go
Normal file
109
store/db/sqlite/user_setting.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO user_setting (
|
||||
user_id, key, value
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id, key) DO UPDATE
|
||||
SET value = EXCLUDED.value
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "key = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
key,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*store.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &store.UserSetting{}
|
||||
var keyString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserID,
|
||||
&keyString,
|
||||
&userSetting.Value,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUserByPATHash(ctx context.Context, tokenHash string) (*store.PATQueryResult, error) {
|
||||
query := `
|
||||
SELECT
|
||||
user_setting.user_id,
|
||||
user_setting.value
|
||||
FROM user_setting
|
||||
WHERE user_setting.key = 'PERSONAL_ACCESS_TOKENS'
|
||||
AND EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(json_extract(user_setting.value, '$.tokens')) AS token
|
||||
WHERE json_extract(token.value, '$.tokenHash') = ?
|
||||
)
|
||||
`
|
||||
|
||||
var userID int32
|
||||
var tokensJSON string
|
||||
|
||||
err := d.db.QueryRowContext(ctx, query, tokenHash).Scan(&userID, &tokensJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
patsUserSetting := &storepb.PersonalAccessTokensUserSetting{}
|
||||
if err := protojsonUnmarshaler.Unmarshal([]byte(tokensJSON), patsUserSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, pat := range patsUserSetting.Tokens {
|
||||
if pat.TokenHash == tokenHash {
|
||||
return &store.PATQueryResult{
|
||||
UserID: userID,
|
||||
PAT: pat,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("PAT not found")
|
||||
}
|
||||
Reference in New Issue
Block a user