Skip to content

Commit dd30d9d

Browse files
authored
Remove GetByBean method because sometimes it's danger when query condition parameter is zero and also introduce new generic methods (#28220)
The function `GetByBean` has an obvious defect that when the fields are empty values, it will be ignored. Then users will get a wrong result which is possibly used to make a security problem. To avoid the possibility, this PR removed function `GetByBean` and all references. And some new generic functions have been introduced to be used. The recommand usage like below. ```go // if query an object according id obj, err := db.GetByID[Object](ctx, id) // query with other conditions obj, err := db.Get[Object](ctx, builder.Eq{"a": a, "b":b}) ```
1 parent beb71f5 commit dd30d9d

28 files changed

+189
-174
lines changed

models/asymkey/ssh_key_deploy.go

+13-20
Original file line numberDiff line numberDiff line change
@@ -131,24 +131,22 @@ func AddDeployKey(ctx context.Context, repoID int64, name, content string, readO
131131
}
132132
defer committer.Close()
133133

134-
pkey := &PublicKey{
135-
Fingerprint: fingerprint,
136-
}
137-
has, err := db.GetByBean(ctx, pkey)
134+
pkey, exist, err := db.Get[PublicKey](ctx, builder.Eq{"fingerprint": fingerprint})
138135
if err != nil {
139136
return nil, err
140-
}
141-
142-
if has {
137+
} else if exist {
143138
if pkey.Type != KeyTypeDeploy {
144139
return nil, ErrKeyAlreadyExist{0, fingerprint, ""}
145140
}
146141
} else {
147142
// First time use this deploy key.
148-
pkey.Mode = accessMode
149-
pkey.Type = KeyTypeDeploy
150-
pkey.Content = content
151-
pkey.Name = name
143+
pkey = &PublicKey{
144+
Fingerprint: fingerprint,
145+
Mode: accessMode,
146+
Type: KeyTypeDeploy,
147+
Content: content,
148+
Name: name,
149+
}
152150
if err = addKey(ctx, pkey); err != nil {
153151
return nil, fmt.Errorf("addKey: %w", err)
154152
}
@@ -164,26 +162,21 @@ func AddDeployKey(ctx context.Context, repoID int64, name, content string, readO
164162

165163
// GetDeployKeyByID returns deploy key by given ID.
166164
func GetDeployKeyByID(ctx context.Context, id int64) (*DeployKey, error) {
167-
key := new(DeployKey)
168-
has, err := db.GetEngine(ctx).ID(id).Get(key)
165+
key, exist, err := db.GetByID[DeployKey](ctx, id)
169166
if err != nil {
170167
return nil, err
171-
} else if !has {
168+
} else if !exist {
172169
return nil, ErrDeployKeyNotExist{id, 0, 0}
173170
}
174171
return key, nil
175172
}
176173

177174
// GetDeployKeyByRepo returns deploy key by given public key ID and repository ID.
178175
func GetDeployKeyByRepo(ctx context.Context, keyID, repoID int64) (*DeployKey, error) {
179-
key := &DeployKey{
180-
KeyID: keyID,
181-
RepoID: repoID,
182-
}
183-
has, err := db.GetByBean(ctx, key)
176+
key, exist, err := db.Get[DeployKey](ctx, builder.Eq{"key_id": keyID, "repo_id": repoID})
184177
if err != nil {
185178
return nil, err
186-
} else if !has {
179+
} else if !exist {
187180
return nil, ErrDeployKeyNotExist{0, keyID, repoID}
188181
}
189182
return key, nil

models/asymkey/ssh_key_fingerprint.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"code.gitea.io/gitea/modules/util"
1616

1717
"golang.org/x/crypto/ssh"
18+
"xorm.io/builder"
1819
)
1920

2021
// ___________.__ .__ __
@@ -31,9 +32,7 @@ import (
3132
// checkKeyFingerprint only checks if key fingerprint has been used as public key,
3233
// it is OK to use same key as deploy key for multiple repositories/users.
3334
func checkKeyFingerprint(ctx context.Context, fingerprint string) error {
34-
has, err := db.GetByBean(ctx, &PublicKey{
35-
Fingerprint: fingerprint,
36-
})
35+
has, err := db.Exist[PublicKey](ctx, builder.Eq{"fingerprint": fingerprint})
3736
if err != nil {
3837
return err
3938
} else if has {

models/auth/session.go

+13-22
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99

1010
"code.gitea.io/gitea/models/db"
1111
"code.gitea.io/gitea/modules/timeutil"
12+
13+
"xorm.io/builder"
1214
)
1315

1416
// Session represents a session compatible for go-chi session
@@ -33,34 +35,28 @@ func UpdateSession(ctx context.Context, key string, data []byte) error {
3335

3436
// ReadSession reads the data for the provided session
3537
func ReadSession(ctx context.Context, key string) (*Session, error) {
36-
session := Session{
37-
Key: key,
38-
}
39-
4038
ctx, committer, err := db.TxContext(ctx)
4139
if err != nil {
4240
return nil, err
4341
}
4442
defer committer.Close()
4543

46-
if has, err := db.GetByBean(ctx, &session); err != nil {
44+
session, exist, err := db.Get[Session](ctx, builder.Eq{"key": key})
45+
if err != nil {
4746
return nil, err
48-
} else if !has {
47+
} else if !exist {
4948
session.Expiry = timeutil.TimeStampNow()
5049
if err := db.Insert(ctx, &session); err != nil {
5150
return nil, err
5251
}
5352
}
5453

55-
return &session, committer.Commit()
54+
return session, committer.Commit()
5655
}
5756

5857
// ExistSession checks if a session exists
5958
func ExistSession(ctx context.Context, key string) (bool, error) {
60-
session := Session{
61-
Key: key,
62-
}
63-
return db.GetEngine(ctx).Get(&session)
59+
return db.Exist[Session](ctx, builder.Eq{"key": key})
6460
}
6561

6662
// DestroySession destroys a session
@@ -79,17 +75,13 @@ func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, er
7975
}
8076
defer committer.Close()
8177

82-
if has, err := db.GetByBean(ctx, &Session{
83-
Key: newKey,
84-
}); err != nil {
78+
if has, err := db.Exist[Session](ctx, builder.Eq{"key": newKey}); err != nil {
8579
return nil, err
8680
} else if has {
8781
return nil, fmt.Errorf("session Key: %s already exists", newKey)
8882
}
8983

90-
if has, err := db.GetByBean(ctx, &Session{
91-
Key: oldKey,
92-
}); err != nil {
84+
if has, err := db.Exist[Session](ctx, builder.Eq{"key": oldKey}); err != nil {
9385
return nil, err
9486
} else if !has {
9587
if err := db.Insert(ctx, &Session{
@@ -104,14 +96,13 @@ func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, er
10496
return nil, err
10597
}
10698

107-
s := Session{
108-
Key: newKey,
109-
}
110-
if _, err := db.GetByBean(ctx, &s); err != nil {
99+
s, _, err := db.Get[Session](ctx, builder.Eq{"key": newKey})
100+
if err != nil {
101+
// is not exist, it should be impossible
111102
return nil, err
112103
}
113104

114-
return &s, committer.Commit()
105+
return s, committer.Commit()
115106
}
116107

117108
// CountSessions returns the number of sessions

models/auth/source.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,10 @@ func IsSSPIEnabled(ctx context.Context) bool {
265265
return false
266266
}
267267

268-
exist, err := db.Exists[Source](ctx, FindSourcesOptions{
268+
exist, err := db.Exist[Source](ctx, FindSourcesOptions{
269269
IsActive: util.OptionalBoolTrue,
270270
LoginType: SSPI,
271-
})
271+
}.ToConds())
272272
if err != nil {
273273
log.Error("Active SSPI Sources: %v", err)
274274
return false

models/db/context.go

+38-8
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,44 @@ func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) {
173173
return GetEngine(ctx).Exec(sqlAndArgs...)
174174
}
175175

176-
// GetByBean filled empty fields of the bean according non-empty fields to query in database.
177-
func GetByBean(ctx context.Context, bean any) (bool, error) {
178-
return GetEngine(ctx).Get(bean)
176+
func Get[T any](ctx context.Context, cond builder.Cond) (object *T, exist bool, err error) {
177+
if !cond.IsValid() {
178+
return nil, false, ErrConditionRequired{}
179+
}
180+
181+
var bean T
182+
has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean)
183+
if err != nil {
184+
return nil, false, err
185+
} else if !has {
186+
return nil, false, nil
187+
}
188+
return &bean, true, nil
189+
}
190+
191+
func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err error) {
192+
var bean T
193+
has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean)
194+
if err != nil {
195+
return nil, false, err
196+
} else if !has {
197+
return nil, false, nil
198+
}
199+
return &bean, true, nil
200+
}
201+
202+
func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) {
203+
if !cond.IsValid() {
204+
return false, ErrConditionRequired{}
205+
}
206+
207+
var bean T
208+
return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean)
209+
}
210+
211+
func ExistByID[T any](ctx context.Context, id int64) (bool, error) {
212+
var bean T
213+
return GetEngine(ctx).ID(id).NoAutoCondition().Exist(&bean)
179214
}
180215

181216
// DeleteByBean deletes all records according non-empty fields of the bean as conditions.
@@ -264,8 +299,3 @@ func inTransaction(ctx context.Context) (*xorm.Session, bool) {
264299
return nil, false
265300
}
266301
}
267-
268-
func Exists[T any](ctx context.Context, opts FindOptions) (bool, error) {
269-
var bean T
270-
return GetEngine(ctx).Where(opts.ToConds()).Exist(&bean)
271-
}

models/db/error.go

+18
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,21 @@ func (err ErrNotExist) Error() string {
7272
func (err ErrNotExist) Unwrap() error {
7373
return util.ErrNotExist
7474
}
75+
76+
// ErrConditionRequired represents an error which require condition.
77+
type ErrConditionRequired struct{}
78+
79+
// IsErrConditionRequired checks if an error is an ErrConditionRequired
80+
func IsErrConditionRequired(err error) bool {
81+
_, ok := err.(ErrConditionRequired)
82+
return ok
83+
}
84+
85+
func (err ErrConditionRequired) Error() string {
86+
return "condition is required"
87+
}
88+
89+
// Unwrap unwraps this as a ErrNotExist err
90+
func (err ErrConditionRequired) Unwrap() error {
91+
return util.ErrInvalidArgument
92+
}

models/db/iterate_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ func TestIterate(t *testing.T) {
3131
assert.EqualValues(t, cnt, repoUnitCnt)
3232

3333
err = db.Iterate(db.DefaultContext, nil, func(ctx context.Context, repoUnit *repo_model.RepoUnit) error {
34-
reopUnit2 := repo_model.RepoUnit{ID: repoUnit.ID}
35-
has, err := db.GetByBean(ctx, &reopUnit2)
34+
has, err := db.ExistByID[repo_model.RepoUnit](ctx, repoUnit.ID)
3635
if err != nil {
3736
return err
38-
} else if !has {
37+
}
38+
if !has {
3939
return db.ErrNotExist{Resource: "repo_unit", ID: repoUnit.ID}
4040
}
4141
assert.EqualValues(t, repoUnit.RepoID, repoUnit.RepoID)

models/git/lfs.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ var ErrLFSObjectNotExist = db.ErrNotExist{Resource: "LFS Meta object"}
135135

136136
// NewLFSMetaObject stores a given populated LFSMetaObject structure in the database
137137
// if it is not already present.
138-
func NewLFSMetaObject(ctx context.Context, m *LFSMetaObject) (*LFSMetaObject, error) {
138+
func NewLFSMetaObject(ctx context.Context, repoID int64, p lfs.Pointer) (*LFSMetaObject, error) {
139139
var err error
140140

141141
ctx, committer, err := db.TxContext(ctx)
@@ -144,16 +144,15 @@ func NewLFSMetaObject(ctx context.Context, m *LFSMetaObject) (*LFSMetaObject, er
144144
}
145145
defer committer.Close()
146146

147-
has, err := db.GetByBean(ctx, m)
147+
m, exist, err := db.Get[LFSMetaObject](ctx, builder.Eq{"repository_id": repoID, "oid": p.Oid})
148148
if err != nil {
149149
return nil, err
150-
}
151-
152-
if has {
150+
} else if exist {
153151
m.Existing = true
154152
return m, committer.Commit()
155153
}
156154

155+
m = &LFSMetaObject{Pointer: p, RepositoryID: repoID}
157156
if err = db.Insert(ctx, m); err != nil {
158157
return nil, err
159158
}

models/git/protected_branch.go

+6-8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424

2525
"github.com/gobwas/glob"
2626
"github.com/gobwas/glob/syntax"
27+
"xorm.io/builder"
2728
)
2829

2930
var ErrBranchIsProtected = errors.New("branch is protected")
@@ -274,25 +275,22 @@ func (protectBranch *ProtectedBranch) IsUnprotectedFile(patterns []glob.Glob, pa
274275

275276
// GetProtectedBranchRuleByName getting protected branch rule by name
276277
func GetProtectedBranchRuleByName(ctx context.Context, repoID int64, ruleName string) (*ProtectedBranch, error) {
277-
rel := &ProtectedBranch{RepoID: repoID, RuleName: ruleName}
278-
has, err := db.GetByBean(ctx, rel)
278+
// branch_name is legacy name, it actually is rule name
279+
rel, exist, err := db.Get[ProtectedBranch](ctx, builder.Eq{"repo_id": repoID, "branch_name": ruleName})
279280
if err != nil {
280281
return nil, err
281-
}
282-
if !has {
282+
} else if !exist {
283283
return nil, nil
284284
}
285285
return rel, nil
286286
}
287287

288288
// GetProtectedBranchRuleByID getting protected branch rule by rule ID
289289
func GetProtectedBranchRuleByID(ctx context.Context, repoID, ruleID int64) (*ProtectedBranch, error) {
290-
rel := &ProtectedBranch{ID: ruleID, RepoID: repoID}
291-
has, err := db.GetByBean(ctx, rel)
290+
rel, exist, err := db.Get[ProtectedBranch](ctx, builder.Eq{"repo_id": repoID, "id": ruleID})
292291
if err != nil {
293292
return nil, err
294-
}
295-
if !has {
293+
} else if !exist {
296294
return nil, nil
297295
}
298296
return rel, nil

models/issues/assignees.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"code.gitea.io/gitea/models/db"
1111
user_model "code.gitea.io/gitea/models/user"
1212
"code.gitea.io/gitea/modules/util"
13+
14+
"xorm.io/builder"
1315
)
1416

1517
// IssueAssignees saves all issue assignees
@@ -59,7 +61,7 @@ func GetAssigneeIDsByIssue(ctx context.Context, issueID int64) ([]int64, error)
5961

6062
// IsUserAssignedToIssue returns true when the user is assigned to the issue
6163
func IsUserAssignedToIssue(ctx context.Context, issue *Issue, user *user_model.User) (isAssigned bool, err error) {
62-
return db.GetByBean(ctx, &IssueAssignees{IssueID: issue.ID, AssigneeID: user.ID})
64+
return db.Exist[IssueAssignees](ctx, builder.Eq{"assignee_id": user.ID, "issue_id": issue.ID})
6365
}
6466

6567
// ToggleIssueAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it.

0 commit comments

Comments
 (0)