gh-105539: Explict resource management for connection objects in sqlite3 tests (#108017)

- Use memory_database() helper
- Move test utility functions to util.py
- Add convenience memory database mixin
- Add check() helper for closed connection tests
This commit is contained in:
Erlend E. Aasland 2023-08-17 08:45:48 +02:00 committed by GitHub
parent c9d83f93d8
commit 1344cfac43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 371 additions and 385 deletions

View File

@ -1,6 +1,8 @@
import sqlite3 as sqlite
import unittest
from .util import memory_database
class BackupTests(unittest.TestCase):
def setUp(self):
@ -32,32 +34,32 @@ def test_bad_target_same_connection(self):
self.cx.backup(self.cx)
def test_bad_target_closed_connection(self):
bck = sqlite.connect(':memory:')
bck.close()
with self.assertRaises(sqlite.ProgrammingError):
self.cx.backup(bck)
with memory_database() as bck:
bck.close()
with self.assertRaises(sqlite.ProgrammingError):
self.cx.backup(bck)
def test_bad_source_closed_connection(self):
bck = sqlite.connect(':memory:')
source = sqlite.connect(":memory:")
source.close()
with self.assertRaises(sqlite.ProgrammingError):
source.backup(bck)
with memory_database() as bck:
source = sqlite.connect(":memory:")
source.close()
with self.assertRaises(sqlite.ProgrammingError):
source.backup(bck)
def test_bad_target_in_transaction(self):
bck = sqlite.connect(':memory:')
bck.execute('CREATE TABLE bar (key INTEGER)')
bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)])
with self.assertRaises(sqlite.OperationalError) as cm:
self.cx.backup(bck)
with memory_database() as bck:
bck.execute('CREATE TABLE bar (key INTEGER)')
bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)])
with self.assertRaises(sqlite.OperationalError) as cm:
self.cx.backup(bck)
def test_keyword_only_args(self):
with self.assertRaises(TypeError):
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, 1)
def test_simple(self):
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck)
self.verify_backup(bck)
@ -67,7 +69,7 @@ def test_progress(self):
def progress(status, remaining, total):
journal.append(status)
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, pages=1, progress=progress)
self.verify_backup(bck)
@ -81,7 +83,7 @@ def test_progress_all_pages_at_once_1(self):
def progress(status, remaining, total):
journal.append(remaining)
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, progress=progress)
self.verify_backup(bck)
@ -94,7 +96,7 @@ def test_progress_all_pages_at_once_2(self):
def progress(status, remaining, total):
journal.append(remaining)
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, pages=-1, progress=progress)
self.verify_backup(bck)
@ -103,7 +105,7 @@ def progress(status, remaining, total):
def test_non_callable_progress(self):
with self.assertRaises(TypeError) as cm:
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, pages=1, progress='bar')
self.assertEqual(str(cm.exception), 'progress argument must be a callable')
@ -116,7 +118,7 @@ def progress(status, remaining, total):
self.cx.commit()
journal.append(remaining)
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, pages=1, progress=progress)
self.verify_backup(bck)
@ -140,12 +142,12 @@ def progress(status, remaining, total):
self.assertEqual(str(err.exception), 'nearly out of space')
def test_database_source_name(self):
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, name='main')
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, name='temp')
with self.assertRaises(sqlite.OperationalError) as cm:
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, name='non-existing')
self.assertIn("unknown database", str(cm.exception))
@ -153,7 +155,7 @@ def test_database_source_name(self):
self.cx.execute('CREATE TABLE attached_db.foo (key INTEGER)')
self.cx.executemany('INSERT INTO attached_db.foo (key) VALUES (?)', [(3,), (4,)])
self.cx.commit()
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, name='attached_db')
self.verify_backup(bck)

View File

@ -33,26 +33,13 @@
SHORT_TIMEOUT, check_disallow_instantiation, requires_subprocess,
is_emscripten, is_wasi
)
from test.support import gc_collect
from test.support import threading_helper
from _testcapi import INT_MAX, ULLONG_MAX
from os import SEEK_SET, SEEK_CUR, SEEK_END
from test.support.os_helper import TESTFN, TESTFN_UNDECODABLE, unlink, temp_dir, FakePath
# Helper for temporary memory databases
def memory_database(*args, **kwargs):
cx = sqlite.connect(":memory:", *args, **kwargs)
return contextlib.closing(cx)
# Temporarily limit a database connection parameter
@contextlib.contextmanager
def cx_limit(cx, category=sqlite.SQLITE_LIMIT_SQL_LENGTH, limit=128):
try:
_prev = cx.setlimit(category, limit)
yield limit
finally:
cx.setlimit(category, _prev)
from .util import memory_database, cx_limit
class ModuleTests(unittest.TestCase):
@ -326,9 +313,9 @@ def test_extended_error_code_on_exception(self):
self.assertEqual(exc.sqlite_errorname, "SQLITE_CONSTRAINT_CHECK")
def test_disallow_instantiation(self):
cx = sqlite.connect(":memory:")
check_disallow_instantiation(self, type(cx("select 1")))
check_disallow_instantiation(self, sqlite.Blob)
with memory_database() as cx:
check_disallow_instantiation(self, type(cx("select 1")))
check_disallow_instantiation(self, sqlite.Blob)
def test_complete_statement(self):
self.assertFalse(sqlite.complete_statement("select t"))
@ -342,6 +329,7 @@ def setUp(self):
cu = self.cx.cursor()
cu.execute("create table test(id integer primary key, name text)")
cu.execute("insert into test(name) values (?)", ("foo",))
cu.close()
def tearDown(self):
self.cx.close()
@ -412,21 +400,22 @@ def test_exceptions(self):
def test_in_transaction(self):
# Can't use db from setUp because we want to test initial state.
cx = sqlite.connect(":memory:")
cu = cx.cursor()
self.assertEqual(cx.in_transaction, False)
cu.execute("create table transactiontest(id integer primary key, name text)")
self.assertEqual(cx.in_transaction, False)
cu.execute("insert into transactiontest(name) values (?)", ("foo",))
self.assertEqual(cx.in_transaction, True)
cu.execute("select name from transactiontest where name=?", ["foo"])
row = cu.fetchone()
self.assertEqual(cx.in_transaction, True)
cx.commit()
self.assertEqual(cx.in_transaction, False)
cu.execute("select name from transactiontest where name=?", ["foo"])
row = cu.fetchone()
self.assertEqual(cx.in_transaction, False)
with memory_database() as cx:
cu = cx.cursor()
self.assertEqual(cx.in_transaction, False)
cu.execute("create table transactiontest(id integer primary key, name text)")
self.assertEqual(cx.in_transaction, False)
cu.execute("insert into transactiontest(name) values (?)", ("foo",))
self.assertEqual(cx.in_transaction, True)
cu.execute("select name from transactiontest where name=?", ["foo"])
row = cu.fetchone()
self.assertEqual(cx.in_transaction, True)
cx.commit()
self.assertEqual(cx.in_transaction, False)
cu.execute("select name from transactiontest where name=?", ["foo"])
row = cu.fetchone()
self.assertEqual(cx.in_transaction, False)
cu.close()
def test_in_transaction_ro(self):
with self.assertRaises(AttributeError):
@ -450,10 +439,9 @@ def test_connection_exceptions(self):
self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc))
def test_interrupt_on_closed_db(self):
cx = sqlite.connect(":memory:")
cx.close()
self.cx.close()
with self.assertRaises(sqlite.ProgrammingError):
cx.interrupt()
self.cx.interrupt()
def test_interrupt(self):
self.assertIsNone(self.cx.interrupt())
@ -521,29 +509,29 @@ def test_connection_init_good_isolation_levels(self):
self.assertEqual(cx.isolation_level, level)
def test_connection_reinit(self):
db = ":memory:"
cx = sqlite.connect(db)
cx.text_factory = bytes
cx.row_factory = sqlite.Row
cu = cx.cursor()
cu.execute("create table foo (bar)")
cu.executemany("insert into foo (bar) values (?)",
((str(v),) for v in range(4)))
cu.execute("select bar from foo")
with memory_database() as cx:
cx.text_factory = bytes
cx.row_factory = sqlite.Row
cu = cx.cursor()
cu.execute("create table foo (bar)")
cu.executemany("insert into foo (bar) values (?)",
((str(v),) for v in range(4)))
cu.execute("select bar from foo")
rows = [r for r in cu.fetchmany(2)]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], [b"0", b"1"])
rows = [r for r in cu.fetchmany(2)]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], [b"0", b"1"])
cx.__init__(db)
cx.execute("create table foo (bar)")
cx.executemany("insert into foo (bar) values (?)",
((v,) for v in ("a", "b", "c", "d")))
cx.__init__(":memory:")
cx.execute("create table foo (bar)")
cx.executemany("insert into foo (bar) values (?)",
((v,) for v in ("a", "b", "c", "d")))
# This uses the old database, old row factory, but new text factory
rows = [r for r in cu.fetchall()]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], ["2", "3"])
# This uses the old database, old row factory, but new text factory
rows = [r for r in cu.fetchall()]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], ["2", "3"])
cu.close()
def test_connection_bad_reinit(self):
cx = sqlite.connect(":memory:")
@ -591,11 +579,11 @@ def test_connect_positional_arguments(self):
"parameters in Python 3.15."
)
with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
sqlite.connect(":memory:", 1.0)
cx = sqlite.connect(":memory:", 1.0)
cx.close()
self.assertEqual(cm.filename, __file__)
class UninitialisedConnectionTests(unittest.TestCase):
def setUp(self):
self.cx = sqlite.Connection.__new__(sqlite.Connection)
@ -1571,12 +1559,12 @@ def run(con, err):
except sqlite.Error:
err.append("multi-threading not allowed")
con = sqlite.connect(":memory:", check_same_thread=False)
err = []
t = threading.Thread(target=run, kwargs={"con": con, "err": err})
t.start()
t.join()
self.assertEqual(len(err), 0, "\n".join(err))
with memory_database(check_same_thread=False) as con:
err = []
t = threading.Thread(target=run, kwargs={"con": con, "err": err})
t.start()
t.join()
self.assertEqual(len(err), 0, "\n".join(err))
class ConstructorTests(unittest.TestCase):
@ -1602,9 +1590,16 @@ def test_binary(self):
b = sqlite.Binary(b"\0'")
class ExtensionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
def tearDown(self):
self.cur.close()
self.con.close()
def test_script_string_sql(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
cur = self.cur
cur.executescript("""
-- bla bla
/* a stupid comment */
@ -1616,40 +1611,40 @@ def test_script_string_sql(self):
self.assertEqual(res, 5)
def test_script_syntax_error(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(sqlite.OperationalError):
cur.executescript("create table test(x); asdf; create table test2(x)")
self.cur.executescript("""
CREATE TABLE test(x);
asdf;
CREATE TABLE test2(x)
""")
def test_script_error_normal(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(sqlite.OperationalError):
cur.executescript("create table test(sadfsadfdsa); select foo from hurz;")
self.cur.executescript("""
CREATE TABLE test(sadfsadfdsa);
SELECT foo FROM hurz;
""")
def test_cursor_executescript_as_bytes(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(TypeError):
cur.executescript(b"create table test(foo); insert into test(foo) values (5);")
self.cur.executescript(b"""
CREATE TABLE test(foo);
INSERT INTO test(foo) VALUES (5);
""")
def test_cursor_executescript_with_null_characters(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(ValueError):
cur.executescript("""
create table a(i);\0
insert into a(i) values (5);
""")
self.cur.executescript("""
CREATE TABLE a(i);\0
INSERT INTO a(i) VALUES (5);
""")
def test_cursor_executescript_with_surrogates(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(UnicodeEncodeError):
cur.executescript("""
create table a(s);
insert into a(s) values ('\ud8ff');
""")
self.cur.executescript("""
CREATE TABLE a(s);
INSERT INTO a(s) VALUES ('\ud8ff');
""")
def test_cursor_executescript_too_large_script(self):
msg = "query string is too large"
@ -1659,19 +1654,18 @@ def test_cursor_executescript_too_large_script(self):
cx.executescript("select 'too large'".ljust(lim+1))
def test_cursor_executescript_tx_control(self):
con = sqlite.connect(":memory:")
con = self.con
con.execute("begin")
self.assertTrue(con.in_transaction)
con.executescript("select 1")
self.assertFalse(con.in_transaction)
def test_connection_execute(self):
con = sqlite.connect(":memory:")
result = con.execute("select 5").fetchone()[0]
result = self.con.execute("select 5").fetchone()[0]
self.assertEqual(result, 5, "Basic test of Connection.execute")
def test_connection_executemany(self):
con = sqlite.connect(":memory:")
con = self.con
con.execute("create table test(foo)")
con.executemany("insert into test(foo) values (?)", [(3,), (4,)])
result = con.execute("select foo from test order by foo").fetchall()
@ -1679,47 +1673,44 @@ def test_connection_executemany(self):
self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany")
def test_connection_executescript(self):
con = sqlite.connect(":memory:")
con.executescript("create table test(foo); insert into test(foo) values (5);")
con = self.con
con.executescript("""
CREATE TABLE test(foo);
INSERT INTO test(foo) VALUES (5);
""")
result = con.execute("select foo from test").fetchone()[0]
self.assertEqual(result, 5, "Basic test of Connection.executescript")
class ClosedConTests(unittest.TestCase):
def check(self, fn, *args, **kwds):
regex = "Cannot operate on a closed database."
with self.assertRaisesRegex(sqlite.ProgrammingError, regex):
fn(*args, **kwds)
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
self.con.close()
def test_closed_con_cursor(self):
con = sqlite.connect(":memory:")
con.close()
with self.assertRaises(sqlite.ProgrammingError):
cur = con.cursor()
self.check(self.con.cursor)
def test_closed_con_commit(self):
con = sqlite.connect(":memory:")
con.close()
with self.assertRaises(sqlite.ProgrammingError):
con.commit()
self.check(self.con.commit)
def test_closed_con_rollback(self):
con = sqlite.connect(":memory:")
con.close()
with self.assertRaises(sqlite.ProgrammingError):
con.rollback()
self.check(self.con.rollback)
def test_closed_cur_execute(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
con.close()
with self.assertRaises(sqlite.ProgrammingError):
cur.execute("select 4")
self.check(self.cur.execute, "select 4")
def test_closed_create_function(self):
con = sqlite.connect(":memory:")
con.close()
def f(x): return 17
with self.assertRaises(sqlite.ProgrammingError):
con.create_function("foo", 1, f)
def f(x):
return 17
self.check(self.con.create_function, "foo", 1, f)
def test_closed_create_aggregate(self):
con = sqlite.connect(":memory:")
con.close()
class Agg:
def __init__(self):
pass
@ -1727,29 +1718,21 @@ def step(self, x):
pass
def finalize(self):
return 17
with self.assertRaises(sqlite.ProgrammingError):
con.create_aggregate("foo", 1, Agg)
self.check(self.con.create_aggregate, "foo", 1, Agg)
def test_closed_set_authorizer(self):
con = sqlite.connect(":memory:")
con.close()
def authorizer(*args):
return sqlite.DENY
with self.assertRaises(sqlite.ProgrammingError):
con.set_authorizer(authorizer)
self.check(self.con.set_authorizer, authorizer)
def test_closed_set_progress_callback(self):
con = sqlite.connect(":memory:")
con.close()
def progress(): pass
with self.assertRaises(sqlite.ProgrammingError):
con.set_progress_handler(progress, 100)
def progress():
pass
self.check(self.con.set_progress_handler, progress, 100)
def test_closed_call(self):
con = sqlite.connect(":memory:")
con.close()
with self.assertRaises(sqlite.ProgrammingError):
con()
self.check(self.con)
class ClosedCurTests(unittest.TestCase):
def test_closed(self):

View File

@ -2,16 +2,12 @@
import unittest
import sqlite3 as sqlite
from .test_dbapi import memory_database
from .util import memory_database
from .util import MemoryDatabaseMixin
class DumpTests(unittest.TestCase):
def setUp(self):
self.cx = sqlite.connect(":memory:")
self.cu = self.cx.cursor()
def tearDown(self):
self.cx.close()
class DumpTests(MemoryDatabaseMixin, unittest.TestCase):
def test_table_dump(self):
expected_sqls = [

View File

@ -24,6 +24,9 @@
import sqlite3 as sqlite
from collections.abc import Sequence
from .util import memory_database
from .util import MemoryDatabaseMixin
def dict_factory(cursor, row):
d = {}
@ -45,10 +48,12 @@ class OkFactory(sqlite.Connection):
def __init__(self, *args, **kwargs):
sqlite.Connection.__init__(self, *args, **kwargs)
for factory in DefectFactory, OkFactory:
with self.subTest(factory=factory):
con = sqlite.connect(":memory:", factory=factory)
self.assertIsInstance(con, factory)
with memory_database(factory=OkFactory) as con:
self.assertIsInstance(con, OkFactory)
regex = "Base Connection.__init__ not called."
with self.assertRaisesRegex(sqlite.ProgrammingError, regex):
with memory_database(factory=DefectFactory) as con:
self.assertIsInstance(con, DefectFactory)
def test_connection_factory_relayed_call(self):
# gh-95132: keyword args must not be passed as positional args
@ -57,9 +62,9 @@ def __init__(self, *args, **kwargs):
kwargs["isolation_level"] = None
super(Factory, self).__init__(*args, **kwargs)
con = sqlite.connect(":memory:", factory=Factory)
self.assertIsNone(con.isolation_level)
self.assertIsInstance(con, Factory)
with memory_database(factory=Factory) as con:
self.assertIsNone(con.isolation_level)
self.assertIsInstance(con, Factory)
def test_connection_factory_as_positional_arg(self):
class Factory(sqlite.Connection):
@ -74,18 +79,13 @@ def __init__(self, *args, **kwargs):
r"parameters in Python 3.15."
)
with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
con = sqlite.connect(":memory:", 5.0, 0, None, True, Factory)
with memory_database(5.0, 0, None, True, Factory) as con:
self.assertIsNone(con.isolation_level)
self.assertIsInstance(con, Factory)
self.assertEqual(cm.filename, __file__)
self.assertIsNone(con.isolation_level)
self.assertIsInstance(con, Factory)
class CursorFactoryTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def tearDown(self):
self.con.close()
class CursorFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
def test_is_instance(self):
cur = self.con.cursor()
@ -103,9 +103,8 @@ def test_invalid_factory(self):
# invalid callable returning non-cursor
self.assertRaises(TypeError, self.con.cursor, lambda con: None)
class RowFactoryTestsBackwardsCompat(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
class RowFactoryTestsBackwardsCompat(MemoryDatabaseMixin, unittest.TestCase):
def test_is_produced_by_factory(self):
cur = self.con.cursor(factory=MyCursor)
@ -114,12 +113,8 @@ def test_is_produced_by_factory(self):
self.assertIsInstance(row, dict)
cur.close()
def tearDown(self):
self.con.close()
class RowFactoryTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
class RowFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
def test_custom_factory(self):
self.con.row_factory = lambda cur, row: list(row)
@ -265,12 +260,8 @@ class FakeCursor(str):
self.assertRaises(TypeError, self.con.cursor, FakeCursor)
self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ())
def tearDown(self):
self.con.close()
class TextFactoryTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
class TextFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
def test_unicode(self):
austria = "Österreich"
@ -291,15 +282,17 @@ def test_custom(self):
self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
self.assertTrue(row[0].endswith("reich"), "column must contain original data")
def tearDown(self):
self.con.close()
class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.con.execute("create table test (value text)")
self.con.execute("insert into test (value) values (?)", ("a\x00b",))
def tearDown(self):
self.con.close()
def test_string(self):
# text_factory defaults to str
row = self.con.execute("select value from test").fetchone()
@ -325,9 +318,6 @@ def test_custom(self):
self.assertIs(type(row[0]), bytes)
self.assertEqual(row[0], b"a\x00b")
def tearDown(self):
self.con.close()
if __name__ == "__main__":
unittest.main()

View File

@ -26,34 +26,31 @@
from test.support.os_helper import TESTFN, unlink
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
from test.test_sqlite3.test_userfunctions import with_tracebacks
from .util import memory_database, cx_limit, with_tracebacks
from .util import MemoryDatabaseMixin
class CollationTests(unittest.TestCase):
class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
def test_create_collation_not_string(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError):
con.create_collation(None, lambda x, y: (x > y) - (x < y))
self.con.create_collation(None, lambda x, y: (x > y) - (x < y))
def test_create_collation_not_callable(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError) as cm:
con.create_collation("X", 42)
self.con.create_collation("X", 42)
self.assertEqual(str(cm.exception), 'parameter must be callable')
def test_create_collation_not_ascii(self):
con = sqlite.connect(":memory:")
con.create_collation("collä", lambda x, y: (x > y) - (x < y))
self.con.create_collation("collä", lambda x, y: (x > y) - (x < y))
def test_create_collation_bad_upper(self):
class BadUpperStr(str):
def upper(self):
return None
con = sqlite.connect(":memory:")
mycoll = lambda x, y: -((x > y) - (x < y))
con.create_collation(BadUpperStr("mycoll"), mycoll)
result = con.execute("""
self.con.create_collation(BadUpperStr("mycoll"), mycoll)
result = self.con.execute("""
select x from (
select 'a' as x
union
@ -68,8 +65,7 @@ def mycoll(x, y):
# reverse order
return -((x > y) - (x < y))
con = sqlite.connect(":memory:")
con.create_collation("mycoll", mycoll)
self.con.create_collation("mycoll", mycoll)
sql = """
select x from (
select 'a' as x
@ -79,21 +75,20 @@ def mycoll(x, y):
select 'c' as x
) order by x collate mycoll
"""
result = con.execute(sql).fetchall()
result = self.con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)],
msg='the expected order was not returned')
con.create_collation("mycoll", None)
self.con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm:
result = con.execute(sql).fetchall()
result = self.con.execute(sql).fetchall()
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
def test_collation_returns_large_integer(self):
def mycoll(x, y):
# reverse order
return -((x > y) - (x < y)) * 2**32
con = sqlite.connect(":memory:")
con.create_collation("mycoll", mycoll)
self.con.create_collation("mycoll", mycoll)
sql = """
select x from (
select 'a' as x
@ -103,7 +98,7 @@ def mycoll(x, y):
select 'c' as x
) order by x collate mycoll
"""
result = con.execute(sql).fetchall()
result = self.con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)],
msg="the expected order was not returned")
@ -112,7 +107,7 @@ def test_collation_register_twice(self):
Register two different collation functions under the same name.
Verify that the last one is actually used.
"""
con = sqlite.connect(":memory:")
con = self.con
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
result = con.execute("""
@ -126,25 +121,26 @@ def test_deregister_collation(self):
Register a collation, then deregister it. Make sure an error is raised if we try
to use it.
"""
con = sqlite.connect(":memory:")
con = self.con
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm:
con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
class ProgressTests(unittest.TestCase):
class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
def test_progress_handler_used(self):
"""
Test that the progress handler is invoked once it is set.
"""
con = sqlite.connect(":memory:")
progress_calls = []
def progress():
progress_calls.append(None)
return 0
con.set_progress_handler(progress, 1)
con.execute("""
self.con.set_progress_handler(progress, 1)
self.con.execute("""
create table foo(a, b)
""")
self.assertTrue(progress_calls)
@ -153,7 +149,7 @@ def test_opcode_count(self):
"""
Test that the opcode argument is respected.
"""
con = sqlite.connect(":memory:")
con = self.con
progress_calls = []
def progress():
progress_calls.append(None)
@ -176,11 +172,10 @@ def test_cancel_operation(self):
"""
Test that returning a non-zero value stops the operation in progress.
"""
con = sqlite.connect(":memory:")
def progress():
return 1
con.set_progress_handler(progress, 1)
curs = con.cursor()
self.con.set_progress_handler(progress, 1)
curs = self.con.cursor()
self.assertRaises(
sqlite.OperationalError,
curs.execute,
@ -190,7 +185,7 @@ def test_clear_handler(self):
"""
Test that setting the progress handler to None clears the previously set handler.
"""
con = sqlite.connect(":memory:")
con = self.con
action = 0
def progress():
nonlocal action
@ -203,31 +198,30 @@ def progress():
@with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler(self):
con = sqlite.connect(":memory:")
def bad_progress():
1 / 0
con.set_progress_handler(bad_progress, 1)
self.con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
self.con.execute("""
create table foo(a, b)
""")
@with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler_result(self):
con = sqlite.connect(":memory:")
class BadBool:
def __bool__(self):
1 / 0
def bad_progress():
return BadBool()
con.set_progress_handler(bad_progress, 1)
self.con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
self.con.execute("""
create table foo(a, b)
""")
class TraceCallbackTests(unittest.TestCase):
class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
@contextlib.contextmanager
def check_stmt_trace(self, cx, expected):
try:
@ -242,12 +236,11 @@ def test_trace_callback_used(self):
"""
Test that the trace callback is invoked once it is set.
"""
con = sqlite.connect(":memory:")
traced_statements = []
def trace(statement):
traced_statements.append(statement)
con.set_trace_callback(trace)
con.execute("create table foo(a, b)")
self.con.set_trace_callback(trace)
self.con.execute("create table foo(a, b)")
self.assertTrue(traced_statements)
self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
@ -255,7 +248,7 @@ def test_clear_trace_callback(self):
"""
Test that setting the trace callback to None clears the previously set callback.
"""
con = sqlite.connect(":memory:")
con = self.con
traced_statements = []
def trace(statement):
traced_statements.append(statement)
@ -269,7 +262,7 @@ def test_unicode_content(self):
Test that the statement can contain unicode literals.
"""
unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
con = sqlite.connect(":memory:")
con = self.con
traced_statements = []
def trace(statement):
traced_statements.append(statement)

View File

@ -28,15 +28,12 @@
from test import support
from unittest.mock import patch
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
from .util import memory_database, cx_limit
from .util import MemoryDatabaseMixin
class RegressionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def tearDown(self):
self.con.close()
class RegressionTests(MemoryDatabaseMixin, unittest.TestCase):
def test_pragma_user_version(self):
# This used to crash pysqlite because this pragma command returns NULL for the column name
@ -45,28 +42,24 @@ def test_pragma_user_version(self):
def test_pragma_schema_version(self):
# This still crashed pysqlite <= 2.2.1
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES)
try:
with memory_database(detect_types=sqlite.PARSE_COLNAMES) as con:
cur = self.con.cursor()
cur.execute("pragma schema_version")
finally:
cur.close()
con.close()
def test_statement_reset(self):
# pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are
# reset before a rollback, but only those that are still in the
# statement cache. The others are not accessible from the connection object.
con = sqlite.connect(":memory:", cached_statements=5)
cursors = [con.cursor() for x in range(5)]
cursors[0].execute("create table test(x)")
for i in range(10):
cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)])
with memory_database(cached_statements=5) as con:
cursors = [con.cursor() for x in range(5)]
cursors[0].execute("create table test(x)")
for i in range(10):
cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)])
for i in range(5):
cursors[i].execute(" " * i + "select x from test")
for i in range(5):
cursors[i].execute(" " * i + "select x from test")
con.rollback()
con.rollback()
def test_column_name_with_spaces(self):
cur = self.con.cursor()
@ -81,17 +74,15 @@ def test_statement_finalization_on_close_db(self):
# cache when closing the database. statements that were still
# referenced in cursors weren't closed and could provoke "
# "OperationalError: Unable to close due to unfinalised statements".
con = sqlite.connect(":memory:")
cursors = []
# default statement cache size is 100
for i in range(105):
cur = con.cursor()
cur = self.con.cursor()
cursors.append(cur)
cur.execute("select 1 x union select " + str(i))
con.close()
def test_on_conflict_rollback(self):
con = sqlite.connect(":memory:")
con = self.con
con.execute("create table foo(x, unique(x) on conflict rollback)")
con.execute("insert into foo(x) values (1)")
try:
@ -126,16 +117,16 @@ def test_type_map_usage(self):
a statement. This test exhibits the problem.
"""
SELECT = "select * from foo"
con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES)
cur = con.cursor()
cur.execute("create table foo(bar timestamp)")
with self.assertWarnsRegex(DeprecationWarning, "adapter"):
cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),))
cur.execute(SELECT)
cur.execute("drop table foo")
cur.execute("create table foo(bar integer)")
cur.execute("insert into foo(bar) values (5)")
cur.execute(SELECT)
with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
cur = con.cursor()
cur.execute("create table foo(bar timestamp)")
with self.assertWarnsRegex(DeprecationWarning, "adapter"):
cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),))
cur.execute(SELECT)
cur.execute("drop table foo")
cur.execute("create table foo(bar integer)")
cur.execute("insert into foo(bar) values (5)")
cur.execute(SELECT)
def test_bind_mutating_list(self):
# Issue41662: Crash when mutate a list of parameters during iteration.
@ -144,11 +135,11 @@ def __conform__(self, protocol):
parameters.clear()
return "..."
parameters = [X(), 0]
con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES)
con.execute("create table foo(bar X, baz integer)")
# Should not crash
with self.assertRaises(IndexError):
con.execute("insert into foo(bar, baz) values (?, ?)", parameters)
with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
con.execute("create table foo(bar X, baz integer)")
# Should not crash
with self.assertRaises(IndexError):
con.execute("insert into foo(bar, baz) values (?, ?)", parameters)
def test_error_msg_decode_error(self):
# When porting the module to Python 3.0, the error message about
@ -173,7 +164,7 @@ def upper(self):
def __del__(self):
con.isolation_level = ""
con = sqlite.connect(":memory:")
con = self.con
con.isolation_level = None
for level in "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE":
with self.subTest(level=level):
@ -204,8 +195,7 @@ class Cursor(sqlite.Cursor):
def __init__(self, con):
pass
con = sqlite.connect(":memory:")
cur = Cursor(con)
cur = Cursor(self.con)
with self.assertRaises(sqlite.ProgrammingError):
cur.execute("select 4+5").fetchall()
with self.assertRaisesRegex(sqlite.ProgrammingError,
@ -238,7 +228,9 @@ def test_auto_commit(self):
2.5.3 introduced a regression so that these could no longer
be created.
"""
con = sqlite.connect(":memory:", isolation_level=None)
with memory_database(isolation_level=None) as con:
self.assertIsNone(con.isolation_level)
self.assertFalse(con.in_transaction)
def test_pragma_autocommit(self):
"""
@ -273,9 +265,7 @@ def test_recursive_cursor_use(self):
Recursively using a cursor, such as when reusing it from a generator led to segfaults.
Now we catch recursive cursor usage and raise a ProgrammingError.
"""
con = sqlite.connect(":memory:")
cur = con.cursor()
cur = self.con.cursor()
cur.execute("create table a (bar)")
cur.execute("create table b (baz)")
@ -295,29 +285,30 @@ def test_convert_timestamp_microsecond_padding(self):
since the microsecond string "456" actually represents "456000".
"""
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES)
cur = con.cursor()
cur.execute("CREATE TABLE t (x TIMESTAMP)")
with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
cur = con.cursor()
cur.execute("CREATE TABLE t (x TIMESTAMP)")
# Microseconds should be 456000
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')")
# Microseconds should be 456000
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')")
# Microseconds should be truncated to 123456
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')")
# Microseconds should be truncated to 123456
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')")
cur.execute("SELECT * FROM t")
with self.assertWarnsRegex(DeprecationWarning, "converter"):
values = [x[0] for x in cur.fetchall()]
cur.execute("SELECT * FROM t")
with self.assertWarnsRegex(DeprecationWarning, "converter"):
values = [x[0] for x in cur.fetchall()]
self.assertEqual(values, [
datetime.datetime(2012, 4, 4, 15, 6, 0, 456000),
datetime.datetime(2012, 4, 4, 15, 6, 0, 123456),
])
self.assertEqual(values, [
datetime.datetime(2012, 4, 4, 15, 6, 0, 456000),
datetime.datetime(2012, 4, 4, 15, 6, 0, 123456),
])
def test_invalid_isolation_level_type(self):
# isolation level is a string, not an integer
self.assertRaises(TypeError,
sqlite.connect, ":memory:", isolation_level=123)
regex = "isolation_level must be str or None"
with self.assertRaisesRegex(TypeError, regex):
memory_database(isolation_level=123).__enter__()
def test_null_character(self):
@ -333,7 +324,7 @@ def test_null_character(self):
cur.execute, query)
def test_surrogates(self):
con = sqlite.connect(":memory:")
con = self.con
self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
self.assertRaises(UnicodeEncodeError, con, "select '\udcff'")
cur = con.cursor()
@ -359,7 +350,7 @@ def test_commit_cursor_reset(self):
to return rows multiple times when fetched from cursors
after commit. See issues 10513 and 23129 for details.
"""
con = sqlite.connect(":memory:")
con = self.con
con.executescript("""
create table t(c);
create table t2(c);
@ -391,10 +382,9 @@ def test_bpo31770(self):
"""
def callback(*args):
pass
con = sqlite.connect(":memory:")
cur = sqlite.Cursor(con)
cur = sqlite.Cursor(self.con)
ref = weakref.ref(cur, callback)
cur.__init__(con)
cur.__init__(self.con)
del cur
# The interpreter shouldn't crash when ref is collected.
del ref
@ -425,6 +415,7 @@ def test_return_empty_bytestring(self):
def test_table_lock_cursor_replace_stmt(self):
with memory_database() as con:
con = self.con
cur = con.cursor()
cur.execute("create table t(t)")
cur.executemany("insert into t values(?)",

View File

@ -28,7 +28,8 @@
from test.support.os_helper import TESTFN, unlink
from test.support.script_helper import assert_python_ok
from test.test_sqlite3.test_dbapi import memory_database
from .util import memory_database
from .util import MemoryDatabaseMixin
TIMEOUT = LOOPBACK_TIMEOUT / 10
@ -132,14 +133,14 @@ def test_locking(self):
def test_rollback_cursor_consistency(self):
"""Check that cursors behave correctly after rollback."""
con = sqlite.connect(":memory:")
cur = con.cursor()
cur.execute("create table test(x)")
cur.execute("insert into test(x) values (5)")
cur.execute("select 1 union select 2 union select 3")
with memory_database() as con:
cur = con.cursor()
cur.execute("create table test(x)")
cur.execute("insert into test(x) values (5)")
cur.execute("select 1 union select 2 union select 3")
con.rollback()
self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)])
con.rollback()
self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)])
def test_multiple_cursors_and_iternext(self):
# gh-94028: statements are cleared and reset in cursor iternext.
@ -218,10 +219,7 @@ def test_no_duplicate_rows_after_rollback_new_query(self):
class SpecialCommandTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
class SpecialCommandTests(MemoryDatabaseMixin, unittest.TestCase):
def test_drop_table(self):
self.cur.execute("create table test(i)")
@ -233,14 +231,8 @@ def test_pragma(self):
self.cur.execute("insert into test(i) values (5)")
self.cur.execute("pragma count_changes=1")
def tearDown(self):
self.cur.close()
self.con.close()
class TransactionalDDL(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
class TransactionalDDL(MemoryDatabaseMixin, unittest.TestCase):
def test_ddl_does_not_autostart_transaction(self):
# For backwards compatibility reasons, DDL statements should not
@ -268,9 +260,6 @@ def test_transactional_ddl(self):
with self.assertRaises(sqlite.OperationalError):
self.con.execute("select * from test")
def tearDown(self):
self.con.close()
class IsolationLevelFromInit(unittest.TestCase):
CREATE = "create table t(t)"

View File

@ -21,54 +21,15 @@
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
import contextlib
import functools
import io
import re
import sys
import unittest
import sqlite3 as sqlite
from unittest.mock import Mock, patch
from test.support import bigmemtest, catch_unraisable_exception, gc_collect
from test.support import bigmemtest, gc_collect
from test.test_sqlite3.test_dbapi import cx_limit
def with_tracebacks(exc, regex="", name=""):
"""Convenience decorator for testing callback tracebacks."""
def decorator(func):
_regex = re.compile(regex) if regex else None
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with catch_unraisable_exception() as cm:
# First, run the test with traceback enabled.
with check_tracebacks(self, cm, exc, _regex, name):
func(self, *args, **kwargs)
# Then run the test with traceback disabled.
func(self, *args, **kwargs)
return wrapper
return decorator
@contextlib.contextmanager
def check_tracebacks(self, cm, exc, regex, obj_name):
"""Convenience context manager for testing callback tracebacks."""
sqlite.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
self.assertEqual(cm.unraisable.exc_type, exc)
if regex:
msg = str(cm.unraisable.exc_value)
self.assertIsNotNone(regex.search(msg))
if obj_name:
self.assertEqual(cm.unraisable.object.__name__, obj_name)
finally:
sqlite.enable_callback_tracebacks(False)
from .util import cx_limit, memory_database
from .util import with_tracebacks, check_tracebacks
def func_returntext():
@ -405,19 +366,19 @@ def test_func_deterministic_keyword_only(self):
def test_function_destructor_via_gc(self):
# See bpo-44304: The destructor of the user function can
# crash if is called without the GIL from the gc functions
dest = sqlite.connect(':memory:')
def md5sum(t):
return
dest.create_function("md5", 1, md5sum)
x = dest("create table lang (name, first_appeared)")
del md5sum, dest
with memory_database() as dest:
dest.create_function("md5", 1, md5sum)
x = dest("create table lang (name, first_appeared)")
del md5sum, dest
y = [x]
y.append(y)
y = [x]
y.append(y)
del x,y
gc_collect()
del x,y
gc_collect()
@with_tracebacks(OverflowError)
def test_func_return_too_large_int(self):
@ -514,6 +475,10 @@ def setUp(self):
"""
self.con.create_window_function("sumint", 1, WindowSumInt)
def tearDown(self):
self.cur.close()
self.con.close()
def test_win_sum_int(self):
self.cur.execute(self.query % "sumint")
self.assertEqual(self.cur.fetchall(), self.expected)
@ -634,6 +599,7 @@ def setUp(self):
""")
cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
("foo", 5, 3.14, None, memoryview(b"blob"),))
cur.close()
self.con.create_aggregate("nostep", 1, AggrNoStep)
self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
@ -646,9 +612,7 @@ def setUp(self):
self.con.create_aggregate("aggtxt", 1, AggrText)
def tearDown(self):
#self.cur.close()
#self.con.close()
pass
self.con.close()
def test_aggr_error_on_create(self):
with self.assertRaises(sqlite.OperationalError):
@ -775,7 +739,7 @@ def setUp(self):
self.con.set_authorizer(self.authorizer_cb)
def tearDown(self):
pass
self.con.close()
def test_table_access(self):
with self.assertRaises(sqlite.DatabaseError) as cm:

View File

@ -0,0 +1,78 @@
import contextlib
import functools
import io
import re
import sqlite3
import test.support
import unittest
# Helper for temporary memory databases
def memory_database(*args, **kwargs):
cx = sqlite3.connect(":memory:", *args, **kwargs)
return contextlib.closing(cx)
# Temporarily limit a database connection parameter
@contextlib.contextmanager
def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128):
try:
_prev = cx.setlimit(category, limit)
yield limit
finally:
cx.setlimit(category, _prev)
def with_tracebacks(exc, regex="", name=""):
"""Convenience decorator for testing callback tracebacks."""
def decorator(func):
_regex = re.compile(regex) if regex else None
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with test.support.catch_unraisable_exception() as cm:
# First, run the test with traceback enabled.
with check_tracebacks(self, cm, exc, _regex, name):
func(self, *args, **kwargs)
# Then run the test with traceback disabled.
func(self, *args, **kwargs)
return wrapper
return decorator
@contextlib.contextmanager
def check_tracebacks(self, cm, exc, regex, obj_name):
"""Convenience context manager for testing callback tracebacks."""
sqlite3.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
self.assertEqual(cm.unraisable.exc_type, exc)
if regex:
msg = str(cm.unraisable.exc_value)
self.assertIsNotNone(regex.search(msg))
if obj_name:
self.assertEqual(cm.unraisable.object.__name__, obj_name)
finally:
sqlite3.enable_callback_tracebacks(False)
class MemoryDatabaseMixin:
def setUp(self):
self.con = sqlite3.connect(":memory:")
self.cur = self.con.cursor()
def tearDown(self):
self.cur.close()
self.con.close()
@property
def cx(self):
return self.con
@property
def cu(self):
return self.cur