»Core Development>Code coverage>Lib/sqlite3/test/userfunctions.py

Python code coverage for Lib/sqlite3/test/userfunctions.py

#countcontent
1n/a#-*- coding: iso-8859-1 -*-
2n/a# pysqlite2/test/userfunctions.py: tests for user-defined functions and
3n/a# aggregates.
4n/a#
5n/a# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
6n/a#
7n/a# This file is part of pysqlite.
8n/a#
9n/a# This software is provided 'as-is', without any express or implied
10n/a# warranty. In no event will the authors be held liable for any damages
11n/a# arising from the use of this software.
12n/a#
13n/a# Permission is granted to anyone to use this software for any purpose,
14n/a# including commercial applications, and to alter it and redistribute it
15n/a# freely, subject to the following restrictions:
16n/a#
17n/a# 1. The origin of this software must not be misrepresented; you must not
18n/a# claim that you wrote the original software. If you use this software
19n/a# in a product, an acknowledgment in the product documentation would be
20n/a# appreciated but is not required.
21n/a# 2. Altered source versions must be plainly marked as such, and must not be
22n/a# misrepresented as being the original software.
23n/a# 3. This notice may not be removed or altered from any source distribution.
24n/a
25n/aimport unittest
26n/aimport sqlite3 as sqlite
27n/a
28n/adef func_returntext():
29n/a return "foo"
30n/adef func_returnunicode():
31n/a return "bar"
32n/adef func_returnint():
33n/a return 42
34n/adef func_returnfloat():
35n/a return 3.14
36n/adef func_returnnull():
37n/a return None
38n/adef func_returnblob():
39n/a return b"blob"
40n/adef func_returnlonglong():
41n/a return 1<<31
42n/adef func_raiseexception():
43n/a 5/0
44n/a
45n/adef func_isstring(v):
46n/a return type(v) is str
47n/adef func_isint(v):
48n/a return type(v) is int
49n/adef func_isfloat(v):
50n/a return type(v) is float
51n/adef func_isnone(v):
52n/a return type(v) is type(None)
53n/adef func_isblob(v):
54n/a return isinstance(v, (bytes, memoryview))
55n/adef func_islonglong(v):
56n/a return isinstance(v, int) and v >= 1<<31
57n/a
58n/adef func(*args):
59n/a return len(args)
60n/a
61n/aclass AggrNoStep:
62n/a def __init__(self):
63n/a pass
64n/a
65n/a def finalize(self):
66n/a return 1
67n/a
68n/aclass AggrNoFinalize:
69n/a def __init__(self):
70n/a pass
71n/a
72n/a def step(self, x):
73n/a pass
74n/a
75n/aclass AggrExceptionInInit:
76n/a def __init__(self):
77n/a 5/0
78n/a
79n/a def step(self, x):
80n/a pass
81n/a
82n/a def finalize(self):
83n/a pass
84n/a
85n/aclass AggrExceptionInStep:
86n/a def __init__(self):
87n/a pass
88n/a
89n/a def step(self, x):
90n/a 5/0
91n/a
92n/a def finalize(self):
93n/a return 42
94n/a
95n/aclass AggrExceptionInFinalize:
96n/a def __init__(self):
97n/a pass
98n/a
99n/a def step(self, x):
100n/a pass
101n/a
102n/a def finalize(self):
103n/a 5/0
104n/a
105n/aclass AggrCheckType:
106n/a def __init__(self):
107n/a self.val = None
108n/a
109n/a def step(self, whichType, val):
110n/a theType = {"str": str, "int": int, "float": float, "None": type(None),
111n/a "blob": bytes}
112n/a self.val = int(theType[whichType] is type(val))
113n/a
114n/a def finalize(self):
115n/a return self.val
116n/a
117n/aclass AggrCheckTypes:
118n/a def __init__(self):
119n/a self.val = 0
120n/a
121n/a def step(self, whichType, *vals):
122n/a theType = {"str": str, "int": int, "float": float, "None": type(None),
123n/a "blob": bytes}
124n/a for val in vals:
125n/a self.val += int(theType[whichType] is type(val))
126n/a
127n/a def finalize(self):
128n/a return self.val
129n/a
130n/aclass AggrSum:
131n/a def __init__(self):
132n/a self.val = 0.0
133n/a
134n/a def step(self, val):
135n/a self.val += val
136n/a
137n/a def finalize(self):
138n/a return self.val
139n/a
140n/aclass FunctionTests(unittest.TestCase):
141n/a def setUp(self):
142n/a self.con = sqlite.connect(":memory:")
143n/a
144n/a self.con.create_function("returntext", 0, func_returntext)
145n/a self.con.create_function("returnunicode", 0, func_returnunicode)
146n/a self.con.create_function("returnint", 0, func_returnint)
147n/a self.con.create_function("returnfloat", 0, func_returnfloat)
148n/a self.con.create_function("returnnull", 0, func_returnnull)
149n/a self.con.create_function("returnblob", 0, func_returnblob)
150n/a self.con.create_function("returnlonglong", 0, func_returnlonglong)
151n/a self.con.create_function("raiseexception", 0, func_raiseexception)
152n/a
153n/a self.con.create_function("isstring", 1, func_isstring)
154n/a self.con.create_function("isint", 1, func_isint)
155n/a self.con.create_function("isfloat", 1, func_isfloat)
156n/a self.con.create_function("isnone", 1, func_isnone)
157n/a self.con.create_function("isblob", 1, func_isblob)
158n/a self.con.create_function("islonglong", 1, func_islonglong)
159n/a self.con.create_function("spam", -1, func)
160n/a
161n/a def tearDown(self):
162n/a self.con.close()
163n/a
164n/a def CheckFuncErrorOnCreate(self):
165n/a with self.assertRaises(sqlite.OperationalError):
166n/a self.con.create_function("bla", -100, lambda x: 2*x)
167n/a
168n/a def CheckFuncRefCount(self):
169n/a def getfunc():
170n/a def f():
171n/a return 1
172n/a return f
173n/a f = getfunc()
174n/a globals()["foo"] = f
175n/a # self.con.create_function("reftest", 0, getfunc())
176n/a self.con.create_function("reftest", 0, f)
177n/a cur = self.con.cursor()
178n/a cur.execute("select reftest()")
179n/a
180n/a def CheckFuncReturnText(self):
181n/a cur = self.con.cursor()
182n/a cur.execute("select returntext()")
183n/a val = cur.fetchone()[0]
184n/a self.assertEqual(type(val), str)
185n/a self.assertEqual(val, "foo")
186n/a
187n/a def CheckFuncReturnUnicode(self):
188n/a cur = self.con.cursor()
189n/a cur.execute("select returnunicode()")
190n/a val = cur.fetchone()[0]
191n/a self.assertEqual(type(val), str)
192n/a self.assertEqual(val, "bar")
193n/a
194n/a def CheckFuncReturnInt(self):
195n/a cur = self.con.cursor()
196n/a cur.execute("select returnint()")
197n/a val = cur.fetchone()[0]
198n/a self.assertEqual(type(val), int)
199n/a self.assertEqual(val, 42)
200n/a
201n/a def CheckFuncReturnFloat(self):
202n/a cur = self.con.cursor()
203n/a cur.execute("select returnfloat()")
204n/a val = cur.fetchone()[0]
205n/a self.assertEqual(type(val), float)
206n/a if val < 3.139 or val > 3.141:
207n/a self.fail("wrong value")
208n/a
209n/a def CheckFuncReturnNull(self):
210n/a cur = self.con.cursor()
211n/a cur.execute("select returnnull()")
212n/a val = cur.fetchone()[0]
213n/a self.assertEqual(type(val), type(None))
214n/a self.assertEqual(val, None)
215n/a
216n/a def CheckFuncReturnBlob(self):
217n/a cur = self.con.cursor()
218n/a cur.execute("select returnblob()")
219n/a val = cur.fetchone()[0]
220n/a self.assertEqual(type(val), bytes)
221n/a self.assertEqual(val, b"blob")
222n/a
223n/a def CheckFuncReturnLongLong(self):
224n/a cur = self.con.cursor()
225n/a cur.execute("select returnlonglong()")
226n/a val = cur.fetchone()[0]
227n/a self.assertEqual(val, 1<<31)
228n/a
229n/a def CheckFuncException(self):
230n/a cur = self.con.cursor()
231n/a with self.assertRaises(sqlite.OperationalError) as cm:
232n/a cur.execute("select raiseexception()")
233n/a cur.fetchone()
234n/a self.assertEqual(str(cm.exception), 'user-defined function raised exception')
235n/a
236n/a def CheckParamString(self):
237n/a cur = self.con.cursor()
238n/a cur.execute("select isstring(?)", ("foo",))
239n/a val = cur.fetchone()[0]
240n/a self.assertEqual(val, 1)
241n/a
242n/a def CheckParamInt(self):
243n/a cur = self.con.cursor()
244n/a cur.execute("select isint(?)", (42,))
245n/a val = cur.fetchone()[0]
246n/a self.assertEqual(val, 1)
247n/a
248n/a def CheckParamFloat(self):
249n/a cur = self.con.cursor()
250n/a cur.execute("select isfloat(?)", (3.14,))
251n/a val = cur.fetchone()[0]
252n/a self.assertEqual(val, 1)
253n/a
254n/a def CheckParamNone(self):
255n/a cur = self.con.cursor()
256n/a cur.execute("select isnone(?)", (None,))
257n/a val = cur.fetchone()[0]
258n/a self.assertEqual(val, 1)
259n/a
260n/a def CheckParamBlob(self):
261n/a cur = self.con.cursor()
262n/a cur.execute("select isblob(?)", (memoryview(b"blob"),))
263n/a val = cur.fetchone()[0]
264n/a self.assertEqual(val, 1)
265n/a
266n/a def CheckParamLongLong(self):
267n/a cur = self.con.cursor()
268n/a cur.execute("select islonglong(?)", (1<<42,))
269n/a val = cur.fetchone()[0]
270n/a self.assertEqual(val, 1)
271n/a
272n/a def CheckAnyArguments(self):
273n/a cur = self.con.cursor()
274n/a cur.execute("select spam(?, ?)", (1, 2))
275n/a val = cur.fetchone()[0]
276n/a self.assertEqual(val, 2)
277n/a
278n/a
279n/aclass AggregateTests(unittest.TestCase):
280n/a def setUp(self):
281n/a self.con = sqlite.connect(":memory:")
282n/a cur = self.con.cursor()
283n/a cur.execute("""
284n/a create table test(
285n/a t text,
286n/a i integer,
287n/a f float,
288n/a n,
289n/a b blob
290n/a )
291n/a """)
292n/a cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
293n/a ("foo", 5, 3.14, None, memoryview(b"blob"),))
294n/a
295n/a self.con.create_aggregate("nostep", 1, AggrNoStep)
296n/a self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
297n/a self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
298n/a self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
299n/a self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
300n/a self.con.create_aggregate("checkType", 2, AggrCheckType)
301n/a self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
302n/a self.con.create_aggregate("mysum", 1, AggrSum)
303n/a
304n/a def tearDown(self):
305n/a #self.cur.close()
306n/a #self.con.close()
307n/a pass
308n/a
309n/a def CheckAggrErrorOnCreate(self):
310n/a with self.assertRaises(sqlite.OperationalError):
311n/a self.con.create_function("bla", -100, AggrSum)
312n/a
313n/a def CheckAggrNoStep(self):
314n/a cur = self.con.cursor()
315n/a with self.assertRaises(AttributeError) as cm:
316n/a cur.execute("select nostep(t) from test")
317n/a self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'")
318n/a
319n/a def CheckAggrNoFinalize(self):
320n/a cur = self.con.cursor()
321n/a with self.assertRaises(sqlite.OperationalError) as cm:
322n/a cur.execute("select nofinalize(t) from test")
323n/a val = cur.fetchone()[0]
324n/a self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
325n/a
326n/a def CheckAggrExceptionInInit(self):
327n/a cur = self.con.cursor()
328n/a with self.assertRaises(sqlite.OperationalError) as cm:
329n/a cur.execute("select excInit(t) from test")
330n/a val = cur.fetchone()[0]
331n/a self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
332n/a
333n/a def CheckAggrExceptionInStep(self):
334n/a cur = self.con.cursor()
335n/a with self.assertRaises(sqlite.OperationalError) as cm:
336n/a cur.execute("select excStep(t) from test")
337n/a val = cur.fetchone()[0]
338n/a self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
339n/a
340n/a def CheckAggrExceptionInFinalize(self):
341n/a cur = self.con.cursor()
342n/a with self.assertRaises(sqlite.OperationalError) as cm:
343n/a cur.execute("select excFinalize(t) from test")
344n/a val = cur.fetchone()[0]
345n/a self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
346n/a
347n/a def CheckAggrCheckParamStr(self):
348n/a cur = self.con.cursor()
349n/a cur.execute("select checkType('str', ?)", ("foo",))
350n/a val = cur.fetchone()[0]
351n/a self.assertEqual(val, 1)
352n/a
353n/a def CheckAggrCheckParamInt(self):
354n/a cur = self.con.cursor()
355n/a cur.execute("select checkType('int', ?)", (42,))
356n/a val = cur.fetchone()[0]
357n/a self.assertEqual(val, 1)
358n/a
359n/a def CheckAggrCheckParamsInt(self):
360n/a cur = self.con.cursor()
361n/a cur.execute("select checkTypes('int', ?, ?)", (42, 24))
362n/a val = cur.fetchone()[0]
363n/a self.assertEqual(val, 2)
364n/a
365n/a def CheckAggrCheckParamFloat(self):
366n/a cur = self.con.cursor()
367n/a cur.execute("select checkType('float', ?)", (3.14,))
368n/a val = cur.fetchone()[0]
369n/a self.assertEqual(val, 1)
370n/a
371n/a def CheckAggrCheckParamNone(self):
372n/a cur = self.con.cursor()
373n/a cur.execute("select checkType('None', ?)", (None,))
374n/a val = cur.fetchone()[0]
375n/a self.assertEqual(val, 1)
376n/a
377n/a def CheckAggrCheckParamBlob(self):
378n/a cur = self.con.cursor()
379n/a cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
380n/a val = cur.fetchone()[0]
381n/a self.assertEqual(val, 1)
382n/a
383n/a def CheckAggrCheckAggrSum(self):
384n/a cur = self.con.cursor()
385n/a cur.execute("delete from test")
386n/a cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
387n/a cur.execute("select mysum(i) from test")
388n/a val = cur.fetchone()[0]
389n/a self.assertEqual(val, 60)
390n/a
391n/aclass AuthorizerTests(unittest.TestCase):
392n/a @staticmethod
393n/a def authorizer_cb(action, arg1, arg2, dbname, source):
394n/a if action != sqlite.SQLITE_SELECT:
395n/a return sqlite.SQLITE_DENY
396n/a if arg2 == 'c2' or arg1 == 't2':
397n/a return sqlite.SQLITE_DENY
398n/a return sqlite.SQLITE_OK
399n/a
400n/a def setUp(self):
401n/a self.con = sqlite.connect(":memory:")
402n/a self.con.executescript("""
403n/a create table t1 (c1, c2);
404n/a create table t2 (c1, c2);
405n/a insert into t1 (c1, c2) values (1, 2);
406n/a insert into t2 (c1, c2) values (4, 5);
407n/a """)
408n/a
409n/a # For our security test:
410n/a self.con.execute("select c2 from t2")
411n/a
412n/a self.con.set_authorizer(self.authorizer_cb)
413n/a
414n/a def tearDown(self):
415n/a pass
416n/a
417n/a def test_table_access(self):
418n/a with self.assertRaises(sqlite.DatabaseError) as cm:
419n/a self.con.execute("select * from t2")
420n/a self.assertIn('prohibited', str(cm.exception))
421n/a
422n/a def test_column_access(self):
423n/a with self.assertRaises(sqlite.DatabaseError) as cm:
424n/a self.con.execute("select c2 from t1")
425n/a self.assertIn('prohibited', str(cm.exception))
426n/a
427n/aclass AuthorizerRaiseExceptionTests(AuthorizerTests):
428n/a @staticmethod
429n/a def authorizer_cb(action, arg1, arg2, dbname, source):
430n/a if action != sqlite.SQLITE_SELECT:
431n/a raise ValueError
432n/a if arg2 == 'c2' or arg1 == 't2':
433n/a raise ValueError
434n/a return sqlite.SQLITE_OK
435n/a
436n/aclass AuthorizerIllegalTypeTests(AuthorizerTests):
437n/a @staticmethod
438n/a def authorizer_cb(action, arg1, arg2, dbname, source):
439n/a if action != sqlite.SQLITE_SELECT:
440n/a return 0.0
441n/a if arg2 == 'c2' or arg1 == 't2':
442n/a return 0.0
443n/a return sqlite.SQLITE_OK
444n/a
445n/aclass AuthorizerLargeIntegerTests(AuthorizerTests):
446n/a @staticmethod
447n/a def authorizer_cb(action, arg1, arg2, dbname, source):
448n/a if action != sqlite.SQLITE_SELECT:
449n/a return 2**32
450n/a if arg2 == 'c2' or arg1 == 't2':
451n/a return 2**32
452n/a return sqlite.SQLITE_OK
453n/a
454n/a
455n/adef suite():
456n/a function_suite = unittest.makeSuite(FunctionTests, "Check")
457n/a aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
458n/a authorizer_suite = unittest.makeSuite(AuthorizerTests)
459n/a return unittest.TestSuite((
460n/a function_suite,
461n/a aggregate_suite,
462n/a authorizer_suite,
463n/a unittest.makeSuite(AuthorizerRaiseExceptionTests),
464n/a unittest.makeSuite(AuthorizerIllegalTypeTests),
465n/a unittest.makeSuite(AuthorizerLargeIntegerTests),
466n/a ))
467n/a
468n/adef test():
469n/a runner = unittest.TextTestRunner()
470n/a runner.run(suite())
471n/a
472n/aif __name__ == "__main__":
473n/a test()