Skip to content

Commit b9c07f4

Browse files
committed
Refactor duplicated transaction management code
The code for transactions is duplicated between SQLite and PostgreSQL. MySQL would have also used identical code. However, the SQL being executed is not universal across all backends. Oracle appears to use the same SQL, but SQL Server has its own special syntax for this. As such, I'm not comfortable promoting this to a default impl on the trait. Instead I've moved the code out into a shared trait/struct, and operate on that instead. I had wanted to make `TransactionManager` be generic over the backend, not the connection itself, since constraints for it will always be about the backend, but I ran into rust-lang/rust#39532 when attempting to do so.
1 parent 2c41bf9 commit b9c07f4

File tree

9 files changed

+138
-114
lines changed

9 files changed

+138
-114
lines changed

diesel/src/backend.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub trait TypeMetadata {
2727

2828
pub trait SupportsReturningClause {}
2929
pub trait SupportsDefaultKeyword {}
30+
pub trait UsesAnsiSavepointSyntax {}
3031

3132
#[derive(Debug, Copy, Clone)]
3233
pub struct Debug;
@@ -50,3 +51,4 @@ impl TypeMetadata for Debug {
5051

5152
impl SupportsReturningClause for Debug {}
5253
impl SupportsDefaultKeyword for Debug {}
54+
impl UsesAnsiSavepointSyntax for Debug {}

diesel/src/connection/mod.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1+
mod transaction_manager;
2+
13
use backend::Backend;
24
use query_builder::{AsQuery, QueryFragment, QueryId};
35
use query_source::Queryable;
46
use result::*;
57
use types::HasSqlType;
68

9+
pub use self::transaction_manager::{TransactionManager, AnsiTransactionManager};
10+
711
pub trait SimpleConnection {
812
#[doc(hidden)]
913
fn batch_execute(&self, query: &str) -> QueryResult<()>;
1014
}
1115

1216
pub trait Connection: SimpleConnection + Sized {
1317
type Backend: Backend;
18+
#[doc(hidden)]
19+
type TransactionManager: TransactionManager<Self>;
1420

1521
/// Establishes a new connection to the database at the given URL. The URL
1622
/// should be a valid connection string for a given backend. See the
@@ -28,14 +34,15 @@ pub trait Connection: SimpleConnection + Sized {
2834
fn transaction<T, E, F>(&self, f: F) -> TransactionResult<T, E> where
2935
F: FnOnce() -> Result<T, E>,
3036
{
31-
try!(self.begin_transaction());
37+
let transaction_manager = self.transaction_manager();
38+
try!(transaction_manager.begin_transaction(self));
3239
match f() {
3340
Ok(value) => {
34-
try!(self.commit_transaction());
41+
try!(transaction_manager.commit_transaction(self));
3542
Ok(value)
3643
},
3744
Err(e) => {
38-
try!(self.rollback_transaction());
45+
try!(transaction_manager.rollback_transaction(self));
3946
Err(TransactionError::UserReturnedError(e))
4047
},
4148
}
@@ -44,8 +51,9 @@ pub trait Connection: SimpleConnection + Sized {
4451
/// Creates a transaction that will never be committed. This is useful for
4552
/// tests. Panics if called while inside of a transaction.
4653
fn begin_test_transaction(&self) -> QueryResult<()> {
47-
assert_eq!(self.get_transaction_depth(), 0);
48-
self.begin_transaction()
54+
let transaction_manager = self.transaction_manager();
55+
assert_eq!(transaction_manager.get_transaction_depth(), 0);
56+
transaction_manager.begin_transaction(self)
4957
}
5058

5159
/// Executes the given function inside a transaction, but does not commit
@@ -86,11 +94,9 @@ pub trait Connection: SimpleConnection + Sized {
8694
fn execute_returning_count<T>(&self, source: &T) -> QueryResult<usize> where
8795
T: QueryFragment<Self::Backend> + QueryId;
8896

97+
8998
#[doc(hidden)] fn silence_notices<F: FnOnce() -> T, T>(&self, f: F) -> T;
90-
#[doc(hidden)] fn begin_transaction(&self) -> QueryResult<()>;
91-
#[doc(hidden)] fn rollback_transaction(&self) -> QueryResult<()>;
92-
#[doc(hidden)] fn commit_transaction(&self) -> QueryResult<()>;
93-
#[doc(hidden)] fn get_transaction_depth(&self) -> i32;
99+
#[doc(hidden)] fn transaction_manager(&self) -> &Self::TransactionManager;
94100

95101
#[doc(hidden)] fn setup_helper_functions(&self);
96102
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
use backend::UsesAnsiSavepointSyntax;
2+
use connection::Connection;
3+
use result::QueryResult;
4+
5+
/// Manages the internal transaction state for a connection. You should not
6+
/// interface with this trait unless you are implementing a new connection
7+
/// adapter. You should use [`Connection::transaction`][transaction],
8+
/// [`Connection::test_transaction`][test_transaction], or
9+
/// [`Connection::begin_test_transaction`][begin_test_transaction] instead.
10+
pub trait TransactionManager<Conn: Connection> {
11+
/// Begin a new transaction. If the transaction depth is greater than 0,
12+
/// this should create a savepoint instead. This function is expected to
13+
/// increment the transaction depth by 1.
14+
fn begin_transaction(&self, conn: &Conn) -> QueryResult<()>;
15+
16+
/// Rollback the inner-most transcation. If the transaction depth is greater
17+
/// than 1, this should rollback to the most recent savepoint. This function
18+
/// is expected to decrement the transaction depth by 1.
19+
fn rollback_transaction(&self, conn: &Conn) -> QueryResult<()>;
20+
21+
/// Commit the inner-most transcation. If the transaction depth is greater
22+
/// than 1, this should release the most recent savepoint. This function is
23+
/// expected to decrement the transaction depth by 1.
24+
fn commit_transaction(&self, conn: &Conn) -> QueryResult<()>;
25+
26+
/// Fetch the current transaction depth. Used to ensure that
27+
/// `begin_test_transaction` is not called when already inside of a
28+
/// transaction.
29+
fn get_transaction_depth(&self) -> u32;
30+
}
31+
32+
use std::cell::Cell;
33+
34+
/// An implementation of TransactionManager which can be used for backends
35+
/// which use ANSI standard syntax for savepoints such as SQLite and PostgreSQL.
36+
#[allow(missing_debug_implementations)]
37+
pub struct AnsiTransactionManager {
38+
transaction_depth: Cell<i32>,
39+
}
40+
41+
impl AnsiTransactionManager {
42+
pub fn new() -> Self {
43+
AnsiTransactionManager {
44+
transaction_depth: Cell::new(0),
45+
}
46+
}
47+
48+
fn change_transaction_depth(&self, by: i32, query: QueryResult<()>) -> QueryResult<()> {
49+
if query.is_ok() {
50+
self.transaction_depth.set(self.transaction_depth.get() + by)
51+
}
52+
query
53+
}
54+
}
55+
56+
impl<Conn> TransactionManager<Conn> for AnsiTransactionManager where
57+
Conn: Connection,
58+
Conn::Backend: UsesAnsiSavepointSyntax,
59+
{
60+
fn begin_transaction(&self, conn: &Conn) -> QueryResult<()> {
61+
let transaction_depth = self.transaction_depth.get();
62+
self.change_transaction_depth(1, if transaction_depth == 0 {
63+
conn.batch_execute("BEGIN")
64+
} else {
65+
conn.batch_execute(&format!("SAVEPOINT diesel_savepoint_{}", transaction_depth))
66+
})
67+
}
68+
69+
fn rollback_transaction(&self, conn: &Conn) -> QueryResult<()> {
70+
let transaction_depth = self.transaction_depth.get();
71+
self.change_transaction_depth(-1, if transaction_depth == 1 {
72+
conn.batch_execute("ROLLBACK")
73+
} else {
74+
conn.batch_execute(&format!("ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
75+
transaction_depth - 1))
76+
})
77+
}
78+
79+
fn commit_transaction(&self, conn: &Conn) -> QueryResult<()> {
80+
let transaction_depth = self.transaction_depth.get();
81+
self.change_transaction_depth(-1, if transaction_depth <= 1 {
82+
conn.batch_execute("COMMIT")
83+
} else {
84+
conn.batch_execute(&format!("RELEASE SAVEPOINT diesel_savepoint_{}",
85+
transaction_depth - 1))
86+
})
87+
}
88+
89+
fn get_transaction_depth(&self) -> u32 {
90+
self.transaction_depth.get() as u32
91+
}
92+
}

diesel/src/mysql/backend.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ impl TypeMetadata for Mysql {
3939

4040
impl SupportsReturningClause for Mysql {}
4141
impl SupportsDefaultKeyword for Mysql {}
42+
impl UsesAnsiSavepointSyntax for Mysql {}
4243

4344
// FIXME: Move this out of this module
4445
use types::HasSqlType;

diesel/src/mysql/connection/mod.rs

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mod raw;
22
mod url;
33

4-
use connection::{Connection, SimpleConnection};
4+
use connection::{Connection, SimpleConnection, AnsiTransactionManager};
55
use query_builder::*;
66
use query_source::Queryable;
77
use result::*;
@@ -13,6 +13,7 @@ use types::HasSqlType;
1313
#[allow(missing_debug_implementations, missing_copy_implementations)]
1414
pub struct MysqlConnection {
1515
_raw_connection: RawConnection,
16+
transaction_manager: AnsiTransactionManager,
1617
}
1718

1819
impl SimpleConnection for MysqlConnection {
@@ -23,20 +24,24 @@ impl SimpleConnection for MysqlConnection {
2324

2425
impl Connection for MysqlConnection {
2526
type Backend = Mysql;
27+
type TransactionManager = AnsiTransactionManager;
2628

2729
fn establish(database_url: &str) -> ConnectionResult<Self> {
2830
let raw_connection = RawConnection::new();
2931
let connection_options = try!(ConnectionOptions::parse(database_url));
3032
try!(raw_connection.connect(connection_options));
3133
Ok(MysqlConnection {
3234
_raw_connection: raw_connection,
35+
transaction_manager: AnsiTransactionManager::new(),
3336
})
3437
}
3538

39+
#[doc(hidden)]
3640
fn execute(&self, _query: &str) -> QueryResult<usize> {
3741
unimplemented!()
3842
}
3943

44+
#[doc(hidden)]
4045
fn query_all<T, U>(&self, _source: T) -> QueryResult<Vec<U>> where
4146
T: AsQuery,
4247
T::Query: QueryFragment<Self::Backend> + QueryId,
@@ -46,30 +51,22 @@ impl Connection for MysqlConnection {
4651
unimplemented!()
4752
}
4853

54+
#[doc(hidden)]
4955
fn silence_notices<F: FnOnce() -> T, T>(&self, _f: F) -> T {
5056
unimplemented!()
5157
}
5258

59+
#[doc(hidden)]
5360
fn execute_returning_count<T>(&self, _source: &T) -> QueryResult<usize> {
5461
unimplemented!()
5562
}
5663

57-
fn begin_transaction(&self) -> QueryResult<()> {
58-
unimplemented!()
59-
}
60-
61-
fn rollback_transaction(&self) -> QueryResult<()> {
62-
unimplemented!()
63-
}
64-
65-
fn commit_transaction(&self) -> QueryResult<()> {
66-
unimplemented!()
67-
}
68-
69-
fn get_transaction_depth(&self) -> i32 {
70-
unimplemented!()
64+
#[doc(hidden)]
65+
fn transaction_manager(&self) -> &Self::TransactionManager {
66+
&self.transaction_manager
7167
}
7268

69+
#[doc(hidden)]
7370
fn setup_helper_functions(&self) {
7471
unimplemented!()
7572
}

diesel/src/pg/backend.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ impl TypeMetadata for Pg {
2323

2424
impl SupportsReturningClause for Pg {}
2525
impl SupportsDefaultKeyword for Pg {}
26+
impl UsesAnsiSavepointSyntax for Pg {}

diesel/src/pg/connection/mod.rs

Lines changed: 6 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@ mod row;
77
pub mod result;
88
mod stmt;
99

10-
use std::cell::Cell;
1110
use std::ffi::{CString, CStr};
1211
use std::rc::Rc;
1312

14-
use connection::{SimpleConnection, Connection};
13+
use connection::{SimpleConnection, Connection, AnsiTransactionManager};
1514
use pg::{Pg, PgQueryBuilder};
1615
use query_builder::{AsQuery, QueryFragment, QueryId};
1716
use query_builder::bind_collector::RawBytesBindCollector;
@@ -29,7 +28,7 @@ use types::HasSqlType;
2928
#[allow(missing_debug_implementations)]
3029
pub struct PgConnection {
3130
raw_connection: RawConnection,
32-
transaction_depth: Cell<i32>,
31+
transaction_manager: AnsiTransactionManager,
3332
statement_cache: StatementCache,
3433
}
3534

@@ -48,12 +47,13 @@ impl SimpleConnection for PgConnection {
4847

4948
impl Connection for PgConnection {
5049
type Backend = Pg;
50+
type TransactionManager = AnsiTransactionManager;
5151

5252
fn establish(database_url: &str) -> ConnectionResult<PgConnection> {
5353
RawConnection::establish(database_url).map(|raw_conn| {
5454
PgConnection {
5555
raw_connection: raw_conn,
56-
transaction_depth: Cell::new(0),
56+
transaction_manager: AnsiTransactionManager::new(),
5757
statement_cache: StatementCache::new(),
5858
}
5959
})
@@ -94,40 +94,8 @@ impl Connection for PgConnection {
9494
}
9595

9696
#[doc(hidden)]
97-
fn begin_transaction(&self) -> QueryResult<()> {
98-
let transaction_depth = self.transaction_depth.get();
99-
self.change_transaction_depth(1, if transaction_depth == 0 {
100-
self.execute("BEGIN")
101-
} else {
102-
self.execute(&format!("SAVEPOINT diesel_savepoint_{}", transaction_depth))
103-
})
104-
}
105-
106-
#[doc(hidden)]
107-
fn rollback_transaction(&self) -> QueryResult<()> {
108-
let transaction_depth = self.transaction_depth.get();
109-
self.change_transaction_depth(-1, if transaction_depth == 1 {
110-
self.execute("ROLLBACK")
111-
} else {
112-
self.execute(&format!("ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
113-
transaction_depth - 1))
114-
})
115-
}
116-
117-
#[doc(hidden)]
118-
fn commit_transaction(&self) -> QueryResult<()> {
119-
let transaction_depth = self.transaction_depth.get();
120-
self.change_transaction_depth(-1, if transaction_depth <= 1 {
121-
self.execute("COMMIT")
122-
} else {
123-
self.execute(&format!("RELEASE SAVEPOINT diesel_savepoint_{}",
124-
transaction_depth - 1))
125-
})
126-
}
127-
128-
#[doc(hidden)]
129-
fn get_transaction_depth(&self) -> i32 {
130-
self.transaction_depth.get()
97+
fn transaction_manager(&self) -> &Self::TransactionManager {
98+
&self.transaction_manager
13199
}
132100

133101
#[doc(hidden)]
@@ -166,13 +134,6 @@ impl PgConnection {
166134
let query = try!(Query::sql(query, None));
167135
query.execute(&self.raw_connection, &Vec::new())
168136
}
169-
170-
fn change_transaction_depth(&self, by: i32, query: QueryResult<usize>) -> QueryResult<()> {
171-
if query.is_ok() {
172-
self.transaction_depth.set(self.transaction_depth.get() + by);
173-
}
174-
query.map(|_| ())
175-
}
176137
}
177138

178139
extern "C" fn noop_notice_processor(_: *mut libc::c_void, _message: *const libc::c_char) {

diesel/src/sqlite/backend.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,5 @@ impl Backend for Sqlite {
2626
impl TypeMetadata for Sqlite {
2727
type TypeMetadata = SqliteType;
2828
}
29+
30+
impl UsesAnsiSavepointSyntax for Sqlite {}

0 commit comments

Comments
 (0)