Skip to content

Commit 34283a7

Browse files
lunnydelvh6543
authored
Allow detect whether it's in a database transaction for a context.Context (#21756)
Fix #19513 This PR introduce a new db method `InTransaction(context.Context)`, and also builtin check on `db.TxContext` and `db.WithTx`. There is also a new method `db.AutoTx` has been introduced but could be used by other PRs. `WithTx` will always open a new transaction, if a transaction exist in context, return an error. `AutoTx` will try to open a new transaction if no transaction exist in context. That means it will always enter a transaction if there is no error. Co-authored-by: delvh <dev.lh@web.de> Co-authored-by: 6543 <6543@obermui.de>
1 parent a0a425a commit 34283a7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+252
-176
lines changed

models/activities/action.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ func NotifyWatchers(actions ...*Action) error {
572572

573573
// NotifyWatchersActions creates batch of actions for every watcher.
574574
func NotifyWatchersActions(acts []*Action) error {
575-
ctx, committer, err := db.TxContext()
575+
ctx, committer, err := db.TxContext(db.DefaultContext)
576576
if err != nil {
577577
return err
578578
}

models/activities/notification.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ func CountNotifications(opts *FindNotificationOptions) (int64, error) {
142142

143143
// CreateRepoTransferNotification creates notification for the user a repository was transferred to
144144
func CreateRepoTransferNotification(doer, newOwner *user_model.User, repo *repo_model.Repository) error {
145-
ctx, committer, err := db.TxContext()
145+
ctx, committer, err := db.TxContext(db.DefaultContext)
146146
if err != nil {
147147
return err
148148
}
@@ -185,7 +185,7 @@ func CreateRepoTransferNotification(doer, newOwner *user_model.User, repo *repo_
185185
// for each watcher, or updates it if already exists
186186
// receiverID > 0 just send to receiver, else send to all watcher
187187
func CreateOrUpdateIssueNotifications(issueID, commentID, notificationAuthorID, receiverID int64) error {
188-
ctx, committer, err := db.TxContext()
188+
ctx, committer, err := db.TxContext(db.DefaultContext)
189189
if err != nil {
190190
return err
191191
}

models/asymkey/gpg_key.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ func DeleteGPGKey(doer *user_model.User, id int64) (err error) {
234234
return ErrGPGKeyAccessDenied{doer.ID, key.ID}
235235
}
236236

237-
ctx, committer, err := db.TxContext()
237+
ctx, committer, err := db.TxContext(db.DefaultContext)
238238
if err != nil {
239239
return err
240240
}

models/asymkey/gpg_key_add.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
7373
return nil, err
7474
}
7575

76-
ctx, committer, err := db.TxContext()
76+
ctx, committer, err := db.TxContext(db.DefaultContext)
7777
if err != nil {
7878
return nil, err
7979
}

models/asymkey/gpg_key_verify.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import (
3131

3232
// VerifyGPGKey marks a GPG key as verified
3333
func VerifyGPGKey(ownerID int64, keyID, token, signature string) (string, error) {
34-
ctx, committer, err := db.TxContext()
34+
ctx, committer, err := db.TxContext(db.DefaultContext)
3535
if err != nil {
3636
return "", err
3737
}

models/asymkey/ssh_key.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
100100
return nil, err
101101
}
102102

103-
ctx, committer, err := db.TxContext()
103+
ctx, committer, err := db.TxContext(db.DefaultContext)
104104
if err != nil {
105105
return nil, err
106106
}
@@ -321,7 +321,7 @@ func PublicKeyIsExternallyManaged(id int64) (bool, error) {
321321
// deleteKeysMarkedForDeletion returns true if ssh keys needs update
322322
func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
323323
// Start session
324-
ctx, committer, err := db.TxContext()
324+
ctx, committer, err := db.TxContext(db.DefaultContext)
325325
if err != nil {
326326
return false, err
327327
}

models/asymkey/ssh_key_deploy.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey
126126
accessMode = perm.AccessModeWrite
127127
}
128128

129-
ctx, committer, err := db.TxContext()
129+
ctx, committer, err := db.TxContext(db.DefaultContext)
130130
if err != nil {
131131
return nil, err
132132
}

models/asymkey/ssh_key_principals.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626

2727
// AddPrincipalKey adds new principal to database and authorized_principals file.
2828
func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*PublicKey, error) {
29-
ctx, committer, err := db.TxContext()
29+
ctx, committer, err := db.TxContext(db.DefaultContext)
3030
if err != nil {
3131
return nil, err
3232
}

models/asymkey/ssh_key_verify.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515

1616
// VerifySSHKey marks a SSH key as verified
1717
func VerifySSHKey(ownerID int64, fingerprint, token, signature string) (string, error) {
18-
ctx, committer, err := db.TxContext()
18+
ctx, committer, err := db.TxContext(db.DefaultContext)
1919
if err != nil {
2020
return "", err
2121
}

models/auth/oauth2.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ type UpdateOAuth2ApplicationOptions struct {
201201

202202
// UpdateOAuth2Application updates an oauth2 application
203203
func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
204-
ctx, committer, err := db.TxContext()
204+
ctx, committer, err := db.TxContext(db.DefaultContext)
205205
if err != nil {
206206
return nil, err
207207
}
@@ -265,7 +265,7 @@ func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
265265

266266
// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
267267
func DeleteOAuth2Application(id, userid int64) error {
268-
ctx, committer, err := db.TxContext()
268+
ctx, committer, err := db.TxContext(db.DefaultContext)
269269
if err != nil {
270270
return err
271271
}

models/auth/session.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func ReadSession(key string) (*Session, error) {
3737
Key: key,
3838
}
3939

40-
ctx, committer, err := db.TxContext()
40+
ctx, committer, err := db.TxContext(db.DefaultContext)
4141
if err != nil {
4242
return nil, err
4343
}
@@ -73,7 +73,7 @@ func DestroySession(key string) error {
7373

7474
// RegenerateSession regenerates a session from the old id
7575
func RegenerateSession(oldKey, newKey string) (*Session, error) {
76-
ctx, committer, err := db.TxContext()
76+
ctx, committer, err := db.TxContext(db.DefaultContext)
7777
if err != nil {
7878
return nil, err
7979
}

models/avatars/avatar.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func saveEmailHash(email string) string {
9797
Hash: emailHash,
9898
}
9999
// OK we're going to open a session just because I think that that might hide away any problems with postgres reporting errors
100-
if err := db.WithTx(func(ctx context.Context) error {
100+
if err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
101101
has, err := db.GetEngine(ctx).Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash))
102102
if has || err != nil {
103103
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time

models/db/context.go

+47-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"context"
99
"database/sql"
1010

11+
"xorm.io/xorm"
1112
"xorm.io/xorm/schemas"
1213
)
1314

@@ -86,7 +87,11 @@ type Committer interface {
8687
}
8788

8889
// TxContext represents a transaction Context
89-
func TxContext() (*Context, Committer, error) {
90+
func TxContext(parentCtx context.Context) (*Context, Committer, error) {
91+
if InTransaction(parentCtx) {
92+
return nil, nil, ErrAlreadyInTransaction
93+
}
94+
9095
sess := x.NewSession()
9196
if err := sess.Begin(); err != nil {
9297
sess.Close()
@@ -97,14 +102,24 @@ func TxContext() (*Context, Committer, error) {
97102
}
98103

99104
// WithTx represents executing database operations on a transaction
100-
// you can optionally change the context to a parent one
101-
func WithTx(f func(ctx context.Context) error, stdCtx ...context.Context) error {
102-
parentCtx := DefaultContext
103-
if len(stdCtx) != 0 && stdCtx[0] != nil {
104-
// TODO: make sure parent context has no open session
105-
parentCtx = stdCtx[0]
105+
// This function will always open a new transaction, if a transaction exist in parentCtx return an error.
106+
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
107+
if InTransaction(parentCtx) {
108+
return ErrAlreadyInTransaction
109+
}
110+
return txWithNoCheck(parentCtx, f)
111+
}
112+
113+
// AutoTx represents executing database operations on a transaction, if the transaction exist,
114+
// this function will reuse it otherwise will create a new one and close it when finished.
115+
func AutoTx(parentCtx context.Context, f func(ctx context.Context) error) error {
116+
if InTransaction(parentCtx) {
117+
return f(newContext(parentCtx, GetEngine(parentCtx), true))
106118
}
119+
return txWithNoCheck(parentCtx, f)
120+
}
107121

122+
func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error {
108123
sess := x.NewSession()
109124
defer sess.Close()
110125
if err := sess.Begin(); err != nil {
@@ -180,3 +195,28 @@ func EstimateCount(ctx context.Context, bean interface{}) (int64, error) {
180195
}
181196
return rows, err
182197
}
198+
199+
// InTransaction returns true if the engine is in a transaction otherwise return false
200+
func InTransaction(ctx context.Context) bool {
201+
var e Engine
202+
if engined, ok := ctx.(Engined); ok {
203+
e = engined.Engine()
204+
} else {
205+
enginedInterface := ctx.Value(enginedContextKey)
206+
if enginedInterface != nil {
207+
e = enginedInterface.(Engined).Engine()
208+
}
209+
}
210+
if e == nil {
211+
return false
212+
}
213+
214+
switch t := e.(type) {
215+
case *xorm.Engine:
216+
return false
217+
case *xorm.Session:
218+
return t.IsInTx()
219+
default:
220+
return false
221+
}
222+
}

models/db/context_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright 2022 The Gitea Authors. All rights reserved.
2+
// Use of this source code is governed by a MIT-style
3+
// license that can be found in the LICENSE file.
4+
5+
package db_test
6+
7+
import (
8+
"context"
9+
"testing"
10+
11+
"code.gitea.io/gitea/models/db"
12+
"code.gitea.io/gitea/models/unittest"
13+
14+
"github.com/stretchr/testify/assert"
15+
)
16+
17+
func TestInTransaction(t *testing.T) {
18+
assert.NoError(t, unittest.PrepareTestDatabase())
19+
assert.False(t, db.InTransaction(db.DefaultContext))
20+
assert.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error {
21+
assert.True(t, db.InTransaction(ctx))
22+
return nil
23+
}))
24+
25+
ctx, committer, err := db.TxContext(db.DefaultContext)
26+
assert.NoError(t, err)
27+
defer committer.Close()
28+
assert.True(t, db.InTransaction(ctx))
29+
assert.Error(t, db.WithTx(ctx, func(ctx context.Context) error {
30+
assert.True(t, db.InTransaction(ctx))
31+
return nil
32+
}))
33+
}

models/db/error.go

+3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
package db
66

77
import (
8+
"errors"
89
"fmt"
910

1011
"code.gitea.io/gitea/modules/util"
1112
)
1213

14+
var ErrAlreadyInTransaction = errors.New("database connection has already been in a transaction")
15+
1316
// ErrCancelled represents an error due to context cancellation
1417
type ErrCancelled struct {
1518
Message string

models/db/index_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
5959
assert.EqualValues(t, 62, maxIndex)
6060

6161
// commit transaction
62-
err = db.WithTx(func(ctx context.Context) error {
62+
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
6363
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 73)
6464
assert.NoError(t, err)
6565
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
@@ -73,7 +73,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
7373
assert.EqualValues(t, 73, maxIndex)
7474

7575
// rollback transaction
76-
err = db.WithTx(func(ctx context.Context) error {
76+
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
7777
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 84)
7878
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
7979
assert.NoError(t, err)
@@ -102,7 +102,7 @@ func TestGetNextResourceIndex(t *testing.T) {
102102
assert.EqualValues(t, 2, maxIndex)
103103

104104
// commit transaction
105-
err = db.WithTx(func(ctx context.Context) error {
105+
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
106106
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
107107
assert.NoError(t, err)
108108
assert.EqualValues(t, 3, maxIndex)
@@ -114,7 +114,7 @@ func TestGetNextResourceIndex(t *testing.T) {
114114
assert.EqualValues(t, 3, maxIndex)
115115

116116
// rollback transaction
117-
err = db.WithTx(func(ctx context.Context) error {
117+
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
118118
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
119119
assert.NoError(t, err)
120120
assert.EqualValues(t, 4, maxIndex)

models/git/branches.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ func FindRenamedBranch(repoID int64, from string) (branch *RenamedBranch, exist
544544

545545
// RenameBranch rename a branch
546546
func RenameBranch(repo *repo_model.Repository, from, to string, gitAction func(isDefault bool) error) (err error) {
547-
ctx, committer, err := db.TxContext()
547+
ctx, committer, err := db.TxContext(db.DefaultContext)
548548
if err != nil {
549549
return err
550550
}

models/git/branches_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func TestRenameBranch(t *testing.T) {
102102
repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
103103
_isDefault := false
104104

105-
ctx, committer, err := db.TxContext()
105+
ctx, committer, err := db.TxContext(db.DefaultContext)
106106
defer committer.Close()
107107
assert.NoError(t, err)
108108
assert.NoError(t, git_model.UpdateProtectBranch(ctx, repo1, &git_model.ProtectedBranch{

models/git/commit_status.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func GetNextCommitStatusIndex(repoID int64, sha string) (int64, error) {
9494

9595
// getNextCommitStatusIndex return the next index
9696
func getNextCommitStatusIndex(repoID int64, sha string) (int64, error) {
97-
ctx, commiter, err := db.TxContext()
97+
ctx, commiter, err := db.TxContext(db.DefaultContext)
9898
if err != nil {
9999
return 0, err
100100
}
@@ -297,7 +297,7 @@ func NewCommitStatus(opts NewCommitStatusOptions) error {
297297
return fmt.Errorf("generate commit status index failed: %w", err)
298298
}
299299

300-
ctx, committer, err := db.TxContext()
300+
ctx, committer, err := db.TxContext(db.DefaultContext)
301301
if err != nil {
302302
return fmt.Errorf("NewCommitStatus[repo_id: %d, user_id: %d, sha: %s]: %w", opts.Repo.ID, opts.Creator.ID, opts.SHA, err)
303303
}

models/git/lfs.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ var ErrLFSObjectNotExist = db.ErrNotExist{Resource: "LFS Meta object"}
137137
func NewLFSMetaObject(m *LFSMetaObject) (*LFSMetaObject, error) {
138138
var err error
139139

140-
ctx, committer, err := db.TxContext()
140+
ctx, committer, err := db.TxContext(db.DefaultContext)
141141
if err != nil {
142142
return nil, err
143143
}
@@ -185,7 +185,7 @@ func RemoveLFSMetaObjectByOid(repoID int64, oid string) (int64, error) {
185185
return 0, ErrLFSObjectNotExist
186186
}
187187

188-
ctx, committer, err := db.TxContext()
188+
ctx, committer, err := db.TxContext(db.DefaultContext)
189189
if err != nil {
190190
return 0, err
191191
}
@@ -242,7 +242,7 @@ func LFSObjectIsAssociated(oid string) (bool, error) {
242242

243243
// LFSAutoAssociate auto associates accessible LFSMetaObjects
244244
func LFSAutoAssociate(metas []*LFSMetaObject, user *user_model.User, repoID int64) error {
245-
ctx, committer, err := db.TxContext()
245+
ctx, committer, err := db.TxContext(db.DefaultContext)
246246
if err != nil {
247247
return err
248248
}

models/git/lfs_lock.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func cleanPath(p string) string {
4444

4545
// CreateLFSLock creates a new lock.
4646
func CreateLFSLock(repo *repo_model.Repository, lock *LFSLock) (*LFSLock, error) {
47-
dbCtx, committer, err := db.TxContext()
47+
dbCtx, committer, err := db.TxContext(db.DefaultContext)
4848
if err != nil {
4949
return nil, err
5050
}
@@ -137,7 +137,7 @@ func CountLFSLockByRepoID(repoID int64) (int64, error) {
137137

138138
// DeleteLFSLockByID deletes a lock by given ID.
139139
func DeleteLFSLockByID(id int64, repo *repo_model.Repository, u *user_model.User, force bool) (*LFSLock, error) {
140-
dbCtx, committer, err := db.TxContext()
140+
dbCtx, committer, err := db.TxContext(db.DefaultContext)
141141
if err != nil {
142142
return nil, err
143143
}

models/issues/assignees.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func IsUserAssignedToIssue(ctx context.Context, issue *Issue, user *user_model.U
6464

6565
// ToggleIssueAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it.
6666
func ToggleIssueAssignee(issue *Issue, doer *user_model.User, assigneeID int64) (removed bool, comment *Comment, err error) {
67-
ctx, committer, err := db.TxContext()
67+
ctx, committer, err := db.TxContext(db.DefaultContext)
6868
if err != nil {
6969
return false, nil, err
7070
}

0 commit comments

Comments
 (0)