gh-105539: Explict resource management for connection objects in sqlite3 tests (#108017)

- Use memory_database() helper
- Move test utility functions to util.py
- Add convenience memory database mixin
- Add check() helper for closed connection tests
This commit is contained in:
Erlend E. Aasland 2023-08-17 08:45:48 +02:00 committed by GitHub
parent c9d83f93d8
commit 1344cfac43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 371 additions and 385 deletions

View File

@ -1,6 +1,8 @@
import sqlite3 as sqlite import sqlite3 as sqlite
import unittest import unittest
from .util import memory_database
class BackupTests(unittest.TestCase): class BackupTests(unittest.TestCase):
def setUp(self): def setUp(self):
@ -32,20 +34,20 @@ def test_bad_target_same_connection(self):
self.cx.backup(self.cx) self.cx.backup(self.cx)
def test_bad_target_closed_connection(self): def test_bad_target_closed_connection(self):
bck = sqlite.connect(':memory:') with memory_database() as bck:
bck.close() bck.close()
with self.assertRaises(sqlite.ProgrammingError): with self.assertRaises(sqlite.ProgrammingError):
self.cx.backup(bck) self.cx.backup(bck)
def test_bad_source_closed_connection(self): def test_bad_source_closed_connection(self):
bck = sqlite.connect(':memory:') with memory_database() as bck:
source = sqlite.connect(":memory:") source = sqlite.connect(":memory:")
source.close() source.close()
with self.assertRaises(sqlite.ProgrammingError): with self.assertRaises(sqlite.ProgrammingError):
source.backup(bck) source.backup(bck)
def test_bad_target_in_transaction(self): def test_bad_target_in_transaction(self):
bck = sqlite.connect(':memory:') with memory_database() as bck:
bck.execute('CREATE TABLE bar (key INTEGER)') bck.execute('CREATE TABLE bar (key INTEGER)')
bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)]) bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)])
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
@ -53,11 +55,11 @@ def test_bad_target_in_transaction(self):
def test_keyword_only_args(self): def test_keyword_only_args(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck, 1) self.cx.backup(bck, 1)
def test_simple(self): def test_simple(self):
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck) self.cx.backup(bck)
self.verify_backup(bck) self.verify_backup(bck)
@ -67,7 +69,7 @@ def test_progress(self):
def progress(status, remaining, total): def progress(status, remaining, total):
journal.append(status) journal.append(status)
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck, pages=1, progress=progress) self.cx.backup(bck, pages=1, progress=progress)
self.verify_backup(bck) self.verify_backup(bck)
@ -81,7 +83,7 @@ def test_progress_all_pages_at_once_1(self):
def progress(status, remaining, total): def progress(status, remaining, total):
journal.append(remaining) journal.append(remaining)
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck, progress=progress) self.cx.backup(bck, progress=progress)
self.verify_backup(bck) self.verify_backup(bck)
@ -94,7 +96,7 @@ def test_progress_all_pages_at_once_2(self):
def progress(status, remaining, total): def progress(status, remaining, total):
journal.append(remaining) journal.append(remaining)
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck, pages=-1, progress=progress) self.cx.backup(bck, pages=-1, progress=progress)
self.verify_backup(bck) self.verify_backup(bck)
@ -103,7 +105,7 @@ def progress(status, remaining, total):
def test_non_callable_progress(self): def test_non_callable_progress(self):
with self.assertRaises(TypeError) as cm: with self.assertRaises(TypeError) as cm:
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck, pages=1, progress='bar') self.cx.backup(bck, pages=1, progress='bar')
self.assertEqual(str(cm.exception), 'progress argument must be a callable') self.assertEqual(str(cm.exception), 'progress argument must be a callable')
@ -116,7 +118,7 @@ def progress(status, remaining, total):
self.cx.commit() self.cx.commit()
journal.append(remaining) journal.append(remaining)
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck, pages=1, progress=progress) self.cx.backup(bck, pages=1, progress=progress)
self.verify_backup(bck) self.verify_backup(bck)
@ -140,12 +142,12 @@ def progress(status, remaining, total):
self.assertEqual(str(err.exception), 'nearly out of space') self.assertEqual(str(err.exception), 'nearly out of space')
def test_database_source_name(self): def test_database_source_name(self):
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck, name='main') self.cx.backup(bck, name='main')
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck, name='temp') self.cx.backup(bck, name='temp')
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck, name='non-existing') self.cx.backup(bck, name='non-existing')
self.assertIn("unknown database", str(cm.exception)) self.assertIn("unknown database", str(cm.exception))
@ -153,7 +155,7 @@ def test_database_source_name(self):
self.cx.execute('CREATE TABLE attached_db.foo (key INTEGER)') self.cx.execute('CREATE TABLE attached_db.foo (key INTEGER)')
self.cx.executemany('INSERT INTO attached_db.foo (key) VALUES (?)', [(3,), (4,)]) self.cx.executemany('INSERT INTO attached_db.foo (key) VALUES (?)', [(3,), (4,)])
self.cx.commit() self.cx.commit()
with sqlite.connect(':memory:') as bck: with memory_database() as bck:
self.cx.backup(bck, name='attached_db') self.cx.backup(bck, name='attached_db')
self.verify_backup(bck) self.verify_backup(bck)

View File

@ -33,26 +33,13 @@
SHORT_TIMEOUT, check_disallow_instantiation, requires_subprocess, SHORT_TIMEOUT, check_disallow_instantiation, requires_subprocess,
is_emscripten, is_wasi is_emscripten, is_wasi
) )
from test.support import gc_collect
from test.support import threading_helper from test.support import threading_helper
from _testcapi import INT_MAX, ULLONG_MAX from _testcapi import INT_MAX, ULLONG_MAX
from os import SEEK_SET, SEEK_CUR, SEEK_END from os import SEEK_SET, SEEK_CUR, SEEK_END
from test.support.os_helper import TESTFN, TESTFN_UNDECODABLE, unlink, temp_dir, FakePath from test.support.os_helper import TESTFN, TESTFN_UNDECODABLE, unlink, temp_dir, FakePath
from .util import memory_database, cx_limit
# Helper for temporary memory databases
def memory_database(*args, **kwargs):
cx = sqlite.connect(":memory:", *args, **kwargs)
return contextlib.closing(cx)
# Temporarily limit a database connection parameter
@contextlib.contextmanager
def cx_limit(cx, category=sqlite.SQLITE_LIMIT_SQL_LENGTH, limit=128):
try:
_prev = cx.setlimit(category, limit)
yield limit
finally:
cx.setlimit(category, _prev)
class ModuleTests(unittest.TestCase): class ModuleTests(unittest.TestCase):
@ -326,7 +313,7 @@ def test_extended_error_code_on_exception(self):
self.assertEqual(exc.sqlite_errorname, "SQLITE_CONSTRAINT_CHECK") self.assertEqual(exc.sqlite_errorname, "SQLITE_CONSTRAINT_CHECK")
def test_disallow_instantiation(self): def test_disallow_instantiation(self):
cx = sqlite.connect(":memory:") with memory_database() as cx:
check_disallow_instantiation(self, type(cx("select 1"))) check_disallow_instantiation(self, type(cx("select 1")))
check_disallow_instantiation(self, sqlite.Blob) check_disallow_instantiation(self, sqlite.Blob)
@ -342,6 +329,7 @@ def setUp(self):
cu = self.cx.cursor() cu = self.cx.cursor()
cu.execute("create table test(id integer primary key, name text)") cu.execute("create table test(id integer primary key, name text)")
cu.execute("insert into test(name) values (?)", ("foo",)) cu.execute("insert into test(name) values (?)", ("foo",))
cu.close()
def tearDown(self): def tearDown(self):
self.cx.close() self.cx.close()
@ -412,7 +400,7 @@ def test_exceptions(self):
def test_in_transaction(self): def test_in_transaction(self):
# Can't use db from setUp because we want to test initial state. # Can't use db from setUp because we want to test initial state.
cx = sqlite.connect(":memory:") with memory_database() as cx:
cu = cx.cursor() cu = cx.cursor()
self.assertEqual(cx.in_transaction, False) self.assertEqual(cx.in_transaction, False)
cu.execute("create table transactiontest(id integer primary key, name text)") cu.execute("create table transactiontest(id integer primary key, name text)")
@ -427,6 +415,7 @@ def test_in_transaction(self):
cu.execute("select name from transactiontest where name=?", ["foo"]) cu.execute("select name from transactiontest where name=?", ["foo"])
row = cu.fetchone() row = cu.fetchone()
self.assertEqual(cx.in_transaction, False) self.assertEqual(cx.in_transaction, False)
cu.close()
def test_in_transaction_ro(self): def test_in_transaction_ro(self):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
@ -450,10 +439,9 @@ def test_connection_exceptions(self):
self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc)) self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc))
def test_interrupt_on_closed_db(self): def test_interrupt_on_closed_db(self):
cx = sqlite.connect(":memory:") self.cx.close()
cx.close()
with self.assertRaises(sqlite.ProgrammingError): with self.assertRaises(sqlite.ProgrammingError):
cx.interrupt() self.cx.interrupt()
def test_interrupt(self): def test_interrupt(self):
self.assertIsNone(self.cx.interrupt()) self.assertIsNone(self.cx.interrupt())
@ -521,8 +509,7 @@ def test_connection_init_good_isolation_levels(self):
self.assertEqual(cx.isolation_level, level) self.assertEqual(cx.isolation_level, level)
def test_connection_reinit(self): def test_connection_reinit(self):
db = ":memory:" with memory_database() as cx:
cx = sqlite.connect(db)
cx.text_factory = bytes cx.text_factory = bytes
cx.row_factory = sqlite.Row cx.row_factory = sqlite.Row
cu = cx.cursor() cu = cx.cursor()
@ -535,7 +522,7 @@ def test_connection_reinit(self):
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], [b"0", b"1"]) self.assertEqual([r[0] for r in rows], [b"0", b"1"])
cx.__init__(db) cx.__init__(":memory:")
cx.execute("create table foo (bar)") cx.execute("create table foo (bar)")
cx.executemany("insert into foo (bar) values (?)", cx.executemany("insert into foo (bar) values (?)",
((v,) for v in ("a", "b", "c", "d"))) ((v,) for v in ("a", "b", "c", "d")))
@ -544,6 +531,7 @@ def test_connection_reinit(self):
rows = [r for r in cu.fetchall()] rows = [r for r in cu.fetchall()]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], ["2", "3"]) self.assertEqual([r[0] for r in rows], ["2", "3"])
cu.close()
def test_connection_bad_reinit(self): def test_connection_bad_reinit(self):
cx = sqlite.connect(":memory:") cx = sqlite.connect(":memory:")
@ -591,11 +579,11 @@ def test_connect_positional_arguments(self):
"parameters in Python 3.15." "parameters in Python 3.15."
) )
with self.assertWarnsRegex(DeprecationWarning, regex) as cm: with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
sqlite.connect(":memory:", 1.0) cx = sqlite.connect(":memory:", 1.0)
cx.close()
self.assertEqual(cm.filename, __file__) self.assertEqual(cm.filename, __file__)
class UninitialisedConnectionTests(unittest.TestCase): class UninitialisedConnectionTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.cx = sqlite.Connection.__new__(sqlite.Connection) self.cx = sqlite.Connection.__new__(sqlite.Connection)
@ -1571,7 +1559,7 @@ def run(con, err):
except sqlite.Error: except sqlite.Error:
err.append("multi-threading not allowed") err.append("multi-threading not allowed")
con = sqlite.connect(":memory:", check_same_thread=False) with memory_database(check_same_thread=False) as con:
err = [] err = []
t = threading.Thread(target=run, kwargs={"con": con, "err": err}) t = threading.Thread(target=run, kwargs={"con": con, "err": err})
t.start() t.start()
@ -1602,9 +1590,16 @@ def test_binary(self):
b = sqlite.Binary(b"\0'") b = sqlite.Binary(b"\0'")
class ExtensionTests(unittest.TestCase): class ExtensionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
def tearDown(self):
self.cur.close()
self.con.close()
def test_script_string_sql(self): def test_script_string_sql(self):
con = sqlite.connect(":memory:") cur = self.cur
cur = con.cursor()
cur.executescript(""" cur.executescript("""
-- bla bla -- bla bla
/* a stupid comment */ /* a stupid comment */
@ -1616,39 +1611,39 @@ def test_script_string_sql(self):
self.assertEqual(res, 5) self.assertEqual(res, 5)
def test_script_syntax_error(self): def test_script_syntax_error(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
cur.executescript("create table test(x); asdf; create table test2(x)") self.cur.executescript("""
CREATE TABLE test(x);
asdf;
CREATE TABLE test2(x)
""")
def test_script_error_normal(self): def test_script_error_normal(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
cur.executescript("create table test(sadfsadfdsa); select foo from hurz;") self.cur.executescript("""
CREATE TABLE test(sadfsadfdsa);
SELECT foo FROM hurz;
""")
def test_cursor_executescript_as_bytes(self): def test_cursor_executescript_as_bytes(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
cur.executescript(b"create table test(foo); insert into test(foo) values (5);") self.cur.executescript(b"""
CREATE TABLE test(foo);
INSERT INTO test(foo) VALUES (5);
""")
def test_cursor_executescript_with_null_characters(self): def test_cursor_executescript_with_null_characters(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
cur.executescript(""" self.cur.executescript("""
create table a(i);\0 CREATE TABLE a(i);\0
insert into a(i) values (5); INSERT INTO a(i) VALUES (5);
""") """)
def test_cursor_executescript_with_surrogates(self): def test_cursor_executescript_with_surrogates(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(UnicodeEncodeError): with self.assertRaises(UnicodeEncodeError):
cur.executescript(""" self.cur.executescript("""
create table a(s); CREATE TABLE a(s);
insert into a(s) values ('\ud8ff'); INSERT INTO a(s) VALUES ('\ud8ff');
""") """)
def test_cursor_executescript_too_large_script(self): def test_cursor_executescript_too_large_script(self):
@ -1659,19 +1654,18 @@ def test_cursor_executescript_too_large_script(self):
cx.executescript("select 'too large'".ljust(lim+1)) cx.executescript("select 'too large'".ljust(lim+1))
def test_cursor_executescript_tx_control(self): def test_cursor_executescript_tx_control(self):
con = sqlite.connect(":memory:") con = self.con
con.execute("begin") con.execute("begin")
self.assertTrue(con.in_transaction) self.assertTrue(con.in_transaction)
con.executescript("select 1") con.executescript("select 1")
self.assertFalse(con.in_transaction) self.assertFalse(con.in_transaction)
def test_connection_execute(self): def test_connection_execute(self):
con = sqlite.connect(":memory:") result = self.con.execute("select 5").fetchone()[0]
result = con.execute("select 5").fetchone()[0]
self.assertEqual(result, 5, "Basic test of Connection.execute") self.assertEqual(result, 5, "Basic test of Connection.execute")
def test_connection_executemany(self): def test_connection_executemany(self):
con = sqlite.connect(":memory:") con = self.con
con.execute("create table test(foo)") con.execute("create table test(foo)")
con.executemany("insert into test(foo) values (?)", [(3,), (4,)]) con.executemany("insert into test(foo) values (?)", [(3,), (4,)])
result = con.execute("select foo from test order by foo").fetchall() result = con.execute("select foo from test order by foo").fetchall()
@ -1679,47 +1673,44 @@ def test_connection_executemany(self):
self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany") self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany")
def test_connection_executescript(self): def test_connection_executescript(self):
con = sqlite.connect(":memory:") con = self.con
con.executescript("create table test(foo); insert into test(foo) values (5);") con.executescript("""
CREATE TABLE test(foo);
INSERT INTO test(foo) VALUES (5);
""")
result = con.execute("select foo from test").fetchone()[0] result = con.execute("select foo from test").fetchone()[0]
self.assertEqual(result, 5, "Basic test of Connection.executescript") self.assertEqual(result, 5, "Basic test of Connection.executescript")
class ClosedConTests(unittest.TestCase): class ClosedConTests(unittest.TestCase):
def check(self, fn, *args, **kwds):
regex = "Cannot operate on a closed database."
with self.assertRaisesRegex(sqlite.ProgrammingError, regex):
fn(*args, **kwds)
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
self.con.close()
def test_closed_con_cursor(self): def test_closed_con_cursor(self):
con = sqlite.connect(":memory:") self.check(self.con.cursor)
con.close()
with self.assertRaises(sqlite.ProgrammingError):
cur = con.cursor()
def test_closed_con_commit(self): def test_closed_con_commit(self):
con = sqlite.connect(":memory:") self.check(self.con.commit)
con.close()
with self.assertRaises(sqlite.ProgrammingError):
con.commit()
def test_closed_con_rollback(self): def test_closed_con_rollback(self):
con = sqlite.connect(":memory:") self.check(self.con.rollback)
con.close()
with self.assertRaises(sqlite.ProgrammingError):
con.rollback()
def test_closed_cur_execute(self): def test_closed_cur_execute(self):
con = sqlite.connect(":memory:") self.check(self.cur.execute, "select 4")
cur = con.cursor()
con.close()
with self.assertRaises(sqlite.ProgrammingError):
cur.execute("select 4")
def test_closed_create_function(self): def test_closed_create_function(self):
con = sqlite.connect(":memory:") def f(x):
con.close() return 17
def f(x): return 17 self.check(self.con.create_function, "foo", 1, f)
with self.assertRaises(sqlite.ProgrammingError):
con.create_function("foo", 1, f)
def test_closed_create_aggregate(self): def test_closed_create_aggregate(self):
con = sqlite.connect(":memory:")
con.close()
class Agg: class Agg:
def __init__(self): def __init__(self):
pass pass
@ -1727,29 +1718,21 @@ def step(self, x):
pass pass
def finalize(self): def finalize(self):
return 17 return 17
with self.assertRaises(sqlite.ProgrammingError): self.check(self.con.create_aggregate, "foo", 1, Agg)
con.create_aggregate("foo", 1, Agg)
def test_closed_set_authorizer(self): def test_closed_set_authorizer(self):
con = sqlite.connect(":memory:")
con.close()
def authorizer(*args): def authorizer(*args):
return sqlite.DENY return sqlite.DENY
with self.assertRaises(sqlite.ProgrammingError): self.check(self.con.set_authorizer, authorizer)
con.set_authorizer(authorizer)
def test_closed_set_progress_callback(self): def test_closed_set_progress_callback(self):
con = sqlite.connect(":memory:") def progress():
con.close() pass
def progress(): pass self.check(self.con.set_progress_handler, progress, 100)
with self.assertRaises(sqlite.ProgrammingError):
con.set_progress_handler(progress, 100)
def test_closed_call(self): def test_closed_call(self):
con = sqlite.connect(":memory:") self.check(self.con)
con.close()
with self.assertRaises(sqlite.ProgrammingError):
con()
class ClosedCurTests(unittest.TestCase): class ClosedCurTests(unittest.TestCase):
def test_closed(self): def test_closed(self):

View File

@ -2,16 +2,12 @@
import unittest import unittest
import sqlite3 as sqlite import sqlite3 as sqlite
from .test_dbapi import memory_database
from .util import memory_database
from .util import MemoryDatabaseMixin
class DumpTests(unittest.TestCase): class DumpTests(MemoryDatabaseMixin, unittest.TestCase):
def setUp(self):
self.cx = sqlite.connect(":memory:")
self.cu = self.cx.cursor()
def tearDown(self):
self.cx.close()
def test_table_dump(self): def test_table_dump(self):
expected_sqls = [ expected_sqls = [

View File

@ -24,6 +24,9 @@
import sqlite3 as sqlite import sqlite3 as sqlite
from collections.abc import Sequence from collections.abc import Sequence
from .util import memory_database
from .util import MemoryDatabaseMixin
def dict_factory(cursor, row): def dict_factory(cursor, row):
d = {} d = {}
@ -45,10 +48,12 @@ class OkFactory(sqlite.Connection):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
sqlite.Connection.__init__(self, *args, **kwargs) sqlite.Connection.__init__(self, *args, **kwargs)
for factory in DefectFactory, OkFactory: with memory_database(factory=OkFactory) as con:
with self.subTest(factory=factory): self.assertIsInstance(con, OkFactory)
con = sqlite.connect(":memory:", factory=factory) regex = "Base Connection.__init__ not called."
self.assertIsInstance(con, factory) with self.assertRaisesRegex(sqlite.ProgrammingError, regex):
with memory_database(factory=DefectFactory) as con:
self.assertIsInstance(con, DefectFactory)
def test_connection_factory_relayed_call(self): def test_connection_factory_relayed_call(self):
# gh-95132: keyword args must not be passed as positional args # gh-95132: keyword args must not be passed as positional args
@ -57,7 +62,7 @@ def __init__(self, *args, **kwargs):
kwargs["isolation_level"] = None kwargs["isolation_level"] = None
super(Factory, self).__init__(*args, **kwargs) super(Factory, self).__init__(*args, **kwargs)
con = sqlite.connect(":memory:", factory=Factory) with memory_database(factory=Factory) as con:
self.assertIsNone(con.isolation_level) self.assertIsNone(con.isolation_level)
self.assertIsInstance(con, Factory) self.assertIsInstance(con, Factory)
@ -74,18 +79,13 @@ def __init__(self, *args, **kwargs):
r"parameters in Python 3.15." r"parameters in Python 3.15."
) )
with self.assertWarnsRegex(DeprecationWarning, regex) as cm: with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
con = sqlite.connect(":memory:", 5.0, 0, None, True, Factory) with memory_database(5.0, 0, None, True, Factory) as con:
self.assertEqual(cm.filename, __file__)
self.assertIsNone(con.isolation_level) self.assertIsNone(con.isolation_level)
self.assertIsInstance(con, Factory) self.assertIsInstance(con, Factory)
self.assertEqual(cm.filename, __file__)
class CursorFactoryTests(unittest.TestCase): class CursorFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def tearDown(self):
self.con.close()
def test_is_instance(self): def test_is_instance(self):
cur = self.con.cursor() cur = self.con.cursor()
@ -103,9 +103,8 @@ def test_invalid_factory(self):
# invalid callable returning non-cursor # invalid callable returning non-cursor
self.assertRaises(TypeError, self.con.cursor, lambda con: None) self.assertRaises(TypeError, self.con.cursor, lambda con: None)
class RowFactoryTestsBackwardsCompat(unittest.TestCase):
def setUp(self): class RowFactoryTestsBackwardsCompat(MemoryDatabaseMixin, unittest.TestCase):
self.con = sqlite.connect(":memory:")
def test_is_produced_by_factory(self): def test_is_produced_by_factory(self):
cur = self.con.cursor(factory=MyCursor) cur = self.con.cursor(factory=MyCursor)
@ -114,12 +113,8 @@ def test_is_produced_by_factory(self):
self.assertIsInstance(row, dict) self.assertIsInstance(row, dict)
cur.close() cur.close()
def tearDown(self):
self.con.close()
class RowFactoryTests(unittest.TestCase): class RowFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def test_custom_factory(self): def test_custom_factory(self):
self.con.row_factory = lambda cur, row: list(row) self.con.row_factory = lambda cur, row: list(row)
@ -265,12 +260,8 @@ class FakeCursor(str):
self.assertRaises(TypeError, self.con.cursor, FakeCursor) self.assertRaises(TypeError, self.con.cursor, FakeCursor)
self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ()) self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ())
def tearDown(self):
self.con.close()
class TextFactoryTests(unittest.TestCase): class TextFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def test_unicode(self): def test_unicode(self):
austria = "Österreich" austria = "Österreich"
@ -291,15 +282,17 @@ def test_custom(self):
self.assertEqual(type(row[0]), str, "type of row[0] must be unicode") self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
self.assertTrue(row[0].endswith("reich"), "column must contain original data") self.assertTrue(row[0].endswith("reich"), "column must contain original data")
def tearDown(self):
self.con.close()
class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase): class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
def setUp(self): def setUp(self):
self.con = sqlite.connect(":memory:") self.con = sqlite.connect(":memory:")
self.con.execute("create table test (value text)") self.con.execute("create table test (value text)")
self.con.execute("insert into test (value) values (?)", ("a\x00b",)) self.con.execute("insert into test (value) values (?)", ("a\x00b",))
def tearDown(self):
self.con.close()
def test_string(self): def test_string(self):
# text_factory defaults to str # text_factory defaults to str
row = self.con.execute("select value from test").fetchone() row = self.con.execute("select value from test").fetchone()
@ -325,9 +318,6 @@ def test_custom(self):
self.assertIs(type(row[0]), bytes) self.assertIs(type(row[0]), bytes)
self.assertEqual(row[0], b"a\x00b") self.assertEqual(row[0], b"a\x00b")
def tearDown(self):
self.con.close()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -26,34 +26,31 @@
from test.support.os_helper import TESTFN, unlink from test.support.os_helper import TESTFN, unlink
from test.test_sqlite3.test_dbapi import memory_database, cx_limit from .util import memory_database, cx_limit, with_tracebacks
from test.test_sqlite3.test_userfunctions import with_tracebacks from .util import MemoryDatabaseMixin
class CollationTests(unittest.TestCase): class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
def test_create_collation_not_string(self): def test_create_collation_not_string(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
con.create_collation(None, lambda x, y: (x > y) - (x < y)) self.con.create_collation(None, lambda x, y: (x > y) - (x < y))
def test_create_collation_not_callable(self): def test_create_collation_not_callable(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError) as cm: with self.assertRaises(TypeError) as cm:
con.create_collation("X", 42) self.con.create_collation("X", 42)
self.assertEqual(str(cm.exception), 'parameter must be callable') self.assertEqual(str(cm.exception), 'parameter must be callable')
def test_create_collation_not_ascii(self): def test_create_collation_not_ascii(self):
con = sqlite.connect(":memory:") self.con.create_collation("collä", lambda x, y: (x > y) - (x < y))
con.create_collation("collä", lambda x, y: (x > y) - (x < y))
def test_create_collation_bad_upper(self): def test_create_collation_bad_upper(self):
class BadUpperStr(str): class BadUpperStr(str):
def upper(self): def upper(self):
return None return None
con = sqlite.connect(":memory:")
mycoll = lambda x, y: -((x > y) - (x < y)) mycoll = lambda x, y: -((x > y) - (x < y))
con.create_collation(BadUpperStr("mycoll"), mycoll) self.con.create_collation(BadUpperStr("mycoll"), mycoll)
result = con.execute(""" result = self.con.execute("""
select x from ( select x from (
select 'a' as x select 'a' as x
union union
@ -68,8 +65,7 @@ def mycoll(x, y):
# reverse order # reverse order
return -((x > y) - (x < y)) return -((x > y) - (x < y))
con = sqlite.connect(":memory:") self.con.create_collation("mycoll", mycoll)
con.create_collation("mycoll", mycoll)
sql = """ sql = """
select x from ( select x from (
select 'a' as x select 'a' as x
@ -79,21 +75,20 @@ def mycoll(x, y):
select 'c' as x select 'c' as x
) order by x collate mycoll ) order by x collate mycoll
""" """
result = con.execute(sql).fetchall() result = self.con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)], self.assertEqual(result, [('c',), ('b',), ('a',)],
msg='the expected order was not returned') msg='the expected order was not returned')
con.create_collation("mycoll", None) self.con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
result = con.execute(sql).fetchall() result = self.con.execute(sql).fetchall()
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
def test_collation_returns_large_integer(self): def test_collation_returns_large_integer(self):
def mycoll(x, y): def mycoll(x, y):
# reverse order # reverse order
return -((x > y) - (x < y)) * 2**32 return -((x > y) - (x < y)) * 2**32
con = sqlite.connect(":memory:") self.con.create_collation("mycoll", mycoll)
con.create_collation("mycoll", mycoll)
sql = """ sql = """
select x from ( select x from (
select 'a' as x select 'a' as x
@ -103,7 +98,7 @@ def mycoll(x, y):
select 'c' as x select 'c' as x
) order by x collate mycoll ) order by x collate mycoll
""" """
result = con.execute(sql).fetchall() result = self.con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)], self.assertEqual(result, [('c',), ('b',), ('a',)],
msg="the expected order was not returned") msg="the expected order was not returned")
@ -112,7 +107,7 @@ def test_collation_register_twice(self):
Register two different collation functions under the same name. Register two different collation functions under the same name.
Verify that the last one is actually used. Verify that the last one is actually used.
""" """
con = sqlite.connect(":memory:") con = self.con
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
result = con.execute(""" result = con.execute("""
@ -126,25 +121,26 @@ def test_deregister_collation(self):
Register a collation, then deregister it. Make sure an error is raised if we try Register a collation, then deregister it. Make sure an error is raised if we try
to use it. to use it.
""" """
con = sqlite.connect(":memory:") con = self.con
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", None) con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
class ProgressTests(unittest.TestCase):
class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
def test_progress_handler_used(self): def test_progress_handler_used(self):
""" """
Test that the progress handler is invoked once it is set. Test that the progress handler is invoked once it is set.
""" """
con = sqlite.connect(":memory:")
progress_calls = [] progress_calls = []
def progress(): def progress():
progress_calls.append(None) progress_calls.append(None)
return 0 return 0
con.set_progress_handler(progress, 1) self.con.set_progress_handler(progress, 1)
con.execute(""" self.con.execute("""
create table foo(a, b) create table foo(a, b)
""") """)
self.assertTrue(progress_calls) self.assertTrue(progress_calls)
@ -153,7 +149,7 @@ def test_opcode_count(self):
""" """
Test that the opcode argument is respected. Test that the opcode argument is respected.
""" """
con = sqlite.connect(":memory:") con = self.con
progress_calls = [] progress_calls = []
def progress(): def progress():
progress_calls.append(None) progress_calls.append(None)
@ -176,11 +172,10 @@ def test_cancel_operation(self):
""" """
Test that returning a non-zero value stops the operation in progress. Test that returning a non-zero value stops the operation in progress.
""" """
con = sqlite.connect(":memory:")
def progress(): def progress():
return 1 return 1
con.set_progress_handler(progress, 1) self.con.set_progress_handler(progress, 1)
curs = con.cursor() curs = self.con.cursor()
self.assertRaises( self.assertRaises(
sqlite.OperationalError, sqlite.OperationalError,
curs.execute, curs.execute,
@ -190,7 +185,7 @@ def test_clear_handler(self):
""" """
Test that setting the progress handler to None clears the previously set handler. Test that setting the progress handler to None clears the previously set handler.
""" """
con = sqlite.connect(":memory:") con = self.con
action = 0 action = 0
def progress(): def progress():
nonlocal action nonlocal action
@ -203,31 +198,30 @@ def progress():
@with_tracebacks(ZeroDivisionError, name="bad_progress") @with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler(self): def test_error_in_progress_handler(self):
con = sqlite.connect(":memory:")
def bad_progress(): def bad_progress():
1 / 0 1 / 0
con.set_progress_handler(bad_progress, 1) self.con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
con.execute(""" self.con.execute("""
create table foo(a, b) create table foo(a, b)
""") """)
@with_tracebacks(ZeroDivisionError, name="bad_progress") @with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler_result(self): def test_error_in_progress_handler_result(self):
con = sqlite.connect(":memory:")
class BadBool: class BadBool:
def __bool__(self): def __bool__(self):
1 / 0 1 / 0
def bad_progress(): def bad_progress():
return BadBool() return BadBool()
con.set_progress_handler(bad_progress, 1) self.con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
con.execute(""" self.con.execute("""
create table foo(a, b) create table foo(a, b)
""") """)
class TraceCallbackTests(unittest.TestCase): class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
@contextlib.contextmanager @contextlib.contextmanager
def check_stmt_trace(self, cx, expected): def check_stmt_trace(self, cx, expected):
try: try:
@ -242,12 +236,11 @@ def test_trace_callback_used(self):
""" """
Test that the trace callback is invoked once it is set. Test that the trace callback is invoked once it is set.
""" """
con = sqlite.connect(":memory:")
traced_statements = [] traced_statements = []
def trace(statement): def trace(statement):
traced_statements.append(statement) traced_statements.append(statement)
con.set_trace_callback(trace) self.con.set_trace_callback(trace)
con.execute("create table foo(a, b)") self.con.execute("create table foo(a, b)")
self.assertTrue(traced_statements) self.assertTrue(traced_statements)
self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
@ -255,7 +248,7 @@ def test_clear_trace_callback(self):
""" """
Test that setting the trace callback to None clears the previously set callback. Test that setting the trace callback to None clears the previously set callback.
""" """
con = sqlite.connect(":memory:") con = self.con
traced_statements = [] traced_statements = []
def trace(statement): def trace(statement):
traced_statements.append(statement) traced_statements.append(statement)
@ -269,7 +262,7 @@ def test_unicode_content(self):
Test that the statement can contain unicode literals. Test that the statement can contain unicode literals.
""" """
unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
con = sqlite.connect(":memory:") con = self.con
traced_statements = [] traced_statements = []
def trace(statement): def trace(statement):
traced_statements.append(statement) traced_statements.append(statement)

View File

@ -28,15 +28,12 @@
from test import support from test import support
from unittest.mock import patch from unittest.mock import patch
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
from .util import memory_database, cx_limit
from .util import MemoryDatabaseMixin
class RegressionTests(unittest.TestCase): class RegressionTests(MemoryDatabaseMixin, unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def tearDown(self):
self.con.close()
def test_pragma_user_version(self): def test_pragma_user_version(self):
# This used to crash pysqlite because this pragma command returns NULL for the column name # This used to crash pysqlite because this pragma command returns NULL for the column name
@ -45,19 +42,15 @@ def test_pragma_user_version(self):
def test_pragma_schema_version(self): def test_pragma_schema_version(self):
# This still crashed pysqlite <= 2.2.1 # This still crashed pysqlite <= 2.2.1
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES) with memory_database(detect_types=sqlite.PARSE_COLNAMES) as con:
try:
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("pragma schema_version") cur.execute("pragma schema_version")
finally:
cur.close()
con.close()
def test_statement_reset(self): def test_statement_reset(self):
# pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are # pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are
# reset before a rollback, but only those that are still in the # reset before a rollback, but only those that are still in the
# statement cache. The others are not accessible from the connection object. # statement cache. The others are not accessible from the connection object.
con = sqlite.connect(":memory:", cached_statements=5) with memory_database(cached_statements=5) as con:
cursors = [con.cursor() for x in range(5)] cursors = [con.cursor() for x in range(5)]
cursors[0].execute("create table test(x)") cursors[0].execute("create table test(x)")
for i in range(10): for i in range(10):
@ -81,17 +74,15 @@ def test_statement_finalization_on_close_db(self):
# cache when closing the database. statements that were still # cache when closing the database. statements that were still
# referenced in cursors weren't closed and could provoke " # referenced in cursors weren't closed and could provoke "
# "OperationalError: Unable to close due to unfinalised statements". # "OperationalError: Unable to close due to unfinalised statements".
con = sqlite.connect(":memory:")
cursors = [] cursors = []
# default statement cache size is 100 # default statement cache size is 100
for i in range(105): for i in range(105):
cur = con.cursor() cur = self.con.cursor()
cursors.append(cur) cursors.append(cur)
cur.execute("select 1 x union select " + str(i)) cur.execute("select 1 x union select " + str(i))
con.close()
def test_on_conflict_rollback(self): def test_on_conflict_rollback(self):
con = sqlite.connect(":memory:") con = self.con
con.execute("create table foo(x, unique(x) on conflict rollback)") con.execute("create table foo(x, unique(x) on conflict rollback)")
con.execute("insert into foo(x) values (1)") con.execute("insert into foo(x) values (1)")
try: try:
@ -126,7 +117,7 @@ def test_type_map_usage(self):
a statement. This test exhibits the problem. a statement. This test exhibits the problem.
""" """
SELECT = "select * from foo" SELECT = "select * from foo"
con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES) with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
cur = con.cursor() cur = con.cursor()
cur.execute("create table foo(bar timestamp)") cur.execute("create table foo(bar timestamp)")
with self.assertWarnsRegex(DeprecationWarning, "adapter"): with self.assertWarnsRegex(DeprecationWarning, "adapter"):
@ -144,7 +135,7 @@ def __conform__(self, protocol):
parameters.clear() parameters.clear()
return "..." return "..."
parameters = [X(), 0] parameters = [X(), 0]
con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES) with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
con.execute("create table foo(bar X, baz integer)") con.execute("create table foo(bar X, baz integer)")
# Should not crash # Should not crash
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
@ -173,7 +164,7 @@ def upper(self):
def __del__(self): def __del__(self):
con.isolation_level = "" con.isolation_level = ""
con = sqlite.connect(":memory:") con = self.con
con.isolation_level = None con.isolation_level = None
for level in "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE": for level in "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE":
with self.subTest(level=level): with self.subTest(level=level):
@ -204,8 +195,7 @@ class Cursor(sqlite.Cursor):
def __init__(self, con): def __init__(self, con):
pass pass
con = sqlite.connect(":memory:") cur = Cursor(self.con)
cur = Cursor(con)
with self.assertRaises(sqlite.ProgrammingError): with self.assertRaises(sqlite.ProgrammingError):
cur.execute("select 4+5").fetchall() cur.execute("select 4+5").fetchall()
with self.assertRaisesRegex(sqlite.ProgrammingError, with self.assertRaisesRegex(sqlite.ProgrammingError,
@ -238,7 +228,9 @@ def test_auto_commit(self):
2.5.3 introduced a regression so that these could no longer 2.5.3 introduced a regression so that these could no longer
be created. be created.
""" """
con = sqlite.connect(":memory:", isolation_level=None) with memory_database(isolation_level=None) as con:
self.assertIsNone(con.isolation_level)
self.assertFalse(con.in_transaction)
def test_pragma_autocommit(self): def test_pragma_autocommit(self):
""" """
@ -273,9 +265,7 @@ def test_recursive_cursor_use(self):
Recursively using a cursor, such as when reusing it from a generator led to segfaults. Recursively using a cursor, such as when reusing it from a generator led to segfaults.
Now we catch recursive cursor usage and raise a ProgrammingError. Now we catch recursive cursor usage and raise a ProgrammingError.
""" """
con = sqlite.connect(":memory:") cur = self.con.cursor()
cur = con.cursor()
cur.execute("create table a (bar)") cur.execute("create table a (bar)")
cur.execute("create table b (baz)") cur.execute("create table b (baz)")
@ -295,7 +285,7 @@ def test_convert_timestamp_microsecond_padding(self):
since the microsecond string "456" actually represents "456000". since the microsecond string "456" actually represents "456000".
""" """
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES) with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
cur = con.cursor() cur = con.cursor()
cur.execute("CREATE TABLE t (x TIMESTAMP)") cur.execute("CREATE TABLE t (x TIMESTAMP)")
@ -316,8 +306,9 @@ def test_convert_timestamp_microsecond_padding(self):
def test_invalid_isolation_level_type(self): def test_invalid_isolation_level_type(self):
# isolation level is a string, not an integer # isolation level is a string, not an integer
self.assertRaises(TypeError, regex = "isolation_level must be str or None"
sqlite.connect, ":memory:", isolation_level=123) with self.assertRaisesRegex(TypeError, regex):
memory_database(isolation_level=123).__enter__()
def test_null_character(self): def test_null_character(self):
@ -333,7 +324,7 @@ def test_null_character(self):
cur.execute, query) cur.execute, query)
def test_surrogates(self): def test_surrogates(self):
con = sqlite.connect(":memory:") con = self.con
self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'") self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
self.assertRaises(UnicodeEncodeError, con, "select '\udcff'") self.assertRaises(UnicodeEncodeError, con, "select '\udcff'")
cur = con.cursor() cur = con.cursor()
@ -359,7 +350,7 @@ def test_commit_cursor_reset(self):
to return rows multiple times when fetched from cursors to return rows multiple times when fetched from cursors
after commit. See issues 10513 and 23129 for details. after commit. See issues 10513 and 23129 for details.
""" """
con = sqlite.connect(":memory:") con = self.con
con.executescript(""" con.executescript("""
create table t(c); create table t(c);
create table t2(c); create table t2(c);
@ -391,10 +382,9 @@ def test_bpo31770(self):
""" """
def callback(*args): def callback(*args):
pass pass
con = sqlite.connect(":memory:") cur = sqlite.Cursor(self.con)
cur = sqlite.Cursor(con)
ref = weakref.ref(cur, callback) ref = weakref.ref(cur, callback)
cur.__init__(con) cur.__init__(self.con)
del cur del cur
# The interpreter shouldn't crash when ref is collected. # The interpreter shouldn't crash when ref is collected.
del ref del ref
@ -425,6 +415,7 @@ def test_return_empty_bytestring(self):
def test_table_lock_cursor_replace_stmt(self): def test_table_lock_cursor_replace_stmt(self):
with memory_database() as con: with memory_database() as con:
con = self.con
cur = con.cursor() cur = con.cursor()
cur.execute("create table t(t)") cur.execute("create table t(t)")
cur.executemany("insert into t values(?)", cur.executemany("insert into t values(?)",

View File

@ -28,7 +28,8 @@
from test.support.os_helper import TESTFN, unlink from test.support.os_helper import TESTFN, unlink
from test.support.script_helper import assert_python_ok from test.support.script_helper import assert_python_ok
from test.test_sqlite3.test_dbapi import memory_database from .util import memory_database
from .util import MemoryDatabaseMixin
TIMEOUT = LOOPBACK_TIMEOUT / 10 TIMEOUT = LOOPBACK_TIMEOUT / 10
@ -132,7 +133,7 @@ def test_locking(self):
def test_rollback_cursor_consistency(self): def test_rollback_cursor_consistency(self):
"""Check that cursors behave correctly after rollback.""" """Check that cursors behave correctly after rollback."""
con = sqlite.connect(":memory:") with memory_database() as con:
cur = con.cursor() cur = con.cursor()
cur.execute("create table test(x)") cur.execute("create table test(x)")
cur.execute("insert into test(x) values (5)") cur.execute("insert into test(x) values (5)")
@ -218,10 +219,7 @@ def test_no_duplicate_rows_after_rollback_new_query(self):
class SpecialCommandTests(unittest.TestCase): class SpecialCommandTests(MemoryDatabaseMixin, unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
def test_drop_table(self): def test_drop_table(self):
self.cur.execute("create table test(i)") self.cur.execute("create table test(i)")
@ -233,14 +231,8 @@ def test_pragma(self):
self.cur.execute("insert into test(i) values (5)") self.cur.execute("insert into test(i) values (5)")
self.cur.execute("pragma count_changes=1") self.cur.execute("pragma count_changes=1")
def tearDown(self):
self.cur.close()
self.con.close()
class TransactionalDDL(MemoryDatabaseMixin, unittest.TestCase):
class TransactionalDDL(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def test_ddl_does_not_autostart_transaction(self): def test_ddl_does_not_autostart_transaction(self):
# For backwards compatibility reasons, DDL statements should not # For backwards compatibility reasons, DDL statements should not
@ -268,9 +260,6 @@ def test_transactional_ddl(self):
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
self.con.execute("select * from test") self.con.execute("select * from test")
def tearDown(self):
self.con.close()
class IsolationLevelFromInit(unittest.TestCase): class IsolationLevelFromInit(unittest.TestCase):
CREATE = "create table t(t)" CREATE = "create table t(t)"

View File

@ -21,54 +21,15 @@
# misrepresented as being the original software. # misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution. # 3. This notice may not be removed or altered from any source distribution.
import contextlib
import functools
import io
import re
import sys import sys
import unittest import unittest
import sqlite3 as sqlite import sqlite3 as sqlite
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from test.support import bigmemtest, catch_unraisable_exception, gc_collect from test.support import bigmemtest, gc_collect
from test.test_sqlite3.test_dbapi import cx_limit from .util import cx_limit, memory_database
from .util import with_tracebacks, check_tracebacks
def with_tracebacks(exc, regex="", name=""):
"""Convenience decorator for testing callback tracebacks."""
def decorator(func):
_regex = re.compile(regex) if regex else None
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with catch_unraisable_exception() as cm:
# First, run the test with traceback enabled.
with check_tracebacks(self, cm, exc, _regex, name):
func(self, *args, **kwargs)
# Then run the test with traceback disabled.
func(self, *args, **kwargs)
return wrapper
return decorator
@contextlib.contextmanager
def check_tracebacks(self, cm, exc, regex, obj_name):
"""Convenience context manager for testing callback tracebacks."""
sqlite.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
self.assertEqual(cm.unraisable.exc_type, exc)
if regex:
msg = str(cm.unraisable.exc_value)
self.assertIsNotNone(regex.search(msg))
if obj_name:
self.assertEqual(cm.unraisable.object.__name__, obj_name)
finally:
sqlite.enable_callback_tracebacks(False)
def func_returntext(): def func_returntext():
@ -405,10 +366,10 @@ def test_func_deterministic_keyword_only(self):
def test_function_destructor_via_gc(self): def test_function_destructor_via_gc(self):
# See bpo-44304: The destructor of the user function can # See bpo-44304: The destructor of the user function can
# crash if is called without the GIL from the gc functions # crash if is called without the GIL from the gc functions
dest = sqlite.connect(':memory:')
def md5sum(t): def md5sum(t):
return return
with memory_database() as dest:
dest.create_function("md5", 1, md5sum) dest.create_function("md5", 1, md5sum)
x = dest("create table lang (name, first_appeared)") x = dest("create table lang (name, first_appeared)")
del md5sum, dest del md5sum, dest
@ -514,6 +475,10 @@ def setUp(self):
""" """
self.con.create_window_function("sumint", 1, WindowSumInt) self.con.create_window_function("sumint", 1, WindowSumInt)
def tearDown(self):
self.cur.close()
self.con.close()
def test_win_sum_int(self): def test_win_sum_int(self):
self.cur.execute(self.query % "sumint") self.cur.execute(self.query % "sumint")
self.assertEqual(self.cur.fetchall(), self.expected) self.assertEqual(self.cur.fetchall(), self.expected)
@ -634,6 +599,7 @@ def setUp(self):
""") """)
cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
("foo", 5, 3.14, None, memoryview(b"blob"),)) ("foo", 5, 3.14, None, memoryview(b"blob"),))
cur.close()
self.con.create_aggregate("nostep", 1, AggrNoStep) self.con.create_aggregate("nostep", 1, AggrNoStep)
self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
@ -646,9 +612,7 @@ def setUp(self):
self.con.create_aggregate("aggtxt", 1, AggrText) self.con.create_aggregate("aggtxt", 1, AggrText)
def tearDown(self): def tearDown(self):
#self.cur.close() self.con.close()
#self.con.close()
pass
def test_aggr_error_on_create(self): def test_aggr_error_on_create(self):
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
@ -775,7 +739,7 @@ def setUp(self):
self.con.set_authorizer(self.authorizer_cb) self.con.set_authorizer(self.authorizer_cb)
def tearDown(self): def tearDown(self):
pass self.con.close()
def test_table_access(self): def test_table_access(self):
with self.assertRaises(sqlite.DatabaseError) as cm: with self.assertRaises(sqlite.DatabaseError) as cm:

View File

@ -0,0 +1,78 @@
import contextlib
import functools
import io
import re
import sqlite3
import test.support
import unittest
# Helper for temporary memory databases
def memory_database(*args, **kwargs):
cx = sqlite3.connect(":memory:", *args, **kwargs)
return contextlib.closing(cx)
# Temporarily limit a database connection parameter
@contextlib.contextmanager
def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128):
try:
_prev = cx.setlimit(category, limit)
yield limit
finally:
cx.setlimit(category, _prev)
def with_tracebacks(exc, regex="", name=""):
"""Convenience decorator for testing callback tracebacks."""
def decorator(func):
_regex = re.compile(regex) if regex else None
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with test.support.catch_unraisable_exception() as cm:
# First, run the test with traceback enabled.
with check_tracebacks(self, cm, exc, _regex, name):
func(self, *args, **kwargs)
# Then run the test with traceback disabled.
func(self, *args, **kwargs)
return wrapper
return decorator
@contextlib.contextmanager
def check_tracebacks(self, cm, exc, regex, obj_name):
"""Convenience context manager for testing callback tracebacks."""
sqlite3.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
self.assertEqual(cm.unraisable.exc_type, exc)
if regex:
msg = str(cm.unraisable.exc_value)
self.assertIsNotNone(regex.search(msg))
if obj_name:
self.assertEqual(cm.unraisable.object.__name__, obj_name)
finally:
sqlite3.enable_callback_tracebacks(False)
class MemoryDatabaseMixin:
def setUp(self):
self.con = sqlite3.connect(":memory:")
self.cur = self.con.cursor()
def tearDown(self):
self.cur.close()
self.con.close()
@property
def cx(self):
return self.con
@property
def cu(self):
return self.cur