ยปCore Development>Code coverage>Lib/test/test_functools.py

Python code coverage for Lib/test/test_functools.py

#countcontent
1n/aimport abc
2n/aimport builtins
3n/aimport collections
4n/aimport copy
5n/afrom itertools import permutations
6n/aimport pickle
7n/afrom random import choice
8n/aimport sys
9n/afrom test import support
10n/aimport time
11n/aimport unittest
12n/aimport unittest.mock
13n/afrom weakref import proxy
14n/aimport contextlib
15n/atry:
16n/a import threading
17n/aexcept ImportError:
18n/a threading = None
19n/a
20n/aimport functools
21n/a
22n/apy_functools = support.import_fresh_module('functools', blocked=['_functools'])
23n/ac_functools = support.import_fresh_module('functools', fresh=['_functools'])
24n/a
25n/adecimal = support.import_fresh_module('decimal', fresh=['_decimal'])
26n/a
27n/a@contextlib.contextmanager
28n/adef replaced_module(name, replacement):
29n/a original_module = sys.modules[name]
30n/a sys.modules[name] = replacement
31n/a try:
32n/a yield
33n/a finally:
34n/a sys.modules[name] = original_module
35n/a
36n/adef capture(*args, **kw):
37n/a """capture all positional and keyword arguments"""
38n/a return args, kw
39n/a
40n/a
41n/adef signature(part):
42n/a """ return the signature of a partial object """
43n/a return (part.func, part.args, part.keywords, part.__dict__)
44n/a
45n/aclass MyTuple(tuple):
46n/a pass
47n/a
48n/aclass BadTuple(tuple):
49n/a def __add__(self, other):
50n/a return list(self) + list(other)
51n/a
52n/aclass MyDict(dict):
53n/a pass
54n/a
55n/a
56n/aclass TestPartial:
57n/a
58n/a def test_basic_examples(self):
59n/a p = self.partial(capture, 1, 2, a=10, b=20)
60n/a self.assertTrue(callable(p))
61n/a self.assertEqual(p(3, 4, b=30, c=40),
62n/a ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
63n/a p = self.partial(map, lambda x: x*10)
64n/a self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
65n/a
66n/a def test_attributes(self):
67n/a p = self.partial(capture, 1, 2, a=10, b=20)
68n/a # attributes should be readable
69n/a self.assertEqual(p.func, capture)
70n/a self.assertEqual(p.args, (1, 2))
71n/a self.assertEqual(p.keywords, dict(a=10, b=20))
72n/a
73n/a def test_argument_checking(self):
74n/a self.assertRaises(TypeError, self.partial) # need at least a func arg
75n/a try:
76n/a self.partial(2)()
77n/a except TypeError:
78n/a pass
79n/a else:
80n/a self.fail('First arg not checked for callability')
81n/a
82n/a def test_protection_of_callers_dict_argument(self):
83n/a # a caller's dictionary should not be altered by partial
84n/a def func(a=10, b=20):
85n/a return a
86n/a d = {'a':3}
87n/a p = self.partial(func, a=5)
88n/a self.assertEqual(p(**d), 3)
89n/a self.assertEqual(d, {'a':3})
90n/a p(b=7)
91n/a self.assertEqual(d, {'a':3})
92n/a
93n/a def test_arg_combinations(self):
94n/a # exercise special code paths for zero args in either partial
95n/a # object or the caller
96n/a p = self.partial(capture)
97n/a self.assertEqual(p(), ((), {}))
98n/a self.assertEqual(p(1,2), ((1,2), {}))
99n/a p = self.partial(capture, 1, 2)
100n/a self.assertEqual(p(), ((1,2), {}))
101n/a self.assertEqual(p(3,4), ((1,2,3,4), {}))
102n/a
103n/a def test_kw_combinations(self):
104n/a # exercise special code paths for no keyword args in
105n/a # either the partial object or the caller
106n/a p = self.partial(capture)
107n/a self.assertEqual(p.keywords, {})
108n/a self.assertEqual(p(), ((), {}))
109n/a self.assertEqual(p(a=1), ((), {'a':1}))
110n/a p = self.partial(capture, a=1)
111n/a self.assertEqual(p.keywords, {'a':1})
112n/a self.assertEqual(p(), ((), {'a':1}))
113n/a self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
114n/a # keyword args in the call override those in the partial object
115n/a self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
116n/a
117n/a def test_positional(self):
118n/a # make sure positional arguments are captured correctly
119n/a for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
120n/a p = self.partial(capture, *args)
121n/a expected = args + ('x',)
122n/a got, empty = p('x')
123n/a self.assertTrue(expected == got and empty == {})
124n/a
125n/a def test_keyword(self):
126n/a # make sure keyword arguments are captured correctly
127n/a for a in ['a', 0, None, 3.5]:
128n/a p = self.partial(capture, a=a)
129n/a expected = {'a':a,'x':None}
130n/a empty, got = p(x=None)
131n/a self.assertTrue(expected == got and empty == ())
132n/a
133n/a def test_no_side_effects(self):
134n/a # make sure there are no side effects that affect subsequent calls
135n/a p = self.partial(capture, 0, a=1)
136n/a args1, kw1 = p(1, b=2)
137n/a self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
138n/a args2, kw2 = p()
139n/a self.assertTrue(args2 == (0,) and kw2 == {'a':1})
140n/a
141n/a def test_error_propagation(self):
142n/a def f(x, y):
143n/a x / y
144n/a self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
145n/a self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
146n/a self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
147n/a self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
148n/a
149n/a def test_weakref(self):
150n/a f = self.partial(int, base=16)
151n/a p = proxy(f)
152n/a self.assertEqual(f.func, p.func)
153n/a f = None
154n/a self.assertRaises(ReferenceError, getattr, p, 'func')
155n/a
156n/a def test_with_bound_and_unbound_methods(self):
157n/a data = list(map(str, range(10)))
158n/a join = self.partial(str.join, '')
159n/a self.assertEqual(join(data), '0123456789')
160n/a join = self.partial(''.join)
161n/a self.assertEqual(join(data), '0123456789')
162n/a
163n/a def test_nested_optimization(self):
164n/a partial = self.partial
165n/a inner = partial(signature, 'asdf')
166n/a nested = partial(inner, bar=True)
167n/a flat = partial(signature, 'asdf', bar=True)
168n/a self.assertEqual(signature(nested), signature(flat))
169n/a
170n/a def test_nested_partial_with_attribute(self):
171n/a # see issue 25137
172n/a partial = self.partial
173n/a
174n/a def foo(bar):
175n/a return bar
176n/a
177n/a p = partial(foo, 'first')
178n/a p2 = partial(p, 'second')
179n/a p2.new_attr = 'spam'
180n/a self.assertEqual(p2.new_attr, 'spam')
181n/a
182n/a def test_repr(self):
183n/a args = (object(), object())
184n/a args_repr = ', '.join(repr(a) for a in args)
185n/a kwargs = {'a': object(), 'b': object()}
186n/a kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
187n/a 'b={b!r}, a={a!r}'.format_map(kwargs)]
188n/a if self.partial in (c_functools.partial, py_functools.partial):
189n/a name = 'functools.partial'
190n/a else:
191n/a name = self.partial.__name__
192n/a
193n/a f = self.partial(capture)
194n/a self.assertEqual(f'{name}({capture!r})', repr(f))
195n/a
196n/a f = self.partial(capture, *args)
197n/a self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
198n/a
199n/a f = self.partial(capture, **kwargs)
200n/a self.assertIn(repr(f),
201n/a [f'{name}({capture!r}, {kwargs_repr})'
202n/a for kwargs_repr in kwargs_reprs])
203n/a
204n/a f = self.partial(capture, *args, **kwargs)
205n/a self.assertIn(repr(f),
206n/a [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
207n/a for kwargs_repr in kwargs_reprs])
208n/a
209n/a def test_recursive_repr(self):
210n/a if self.partial in (c_functools.partial, py_functools.partial):
211n/a name = 'functools.partial'
212n/a else:
213n/a name = self.partial.__name__
214n/a
215n/a f = self.partial(capture)
216n/a f.__setstate__((f, (), {}, {}))
217n/a try:
218n/a self.assertEqual(repr(f), '%s(...)' % (name,))
219n/a finally:
220n/a f.__setstate__((capture, (), {}, {}))
221n/a
222n/a f = self.partial(capture)
223n/a f.__setstate__((capture, (f,), {}, {}))
224n/a try:
225n/a self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
226n/a finally:
227n/a f.__setstate__((capture, (), {}, {}))
228n/a
229n/a f = self.partial(capture)
230n/a f.__setstate__((capture, (), {'a': f}, {}))
231n/a try:
232n/a self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
233n/a finally:
234n/a f.__setstate__((capture, (), {}, {}))
235n/a
236n/a def test_pickle(self):
237n/a with self.AllowPickle():
238n/a f = self.partial(signature, ['asdf'], bar=[True])
239n/a f.attr = []
240n/a for proto in range(pickle.HIGHEST_PROTOCOL + 1):
241n/a f_copy = pickle.loads(pickle.dumps(f, proto))
242n/a self.assertEqual(signature(f_copy), signature(f))
243n/a
244n/a def test_copy(self):
245n/a f = self.partial(signature, ['asdf'], bar=[True])
246n/a f.attr = []
247n/a f_copy = copy.copy(f)
248n/a self.assertEqual(signature(f_copy), signature(f))
249n/a self.assertIs(f_copy.attr, f.attr)
250n/a self.assertIs(f_copy.args, f.args)
251n/a self.assertIs(f_copy.keywords, f.keywords)
252n/a
253n/a def test_deepcopy(self):
254n/a f = self.partial(signature, ['asdf'], bar=[True])
255n/a f.attr = []
256n/a f_copy = copy.deepcopy(f)
257n/a self.assertEqual(signature(f_copy), signature(f))
258n/a self.assertIsNot(f_copy.attr, f.attr)
259n/a self.assertIsNot(f_copy.args, f.args)
260n/a self.assertIsNot(f_copy.args[0], f.args[0])
261n/a self.assertIsNot(f_copy.keywords, f.keywords)
262n/a self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
263n/a
264n/a def test_setstate(self):
265n/a f = self.partial(signature)
266n/a f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
267n/a
268n/a self.assertEqual(signature(f),
269n/a (capture, (1,), dict(a=10), dict(attr=[])))
270n/a self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
271n/a
272n/a f.__setstate__((capture, (1,), dict(a=10), None))
273n/a
274n/a self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
275n/a self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
276n/a
277n/a f.__setstate__((capture, (1,), None, None))
278n/a #self.assertEqual(signature(f), (capture, (1,), {}, {}))
279n/a self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
280n/a self.assertEqual(f(2), ((1, 2), {}))
281n/a self.assertEqual(f(), ((1,), {}))
282n/a
283n/a f.__setstate__((capture, (), {}, None))
284n/a self.assertEqual(signature(f), (capture, (), {}, {}))
285n/a self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
286n/a self.assertEqual(f(2), ((2,), {}))
287n/a self.assertEqual(f(), ((), {}))
288n/a
289n/a def test_setstate_errors(self):
290n/a f = self.partial(signature)
291n/a self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
292n/a self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
293n/a self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
294n/a self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
295n/a self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
296n/a self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
297n/a self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
298n/a
299n/a def test_setstate_subclasses(self):
300n/a f = self.partial(signature)
301n/a f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
302n/a s = signature(f)
303n/a self.assertEqual(s, (capture, (1,), dict(a=10), {}))
304n/a self.assertIs(type(s[1]), tuple)
305n/a self.assertIs(type(s[2]), dict)
306n/a r = f()
307n/a self.assertEqual(r, ((1,), {'a': 10}))
308n/a self.assertIs(type(r[0]), tuple)
309n/a self.assertIs(type(r[1]), dict)
310n/a
311n/a f.__setstate__((capture, BadTuple((1,)), {}, None))
312n/a s = signature(f)
313n/a self.assertEqual(s, (capture, (1,), {}, {}))
314n/a self.assertIs(type(s[1]), tuple)
315n/a r = f(2)
316n/a self.assertEqual(r, ((1, 2), {}))
317n/a self.assertIs(type(r[0]), tuple)
318n/a
319n/a def test_recursive_pickle(self):
320n/a with self.AllowPickle():
321n/a f = self.partial(capture)
322n/a f.__setstate__((f, (), {}, {}))
323n/a try:
324n/a for proto in range(pickle.HIGHEST_PROTOCOL + 1):
325n/a with self.assertRaises(RecursionError):
326n/a pickle.dumps(f, proto)
327n/a finally:
328n/a f.__setstate__((capture, (), {}, {}))
329n/a
330n/a f = self.partial(capture)
331n/a f.__setstate__((capture, (f,), {}, {}))
332n/a try:
333n/a for proto in range(pickle.HIGHEST_PROTOCOL + 1):
334n/a f_copy = pickle.loads(pickle.dumps(f, proto))
335n/a try:
336n/a self.assertIs(f_copy.args[0], f_copy)
337n/a finally:
338n/a f_copy.__setstate__((capture, (), {}, {}))
339n/a finally:
340n/a f.__setstate__((capture, (), {}, {}))
341n/a
342n/a f = self.partial(capture)
343n/a f.__setstate__((capture, (), {'a': f}, {}))
344n/a try:
345n/a for proto in range(pickle.HIGHEST_PROTOCOL + 1):
346n/a f_copy = pickle.loads(pickle.dumps(f, proto))
347n/a try:
348n/a self.assertIs(f_copy.keywords['a'], f_copy)
349n/a finally:
350n/a f_copy.__setstate__((capture, (), {}, {}))
351n/a finally:
352n/a f.__setstate__((capture, (), {}, {}))
353n/a
354n/a # Issue 6083: Reference counting bug
355n/a def test_setstate_refcount(self):
356n/a class BadSequence:
357n/a def __len__(self):
358n/a return 4
359n/a def __getitem__(self, key):
360n/a if key == 0:
361n/a return max
362n/a elif key == 1:
363n/a return tuple(range(1000000))
364n/a elif key in (2, 3):
365n/a return {}
366n/a raise IndexError
367n/a
368n/a f = self.partial(object)
369n/a self.assertRaises(TypeError, f.__setstate__, BadSequence())
370n/a
371n/a@unittest.skipUnless(c_functools, 'requires the C _functools module')
372n/aclass TestPartialC(TestPartial, unittest.TestCase):
373n/a if c_functools:
374n/a partial = c_functools.partial
375n/a
376n/a class AllowPickle:
377n/a def __enter__(self):
378n/a return self
379n/a def __exit__(self, type, value, tb):
380n/a return False
381n/a
382n/a def test_attributes_unwritable(self):
383n/a # attributes should not be writable
384n/a p = self.partial(capture, 1, 2, a=10, b=20)
385n/a self.assertRaises(AttributeError, setattr, p, 'func', map)
386n/a self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
387n/a self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
388n/a
389n/a p = self.partial(hex)
390n/a try:
391n/a del p.__dict__
392n/a except TypeError:
393n/a pass
394n/a else:
395n/a self.fail('partial object allowed __dict__ to be deleted')
396n/a
397n/aclass TestPartialPy(TestPartial, unittest.TestCase):
398n/a partial = py_functools.partial
399n/a
400n/a class AllowPickle:
401n/a def __init__(self):
402n/a self._cm = replaced_module("functools", py_functools)
403n/a def __enter__(self):
404n/a return self._cm.__enter__()
405n/a def __exit__(self, type, value, tb):
406n/a return self._cm.__exit__(type, value, tb)
407n/a
408n/aif c_functools:
409n/a class CPartialSubclass(c_functools.partial):
410n/a pass
411n/a
412n/aclass PyPartialSubclass(py_functools.partial):
413n/a pass
414n/a
415n/a@unittest.skipUnless(c_functools, 'requires the C _functools module')
416n/aclass TestPartialCSubclass(TestPartialC):
417n/a if c_functools:
418n/a partial = CPartialSubclass
419n/a
420n/a # partial subclasses are not optimized for nested calls
421n/a test_nested_optimization = None
422n/a
423n/aclass TestPartialPySubclass(TestPartialPy):
424n/a partial = PyPartialSubclass
425n/a
426n/aclass TestPartialMethod(unittest.TestCase):
427n/a
428n/a class A(object):
429n/a nothing = functools.partialmethod(capture)
430n/a positional = functools.partialmethod(capture, 1)
431n/a keywords = functools.partialmethod(capture, a=2)
432n/a both = functools.partialmethod(capture, 3, b=4)
433n/a
434n/a nested = functools.partialmethod(positional, 5)
435n/a
436n/a over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
437n/a
438n/a static = functools.partialmethod(staticmethod(capture), 8)
439n/a cls = functools.partialmethod(classmethod(capture), d=9)
440n/a
441n/a a = A()
442n/a
443n/a def test_arg_combinations(self):
444n/a self.assertEqual(self.a.nothing(), ((self.a,), {}))
445n/a self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
446n/a self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
447n/a self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
448n/a
449n/a self.assertEqual(self.a.positional(), ((self.a, 1), {}))
450n/a self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
451n/a self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
452n/a self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
453n/a
454n/a self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
455n/a self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
456n/a self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
457n/a self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
458n/a
459n/a self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
460n/a self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
461n/a self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
462n/a self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
463n/a
464n/a self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
465n/a
466n/a def test_nested(self):
467n/a self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
468n/a self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
469n/a self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
470n/a self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
471n/a
472n/a self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
473n/a
474n/a def test_over_partial(self):
475n/a self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
476n/a self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
477n/a self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
478n/a self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
479n/a
480n/a self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
481n/a
482n/a def test_bound_method_introspection(self):
483n/a obj = self.a
484n/a self.assertIs(obj.both.__self__, obj)
485n/a self.assertIs(obj.nested.__self__, obj)
486n/a self.assertIs(obj.over_partial.__self__, obj)
487n/a self.assertIs(obj.cls.__self__, self.A)
488n/a self.assertIs(self.A.cls.__self__, self.A)
489n/a
490n/a def test_unbound_method_retrieval(self):
491n/a obj = self.A
492n/a self.assertFalse(hasattr(obj.both, "__self__"))
493n/a self.assertFalse(hasattr(obj.nested, "__self__"))
494n/a self.assertFalse(hasattr(obj.over_partial, "__self__"))
495n/a self.assertFalse(hasattr(obj.static, "__self__"))
496n/a self.assertFalse(hasattr(self.a.static, "__self__"))
497n/a
498n/a def test_descriptors(self):
499n/a for obj in [self.A, self.a]:
500n/a with self.subTest(obj=obj):
501n/a self.assertEqual(obj.static(), ((8,), {}))
502n/a self.assertEqual(obj.static(5), ((8, 5), {}))
503n/a self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
504n/a self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
505n/a
506n/a self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
507n/a self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
508n/a self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
509n/a self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
510n/a
511n/a def test_overriding_keywords(self):
512n/a self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
513n/a self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
514n/a
515n/a def test_invalid_args(self):
516n/a with self.assertRaises(TypeError):
517n/a class B(object):
518n/a method = functools.partialmethod(None, 1)
519n/a
520n/a def test_repr(self):
521n/a self.assertEqual(repr(vars(self.A)['both']),
522n/a 'functools.partialmethod({}, 3, b=4)'.format(capture))
523n/a
524n/a def test_abstract(self):
525n/a class Abstract(abc.ABCMeta):
526n/a
527n/a @abc.abstractmethod
528n/a def add(self, x, y):
529n/a pass
530n/a
531n/a add5 = functools.partialmethod(add, 5)
532n/a
533n/a self.assertTrue(Abstract.add.__isabstractmethod__)
534n/a self.assertTrue(Abstract.add5.__isabstractmethod__)
535n/a
536n/a for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
537n/a self.assertFalse(getattr(func, '__isabstractmethod__', False))
538n/a
539n/a
540n/aclass TestUpdateWrapper(unittest.TestCase):
541n/a
542n/a def check_wrapper(self, wrapper, wrapped,
543n/a assigned=functools.WRAPPER_ASSIGNMENTS,
544n/a updated=functools.WRAPPER_UPDATES):
545n/a # Check attributes were assigned
546n/a for name in assigned:
547n/a self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
548n/a # Check attributes were updated
549n/a for name in updated:
550n/a wrapper_attr = getattr(wrapper, name)
551n/a wrapped_attr = getattr(wrapped, name)
552n/a for key in wrapped_attr:
553n/a if name == "__dict__" and key == "__wrapped__":
554n/a # __wrapped__ is overwritten by the update code
555n/a continue
556n/a self.assertIs(wrapped_attr[key], wrapper_attr[key])
557n/a # Check __wrapped__
558n/a self.assertIs(wrapper.__wrapped__, wrapped)
559n/a
560n/a
561n/a def _default_update(self):
562n/a def f(a:'This is a new annotation'):
563n/a """This is a test"""
564n/a pass
565n/a f.attr = 'This is also a test'
566n/a f.__wrapped__ = "This is a bald faced lie"
567n/a def wrapper(b:'This is the prior annotation'):
568n/a pass
569n/a functools.update_wrapper(wrapper, f)
570n/a return wrapper, f
571n/a
572n/a def test_default_update(self):
573n/a wrapper, f = self._default_update()
574n/a self.check_wrapper(wrapper, f)
575n/a self.assertIs(wrapper.__wrapped__, f)
576n/a self.assertEqual(wrapper.__name__, 'f')
577n/a self.assertEqual(wrapper.__qualname__, f.__qualname__)
578n/a self.assertEqual(wrapper.attr, 'This is also a test')
579n/a self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
580n/a self.assertNotIn('b', wrapper.__annotations__)
581n/a
582n/a @unittest.skipIf(sys.flags.optimize >= 2,
583n/a "Docstrings are omitted with -O2 and above")
584n/a def test_default_update_doc(self):
585n/a wrapper, f = self._default_update()
586n/a self.assertEqual(wrapper.__doc__, 'This is a test')
587n/a
588n/a def test_no_update(self):
589n/a def f():
590n/a """This is a test"""
591n/a pass
592n/a f.attr = 'This is also a test'
593n/a def wrapper():
594n/a pass
595n/a functools.update_wrapper(wrapper, f, (), ())
596n/a self.check_wrapper(wrapper, f, (), ())
597n/a self.assertEqual(wrapper.__name__, 'wrapper')
598n/a self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
599n/a self.assertEqual(wrapper.__doc__, None)
600n/a self.assertEqual(wrapper.__annotations__, {})
601n/a self.assertFalse(hasattr(wrapper, 'attr'))
602n/a
603n/a def test_selective_update(self):
604n/a def f():
605n/a pass
606n/a f.attr = 'This is a different test'
607n/a f.dict_attr = dict(a=1, b=2, c=3)
608n/a def wrapper():
609n/a pass
610n/a wrapper.dict_attr = {}
611n/a assign = ('attr',)
612n/a update = ('dict_attr',)
613n/a functools.update_wrapper(wrapper, f, assign, update)
614n/a self.check_wrapper(wrapper, f, assign, update)
615n/a self.assertEqual(wrapper.__name__, 'wrapper')
616n/a self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
617n/a self.assertEqual(wrapper.__doc__, None)
618n/a self.assertEqual(wrapper.attr, 'This is a different test')
619n/a self.assertEqual(wrapper.dict_attr, f.dict_attr)
620n/a
621n/a def test_missing_attributes(self):
622n/a def f():
623n/a pass
624n/a def wrapper():
625n/a pass
626n/a wrapper.dict_attr = {}
627n/a assign = ('attr',)
628n/a update = ('dict_attr',)
629n/a # Missing attributes on wrapped object are ignored
630n/a functools.update_wrapper(wrapper, f, assign, update)
631n/a self.assertNotIn('attr', wrapper.__dict__)
632n/a self.assertEqual(wrapper.dict_attr, {})
633n/a # Wrapper must have expected attributes for updating
634n/a del wrapper.dict_attr
635n/a with self.assertRaises(AttributeError):
636n/a functools.update_wrapper(wrapper, f, assign, update)
637n/a wrapper.dict_attr = 1
638n/a with self.assertRaises(AttributeError):
639n/a functools.update_wrapper(wrapper, f, assign, update)
640n/a
641n/a @support.requires_docstrings
642n/a @unittest.skipIf(sys.flags.optimize >= 2,
643n/a "Docstrings are omitted with -O2 and above")
644n/a def test_builtin_update(self):
645n/a # Test for bug #1576241
646n/a def wrapper():
647n/a pass
648n/a functools.update_wrapper(wrapper, max)
649n/a self.assertEqual(wrapper.__name__, 'max')
650n/a self.assertTrue(wrapper.__doc__.startswith('max('))
651n/a self.assertEqual(wrapper.__annotations__, {})
652n/a
653n/a
654n/aclass TestWraps(TestUpdateWrapper):
655n/a
656n/a def _default_update(self):
657n/a def f():
658n/a """This is a test"""
659n/a pass
660n/a f.attr = 'This is also a test'
661n/a f.__wrapped__ = "This is still a bald faced lie"
662n/a @functools.wraps(f)
663n/a def wrapper():
664n/a pass
665n/a return wrapper, f
666n/a
667n/a def test_default_update(self):
668n/a wrapper, f = self._default_update()
669n/a self.check_wrapper(wrapper, f)
670n/a self.assertEqual(wrapper.__name__, 'f')
671n/a self.assertEqual(wrapper.__qualname__, f.__qualname__)
672n/a self.assertEqual(wrapper.attr, 'This is also a test')
673n/a
674n/a @unittest.skipIf(sys.flags.optimize >= 2,
675n/a "Docstrings are omitted with -O2 and above")
676n/a def test_default_update_doc(self):
677n/a wrapper, _ = self._default_update()
678n/a self.assertEqual(wrapper.__doc__, 'This is a test')
679n/a
680n/a def test_no_update(self):
681n/a def f():
682n/a """This is a test"""
683n/a pass
684n/a f.attr = 'This is also a test'
685n/a @functools.wraps(f, (), ())
686n/a def wrapper():
687n/a pass
688n/a self.check_wrapper(wrapper, f, (), ())
689n/a self.assertEqual(wrapper.__name__, 'wrapper')
690n/a self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
691n/a self.assertEqual(wrapper.__doc__, None)
692n/a self.assertFalse(hasattr(wrapper, 'attr'))
693n/a
694n/a def test_selective_update(self):
695n/a def f():
696n/a pass
697n/a f.attr = 'This is a different test'
698n/a f.dict_attr = dict(a=1, b=2, c=3)
699n/a def add_dict_attr(f):
700n/a f.dict_attr = {}
701n/a return f
702n/a assign = ('attr',)
703n/a update = ('dict_attr',)
704n/a @functools.wraps(f, assign, update)
705n/a @add_dict_attr
706n/a def wrapper():
707n/a pass
708n/a self.check_wrapper(wrapper, f, assign, update)
709n/a self.assertEqual(wrapper.__name__, 'wrapper')
710n/a self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
711n/a self.assertEqual(wrapper.__doc__, None)
712n/a self.assertEqual(wrapper.attr, 'This is a different test')
713n/a self.assertEqual(wrapper.dict_attr, f.dict_attr)
714n/a
715n/a@unittest.skipUnless(c_functools, 'requires the C _functools module')
716n/aclass TestReduce(unittest.TestCase):
717n/a if c_functools:
718n/a func = c_functools.reduce
719n/a
720n/a def test_reduce(self):
721n/a class Squares:
722n/a def __init__(self, max):
723n/a self.max = max
724n/a self.sofar = []
725n/a
726n/a def __len__(self):
727n/a return len(self.sofar)
728n/a
729n/a def __getitem__(self, i):
730n/a if not 0 <= i < self.max: raise IndexError
731n/a n = len(self.sofar)
732n/a while n <= i:
733n/a self.sofar.append(n*n)
734n/a n += 1
735n/a return self.sofar[i]
736n/a def add(x, y):
737n/a return x + y
738n/a self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
739n/a self.assertEqual(
740n/a self.func(add, [['a', 'c'], [], ['d', 'w']], []),
741n/a ['a','c','d','w']
742n/a )
743n/a self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
744n/a self.assertEqual(
745n/a self.func(lambda x, y: x*y, range(2,21), 1),
746n/a 2432902008176640000
747n/a )
748n/a self.assertEqual(self.func(add, Squares(10)), 285)
749n/a self.assertEqual(self.func(add, Squares(10), 0), 285)
750n/a self.assertEqual(self.func(add, Squares(0), 0), 0)
751n/a self.assertRaises(TypeError, self.func)
752n/a self.assertRaises(TypeError, self.func, 42, 42)
753n/a self.assertRaises(TypeError, self.func, 42, 42, 42)
754n/a self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
755n/a self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
756n/a self.assertRaises(TypeError, self.func, 42, (42, 42))
757n/a self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
758n/a self.assertRaises(TypeError, self.func, add, "")
759n/a self.assertRaises(TypeError, self.func, add, ())
760n/a self.assertRaises(TypeError, self.func, add, object())
761n/a
762n/a class TestFailingIter:
763n/a def __iter__(self):
764n/a raise RuntimeError
765n/a self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
766n/a
767n/a self.assertEqual(self.func(add, [], None), None)
768n/a self.assertEqual(self.func(add, [], 42), 42)
769n/a
770n/a class BadSeq:
771n/a def __getitem__(self, index):
772n/a raise ValueError
773n/a self.assertRaises(ValueError, self.func, 42, BadSeq())
774n/a
775n/a # Test reduce()'s use of iterators.
776n/a def test_iterator_usage(self):
777n/a class SequenceClass:
778n/a def __init__(self, n):
779n/a self.n = n
780n/a def __getitem__(self, i):
781n/a if 0 <= i < self.n:
782n/a return i
783n/a else:
784n/a raise IndexError
785n/a
786n/a from operator import add
787n/a self.assertEqual(self.func(add, SequenceClass(5)), 10)
788n/a self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
789n/a self.assertRaises(TypeError, self.func, add, SequenceClass(0))
790n/a self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
791n/a self.assertEqual(self.func(add, SequenceClass(1)), 0)
792n/a self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
793n/a
794n/a d = {"one": 1, "two": 2, "three": 3}
795n/a self.assertEqual(self.func(add, d), "".join(d.keys()))
796n/a
797n/a
798n/aclass TestCmpToKey:
799n/a
800n/a def test_cmp_to_key(self):
801n/a def cmp1(x, y):
802n/a return (x > y) - (x < y)
803n/a key = self.cmp_to_key(cmp1)
804n/a self.assertEqual(key(3), key(3))
805n/a self.assertGreater(key(3), key(1))
806n/a self.assertGreaterEqual(key(3), key(3))
807n/a
808n/a def cmp2(x, y):
809n/a return int(x) - int(y)
810n/a key = self.cmp_to_key(cmp2)
811n/a self.assertEqual(key(4.0), key('4'))
812n/a self.assertLess(key(2), key('35'))
813n/a self.assertLessEqual(key(2), key('35'))
814n/a self.assertNotEqual(key(2), key('35'))
815n/a
816n/a def test_cmp_to_key_arguments(self):
817n/a def cmp1(x, y):
818n/a return (x > y) - (x < y)
819n/a key = self.cmp_to_key(mycmp=cmp1)
820n/a self.assertEqual(key(obj=3), key(obj=3))
821n/a self.assertGreater(key(obj=3), key(obj=1))
822n/a with self.assertRaises((TypeError, AttributeError)):
823n/a key(3) > 1 # rhs is not a K object
824n/a with self.assertRaises((TypeError, AttributeError)):
825n/a 1 < key(3) # lhs is not a K object
826n/a with self.assertRaises(TypeError):
827n/a key = self.cmp_to_key() # too few args
828n/a with self.assertRaises(TypeError):
829n/a key = self.cmp_to_key(cmp1, None) # too many args
830n/a key = self.cmp_to_key(cmp1)
831n/a with self.assertRaises(TypeError):
832n/a key() # too few args
833n/a with self.assertRaises(TypeError):
834n/a key(None, None) # too many args
835n/a
836n/a def test_bad_cmp(self):
837n/a def cmp1(x, y):
838n/a raise ZeroDivisionError
839n/a key = self.cmp_to_key(cmp1)
840n/a with self.assertRaises(ZeroDivisionError):
841n/a key(3) > key(1)
842n/a
843n/a class BadCmp:
844n/a def __lt__(self, other):
845n/a raise ZeroDivisionError
846n/a def cmp1(x, y):
847n/a return BadCmp()
848n/a with self.assertRaises(ZeroDivisionError):
849n/a key(3) > key(1)
850n/a
851n/a def test_obj_field(self):
852n/a def cmp1(x, y):
853n/a return (x > y) - (x < y)
854n/a key = self.cmp_to_key(mycmp=cmp1)
855n/a self.assertEqual(key(50).obj, 50)
856n/a
857n/a def test_sort_int(self):
858n/a def mycmp(x, y):
859n/a return y - x
860n/a self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
861n/a [4, 3, 2, 1, 0])
862n/a
863n/a def test_sort_int_str(self):
864n/a def mycmp(x, y):
865n/a x, y = int(x), int(y)
866n/a return (x > y) - (x < y)
867n/a values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
868n/a values = sorted(values, key=self.cmp_to_key(mycmp))
869n/a self.assertEqual([int(value) for value in values],
870n/a [0, 1, 1, 2, 3, 4, 5, 7, 10])
871n/a
872n/a def test_hash(self):
873n/a def mycmp(x, y):
874n/a return y - x
875n/a key = self.cmp_to_key(mycmp)
876n/a k = key(10)
877n/a self.assertRaises(TypeError, hash, k)
878n/a self.assertNotIsInstance(k, collections.Hashable)
879n/a
880n/a
881n/a@unittest.skipUnless(c_functools, 'requires the C _functools module')
882n/aclass TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
883n/a if c_functools:
884n/a cmp_to_key = c_functools.cmp_to_key
885n/a
886n/a
887n/aclass TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
888n/a cmp_to_key = staticmethod(py_functools.cmp_to_key)
889n/a
890n/a
891n/aclass TestTotalOrdering(unittest.TestCase):
892n/a
893n/a def test_total_ordering_lt(self):
894n/a @functools.total_ordering
895n/a class A:
896n/a def __init__(self, value):
897n/a self.value = value
898n/a def __lt__(self, other):
899n/a return self.value < other.value
900n/a def __eq__(self, other):
901n/a return self.value == other.value
902n/a self.assertTrue(A(1) < A(2))
903n/a self.assertTrue(A(2) > A(1))
904n/a self.assertTrue(A(1) <= A(2))
905n/a self.assertTrue(A(2) >= A(1))
906n/a self.assertTrue(A(2) <= A(2))
907n/a self.assertTrue(A(2) >= A(2))
908n/a self.assertFalse(A(1) > A(2))
909n/a
910n/a def test_total_ordering_le(self):
911n/a @functools.total_ordering
912n/a class A:
913n/a def __init__(self, value):
914n/a self.value = value
915n/a def __le__(self, other):
916n/a return self.value <= other.value
917n/a def __eq__(self, other):
918n/a return self.value == other.value
919n/a self.assertTrue(A(1) < A(2))
920n/a self.assertTrue(A(2) > A(1))
921n/a self.assertTrue(A(1) <= A(2))
922n/a self.assertTrue(A(2) >= A(1))
923n/a self.assertTrue(A(2) <= A(2))
924n/a self.assertTrue(A(2) >= A(2))
925n/a self.assertFalse(A(1) >= A(2))
926n/a
927n/a def test_total_ordering_gt(self):
928n/a @functools.total_ordering
929n/a class A:
930n/a def __init__(self, value):
931n/a self.value = value
932n/a def __gt__(self, other):
933n/a return self.value > other.value
934n/a def __eq__(self, other):
935n/a return self.value == other.value
936n/a self.assertTrue(A(1) < A(2))
937n/a self.assertTrue(A(2) > A(1))
938n/a self.assertTrue(A(1) <= A(2))
939n/a self.assertTrue(A(2) >= A(1))
940n/a self.assertTrue(A(2) <= A(2))
941n/a self.assertTrue(A(2) >= A(2))
942n/a self.assertFalse(A(2) < A(1))
943n/a
944n/a def test_total_ordering_ge(self):
945n/a @functools.total_ordering
946n/a class A:
947n/a def __init__(self, value):
948n/a self.value = value
949n/a def __ge__(self, other):
950n/a return self.value >= other.value
951n/a def __eq__(self, other):
952n/a return self.value == other.value
953n/a self.assertTrue(A(1) < A(2))
954n/a self.assertTrue(A(2) > A(1))
955n/a self.assertTrue(A(1) <= A(2))
956n/a self.assertTrue(A(2) >= A(1))
957n/a self.assertTrue(A(2) <= A(2))
958n/a self.assertTrue(A(2) >= A(2))
959n/a self.assertFalse(A(2) <= A(1))
960n/a
961n/a def test_total_ordering_no_overwrite(self):
962n/a # new methods should not overwrite existing
963n/a @functools.total_ordering
964n/a class A(int):
965n/a pass
966n/a self.assertTrue(A(1) < A(2))
967n/a self.assertTrue(A(2) > A(1))
968n/a self.assertTrue(A(1) <= A(2))
969n/a self.assertTrue(A(2) >= A(1))
970n/a self.assertTrue(A(2) <= A(2))
971n/a self.assertTrue(A(2) >= A(2))
972n/a
973n/a def test_no_operations_defined(self):
974n/a with self.assertRaises(ValueError):
975n/a @functools.total_ordering
976n/a class A:
977n/a pass
978n/a
979n/a def test_type_error_when_not_implemented(self):
980n/a # bug 10042; ensure stack overflow does not occur
981n/a # when decorated types return NotImplemented
982n/a @functools.total_ordering
983n/a class ImplementsLessThan:
984n/a def __init__(self, value):
985n/a self.value = value
986n/a def __eq__(self, other):
987n/a if isinstance(other, ImplementsLessThan):
988n/a return self.value == other.value
989n/a return False
990n/a def __lt__(self, other):
991n/a if isinstance(other, ImplementsLessThan):
992n/a return self.value < other.value
993n/a return NotImplemented
994n/a
995n/a @functools.total_ordering
996n/a class ImplementsGreaterThan:
997n/a def __init__(self, value):
998n/a self.value = value
999n/a def __eq__(self, other):
1000n/a if isinstance(other, ImplementsGreaterThan):
1001n/a return self.value == other.value
1002n/a return False
1003n/a def __gt__(self, other):
1004n/a if isinstance(other, ImplementsGreaterThan):
1005n/a return self.value > other.value
1006n/a return NotImplemented
1007n/a
1008n/a @functools.total_ordering
1009n/a class ImplementsLessThanEqualTo:
1010n/a def __init__(self, value):
1011n/a self.value = value
1012n/a def __eq__(self, other):
1013n/a if isinstance(other, ImplementsLessThanEqualTo):
1014n/a return self.value == other.value
1015n/a return False
1016n/a def __le__(self, other):
1017n/a if isinstance(other, ImplementsLessThanEqualTo):
1018n/a return self.value <= other.value
1019n/a return NotImplemented
1020n/a
1021n/a @functools.total_ordering
1022n/a class ImplementsGreaterThanEqualTo:
1023n/a def __init__(self, value):
1024n/a self.value = value
1025n/a def __eq__(self, other):
1026n/a if isinstance(other, ImplementsGreaterThanEqualTo):
1027n/a return self.value == other.value
1028n/a return False
1029n/a def __ge__(self, other):
1030n/a if isinstance(other, ImplementsGreaterThanEqualTo):
1031n/a return self.value >= other.value
1032n/a return NotImplemented
1033n/a
1034n/a @functools.total_ordering
1035n/a class ComparatorNotImplemented:
1036n/a def __init__(self, value):
1037n/a self.value = value
1038n/a def __eq__(self, other):
1039n/a if isinstance(other, ComparatorNotImplemented):
1040n/a return self.value == other.value
1041n/a return False
1042n/a def __lt__(self, other):
1043n/a return NotImplemented
1044n/a
1045n/a with self.subTest("LT < 1"), self.assertRaises(TypeError):
1046n/a ImplementsLessThan(-1) < 1
1047n/a
1048n/a with self.subTest("LT < LE"), self.assertRaises(TypeError):
1049n/a ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1050n/a
1051n/a with self.subTest("LT < GT"), self.assertRaises(TypeError):
1052n/a ImplementsLessThan(1) < ImplementsGreaterThan(1)
1053n/a
1054n/a with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1055n/a ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1056n/a
1057n/a with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1058n/a ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1059n/a
1060n/a with self.subTest("GT > GE"), self.assertRaises(TypeError):
1061n/a ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1062n/a
1063n/a with self.subTest("GT > LT"), self.assertRaises(TypeError):
1064n/a ImplementsGreaterThan(5) > ImplementsLessThan(5)
1065n/a
1066n/a with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1067n/a ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1068n/a
1069n/a with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1070n/a ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1071n/a
1072n/a with self.subTest("GE when equal"):
1073n/a a = ComparatorNotImplemented(8)
1074n/a b = ComparatorNotImplemented(8)
1075n/a self.assertEqual(a, b)
1076n/a with self.assertRaises(TypeError):
1077n/a a >= b
1078n/a
1079n/a with self.subTest("LE when equal"):
1080n/a a = ComparatorNotImplemented(9)
1081n/a b = ComparatorNotImplemented(9)
1082n/a self.assertEqual(a, b)
1083n/a with self.assertRaises(TypeError):
1084n/a a <= b
1085n/a
1086n/a def test_pickle(self):
1087n/a for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1088n/a for name in '__lt__', '__gt__', '__le__', '__ge__':
1089n/a with self.subTest(method=name, proto=proto):
1090n/a method = getattr(Orderable_LT, name)
1091n/a method_copy = pickle.loads(pickle.dumps(method, proto))
1092n/a self.assertIs(method_copy, method)
1093n/a
1094n/a@functools.total_ordering
1095n/aclass Orderable_LT:
1096n/a def __init__(self, value):
1097n/a self.value = value
1098n/a def __lt__(self, other):
1099n/a return self.value < other.value
1100n/a def __eq__(self, other):
1101n/a return self.value == other.value
1102n/a
1103n/a
1104n/aclass TestLRU:
1105n/a
1106n/a def test_lru(self):
1107n/a def orig(x, y):
1108n/a return 3 * x + y
1109n/a f = self.module.lru_cache(maxsize=20)(orig)
1110n/a hits, misses, maxsize, currsize = f.cache_info()
1111n/a self.assertEqual(maxsize, 20)
1112n/a self.assertEqual(currsize, 0)
1113n/a self.assertEqual(hits, 0)
1114n/a self.assertEqual(misses, 0)
1115n/a
1116n/a domain = range(5)
1117n/a for i in range(1000):
1118n/a x, y = choice(domain), choice(domain)
1119n/a actual = f(x, y)
1120n/a expected = orig(x, y)
1121n/a self.assertEqual(actual, expected)
1122n/a hits, misses, maxsize, currsize = f.cache_info()
1123n/a self.assertTrue(hits > misses)
1124n/a self.assertEqual(hits + misses, 1000)
1125n/a self.assertEqual(currsize, 20)
1126n/a
1127n/a f.cache_clear() # test clearing
1128n/a hits, misses, maxsize, currsize = f.cache_info()
1129n/a self.assertEqual(hits, 0)
1130n/a self.assertEqual(misses, 0)
1131n/a self.assertEqual(currsize, 0)
1132n/a f(x, y)
1133n/a hits, misses, maxsize, currsize = f.cache_info()
1134n/a self.assertEqual(hits, 0)
1135n/a self.assertEqual(misses, 1)
1136n/a self.assertEqual(currsize, 1)
1137n/a
1138n/a # Test bypassing the cache
1139n/a self.assertIs(f.__wrapped__, orig)
1140n/a f.__wrapped__(x, y)
1141n/a hits, misses, maxsize, currsize = f.cache_info()
1142n/a self.assertEqual(hits, 0)
1143n/a self.assertEqual(misses, 1)
1144n/a self.assertEqual(currsize, 1)
1145n/a
1146n/a # test size zero (which means "never-cache")
1147n/a @self.module.lru_cache(0)
1148n/a def f():
1149n/a nonlocal f_cnt
1150n/a f_cnt += 1
1151n/a return 20
1152n/a self.assertEqual(f.cache_info().maxsize, 0)
1153n/a f_cnt = 0
1154n/a for i in range(5):
1155n/a self.assertEqual(f(), 20)
1156n/a self.assertEqual(f_cnt, 5)
1157n/a hits, misses, maxsize, currsize = f.cache_info()
1158n/a self.assertEqual(hits, 0)
1159n/a self.assertEqual(misses, 5)
1160n/a self.assertEqual(currsize, 0)
1161n/a
1162n/a # test size one
1163n/a @self.module.lru_cache(1)
1164n/a def f():
1165n/a nonlocal f_cnt
1166n/a f_cnt += 1
1167n/a return 20
1168n/a self.assertEqual(f.cache_info().maxsize, 1)
1169n/a f_cnt = 0
1170n/a for i in range(5):
1171n/a self.assertEqual(f(), 20)
1172n/a self.assertEqual(f_cnt, 1)
1173n/a hits, misses, maxsize, currsize = f.cache_info()
1174n/a self.assertEqual(hits, 4)
1175n/a self.assertEqual(misses, 1)
1176n/a self.assertEqual(currsize, 1)
1177n/a
1178n/a # test size two
1179n/a @self.module.lru_cache(2)
1180n/a def f(x):
1181n/a nonlocal f_cnt
1182n/a f_cnt += 1
1183n/a return x*10
1184n/a self.assertEqual(f.cache_info().maxsize, 2)
1185n/a f_cnt = 0
1186n/a for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1187n/a # * * * *
1188n/a self.assertEqual(f(x), x*10)
1189n/a self.assertEqual(f_cnt, 4)
1190n/a hits, misses, maxsize, currsize = f.cache_info()
1191n/a self.assertEqual(hits, 12)
1192n/a self.assertEqual(misses, 4)
1193n/a self.assertEqual(currsize, 2)
1194n/a
1195n/a def test_lru_hash_only_once(self):
1196n/a # To protect against weird reentrancy bugs and to improve
1197n/a # efficiency when faced with slow __hash__ methods, the
1198n/a # LRU cache guarantees that it will only call __hash__
1199n/a # only once per use as an argument to the cached function.
1200n/a
1201n/a @self.module.lru_cache(maxsize=1)
1202n/a def f(x, y):
1203n/a return x * 3 + y
1204n/a
1205n/a # Simulate the integer 5
1206n/a mock_int = unittest.mock.Mock()
1207n/a mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1208n/a mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1209n/a
1210n/a # Add to cache: One use as an argument gives one call
1211n/a self.assertEqual(f(mock_int, 1), 16)
1212n/a self.assertEqual(mock_int.__hash__.call_count, 1)
1213n/a self.assertEqual(f.cache_info(), (0, 1, 1, 1))
1214n/a
1215n/a # Cache hit: One use as an argument gives one additional call
1216n/a self.assertEqual(f(mock_int, 1), 16)
1217n/a self.assertEqual(mock_int.__hash__.call_count, 2)
1218n/a self.assertEqual(f.cache_info(), (1, 1, 1, 1))
1219n/a
1220n/a # Cache eviction: No use as an argument gives no additonal call
1221n/a self.assertEqual(f(6, 2), 20)
1222n/a self.assertEqual(mock_int.__hash__.call_count, 2)
1223n/a self.assertEqual(f.cache_info(), (1, 2, 1, 1))
1224n/a
1225n/a # Cache miss: One use as an argument gives one additional call
1226n/a self.assertEqual(f(mock_int, 1), 16)
1227n/a self.assertEqual(mock_int.__hash__.call_count, 3)
1228n/a self.assertEqual(f.cache_info(), (1, 3, 1, 1))
1229n/a
1230n/a def test_lru_reentrancy_with_len(self):
1231n/a # Test to make sure the LRU cache code isn't thrown-off by
1232n/a # caching the built-in len() function. Since len() can be
1233n/a # cached, we shouldn't use it inside the lru code itself.
1234n/a old_len = builtins.len
1235n/a try:
1236n/a builtins.len = self.module.lru_cache(4)(len)
1237n/a for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1238n/a self.assertEqual(len('abcdefghijklmn'[:i]), i)
1239n/a finally:
1240n/a builtins.len = old_len
1241n/a
1242n/a def test_lru_star_arg_handling(self):
1243n/a # Test regression that arose in ea064ff3c10f
1244n/a @functools.lru_cache()
1245n/a def f(*args):
1246n/a return args
1247n/a
1248n/a self.assertEqual(f(1, 2), (1, 2))
1249n/a self.assertEqual(f((1, 2)), ((1, 2),))
1250n/a
1251n/a def test_lru_type_error(self):
1252n/a # Regression test for issue #28653.
1253n/a # lru_cache was leaking when one of the arguments
1254n/a # wasn't cacheable.
1255n/a
1256n/a @functools.lru_cache(maxsize=None)
1257n/a def infinite_cache(o):
1258n/a pass
1259n/a
1260n/a @functools.lru_cache(maxsize=10)
1261n/a def limited_cache(o):
1262n/a pass
1263n/a
1264n/a with self.assertRaises(TypeError):
1265n/a infinite_cache([])
1266n/a
1267n/a with self.assertRaises(TypeError):
1268n/a limited_cache([])
1269n/a
1270n/a def test_lru_with_maxsize_none(self):
1271n/a @self.module.lru_cache(maxsize=None)
1272n/a def fib(n):
1273n/a if n < 2:
1274n/a return n
1275n/a return fib(n-1) + fib(n-2)
1276n/a self.assertEqual([fib(n) for n in range(16)],
1277n/a [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1278n/a self.assertEqual(fib.cache_info(),
1279n/a self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1280n/a fib.cache_clear()
1281n/a self.assertEqual(fib.cache_info(),
1282n/a self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1283n/a
1284n/a def test_lru_with_maxsize_negative(self):
1285n/a @self.module.lru_cache(maxsize=-10)
1286n/a def eq(n):
1287n/a return n
1288n/a for i in (0, 1):
1289n/a self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1290n/a self.assertEqual(eq.cache_info(),
1291n/a self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
1292n/a
1293n/a def test_lru_with_exceptions(self):
1294n/a # Verify that user_function exceptions get passed through without
1295n/a # creating a hard-to-read chained exception.
1296n/a # http://bugs.python.org/issue13177
1297n/a for maxsize in (None, 128):
1298n/a @self.module.lru_cache(maxsize)
1299n/a def func(i):
1300n/a return 'abc'[i]
1301n/a self.assertEqual(func(0), 'a')
1302n/a with self.assertRaises(IndexError) as cm:
1303n/a func(15)
1304n/a self.assertIsNone(cm.exception.__context__)
1305n/a # Verify that the previous exception did not result in a cached entry
1306n/a with self.assertRaises(IndexError):
1307n/a func(15)
1308n/a
1309n/a def test_lru_with_types(self):
1310n/a for maxsize in (None, 128):
1311n/a @self.module.lru_cache(maxsize=maxsize, typed=True)
1312n/a def square(x):
1313n/a return x * x
1314n/a self.assertEqual(square(3), 9)
1315n/a self.assertEqual(type(square(3)), type(9))
1316n/a self.assertEqual(square(3.0), 9.0)
1317n/a self.assertEqual(type(square(3.0)), type(9.0))
1318n/a self.assertEqual(square(x=3), 9)
1319n/a self.assertEqual(type(square(x=3)), type(9))
1320n/a self.assertEqual(square(x=3.0), 9.0)
1321n/a self.assertEqual(type(square(x=3.0)), type(9.0))
1322n/a self.assertEqual(square.cache_info().hits, 4)
1323n/a self.assertEqual(square.cache_info().misses, 4)
1324n/a
1325n/a def test_lru_with_keyword_args(self):
1326n/a @self.module.lru_cache()
1327n/a def fib(n):
1328n/a if n < 2:
1329n/a return n
1330n/a return fib(n=n-1) + fib(n=n-2)
1331n/a self.assertEqual(
1332n/a [fib(n=number) for number in range(16)],
1333n/a [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1334n/a )
1335n/a self.assertEqual(fib.cache_info(),
1336n/a self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1337n/a fib.cache_clear()
1338n/a self.assertEqual(fib.cache_info(),
1339n/a self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1340n/a
1341n/a def test_lru_with_keyword_args_maxsize_none(self):
1342n/a @self.module.lru_cache(maxsize=None)
1343n/a def fib(n):
1344n/a if n < 2:
1345n/a return n
1346n/a return fib(n=n-1) + fib(n=n-2)
1347n/a self.assertEqual([fib(n=number) for number in range(16)],
1348n/a [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1349n/a self.assertEqual(fib.cache_info(),
1350n/a self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1351n/a fib.cache_clear()
1352n/a self.assertEqual(fib.cache_info(),
1353n/a self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1354n/a
1355n/a def test_kwargs_order(self):
1356n/a # PEP 468: Preserving Keyword Argument Order
1357n/a @self.module.lru_cache(maxsize=10)
1358n/a def f(**kwargs):
1359n/a return list(kwargs.items())
1360n/a self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1361n/a self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1362n/a self.assertEqual(f.cache_info(),
1363n/a self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1364n/a
1365n/a def test_lru_cache_decoration(self):
1366n/a def f(zomg: 'zomg_annotation'):
1367n/a """f doc string"""
1368n/a return 42
1369n/a g = self.module.lru_cache()(f)
1370n/a for attr in self.module.WRAPPER_ASSIGNMENTS:
1371n/a self.assertEqual(getattr(g, attr), getattr(f, attr))
1372n/a
1373n/a @unittest.skipUnless(threading, 'This test requires threading.')
1374n/a def test_lru_cache_threaded(self):
1375n/a n, m = 5, 11
1376n/a def orig(x, y):
1377n/a return 3 * x + y
1378n/a f = self.module.lru_cache(maxsize=n*m)(orig)
1379n/a hits, misses, maxsize, currsize = f.cache_info()
1380n/a self.assertEqual(currsize, 0)
1381n/a
1382n/a start = threading.Event()
1383n/a def full(k):
1384n/a start.wait(10)
1385n/a for _ in range(m):
1386n/a self.assertEqual(f(k, 0), orig(k, 0))
1387n/a
1388n/a def clear():
1389n/a start.wait(10)
1390n/a for _ in range(2*m):
1391n/a f.cache_clear()
1392n/a
1393n/a orig_si = sys.getswitchinterval()
1394n/a support.setswitchinterval(1e-6)
1395n/a try:
1396n/a # create n threads in order to fill cache
1397n/a threads = [threading.Thread(target=full, args=[k])
1398n/a for k in range(n)]
1399n/a with support.start_threads(threads):
1400n/a start.set()
1401n/a
1402n/a hits, misses, maxsize, currsize = f.cache_info()
1403n/a if self.module is py_functools:
1404n/a # XXX: Why can be not equal?
1405n/a self.assertLessEqual(misses, n)
1406n/a self.assertLessEqual(hits, m*n - misses)
1407n/a else:
1408n/a self.assertEqual(misses, n)
1409n/a self.assertEqual(hits, m*n - misses)
1410n/a self.assertEqual(currsize, n)
1411n/a
1412n/a # create n threads in order to fill cache and 1 to clear it
1413n/a threads = [threading.Thread(target=clear)]
1414n/a threads += [threading.Thread(target=full, args=[k])
1415n/a for k in range(n)]
1416n/a start.clear()
1417n/a with support.start_threads(threads):
1418n/a start.set()
1419n/a finally:
1420n/a sys.setswitchinterval(orig_si)
1421n/a
1422n/a @unittest.skipUnless(threading, 'This test requires threading.')
1423n/a def test_lru_cache_threaded2(self):
1424n/a # Simultaneous call with the same arguments
1425n/a n, m = 5, 7
1426n/a start = threading.Barrier(n+1)
1427n/a pause = threading.Barrier(n+1)
1428n/a stop = threading.Barrier(n+1)
1429n/a @self.module.lru_cache(maxsize=m*n)
1430n/a def f(x):
1431n/a pause.wait(10)
1432n/a return 3 * x
1433n/a self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1434n/a def test():
1435n/a for i in range(m):
1436n/a start.wait(10)
1437n/a self.assertEqual(f(i), 3 * i)
1438n/a stop.wait(10)
1439n/a threads = [threading.Thread(target=test) for k in range(n)]
1440n/a with support.start_threads(threads):
1441n/a for i in range(m):
1442n/a start.wait(10)
1443n/a stop.reset()
1444n/a pause.wait(10)
1445n/a start.reset()
1446n/a stop.wait(10)
1447n/a pause.reset()
1448n/a self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1449n/a
1450n/a @unittest.skipUnless(threading, 'This test requires threading.')
1451n/a def test_lru_cache_threaded3(self):
1452n/a @self.module.lru_cache(maxsize=2)
1453n/a def f(x):
1454n/a time.sleep(.01)
1455n/a return 3 * x
1456n/a def test(i, x):
1457n/a with self.subTest(thread=i):
1458n/a self.assertEqual(f(x), 3 * x, i)
1459n/a threads = [threading.Thread(target=test, args=(i, v))
1460n/a for i, v in enumerate([1, 2, 2, 3, 2])]
1461n/a with support.start_threads(threads):
1462n/a pass
1463n/a
1464n/a def test_need_for_rlock(self):
1465n/a # This will deadlock on an LRU cache that uses a regular lock
1466n/a
1467n/a @self.module.lru_cache(maxsize=10)
1468n/a def test_func(x):
1469n/a 'Used to demonstrate a reentrant lru_cache call within a single thread'
1470n/a return x
1471n/a
1472n/a class DoubleEq:
1473n/a 'Demonstrate a reentrant lru_cache call within a single thread'
1474n/a def __init__(self, x):
1475n/a self.x = x
1476n/a def __hash__(self):
1477n/a return self.x
1478n/a def __eq__(self, other):
1479n/a if self.x == 2:
1480n/a test_func(DoubleEq(1))
1481n/a return self.x == other.x
1482n/a
1483n/a test_func(DoubleEq(1)) # Load the cache
1484n/a test_func(DoubleEq(2)) # Load the cache
1485n/a self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1486n/a DoubleEq(2)) # Verify the correct return value
1487n/a
1488n/a def test_early_detection_of_bad_call(self):
1489n/a # Issue #22184
1490n/a with self.assertRaises(TypeError):
1491n/a @functools.lru_cache
1492n/a def f():
1493n/a pass
1494n/a
1495n/a def test_lru_method(self):
1496n/a class X(int):
1497n/a f_cnt = 0
1498n/a @self.module.lru_cache(2)
1499n/a def f(self, x):
1500n/a self.f_cnt += 1
1501n/a return x*10+self
1502n/a a = X(5)
1503n/a b = X(5)
1504n/a c = X(7)
1505n/a self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1506n/a
1507n/a for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1508n/a self.assertEqual(a.f(x), x*10 + 5)
1509n/a self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1510n/a self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1511n/a
1512n/a for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1513n/a self.assertEqual(b.f(x), x*10 + 5)
1514n/a self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1515n/a self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1516n/a
1517n/a for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1518n/a self.assertEqual(c.f(x), x*10 + 7)
1519n/a self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1520n/a self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1521n/a
1522n/a self.assertEqual(a.f.cache_info(), X.f.cache_info())
1523n/a self.assertEqual(b.f.cache_info(), X.f.cache_info())
1524n/a self.assertEqual(c.f.cache_info(), X.f.cache_info())
1525n/a
1526n/a def test_pickle(self):
1527n/a cls = self.__class__
1528n/a for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1529n/a for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1530n/a with self.subTest(proto=proto, func=f):
1531n/a f_copy = pickle.loads(pickle.dumps(f, proto))
1532n/a self.assertIs(f_copy, f)
1533n/a
1534n/a def test_copy(self):
1535n/a cls = self.__class__
1536n/a def orig(x, y):
1537n/a return 3 * x + y
1538n/a part = self.module.partial(orig, 2)
1539n/a funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1540n/a self.module.lru_cache(2)(part))
1541n/a for f in funcs:
1542n/a with self.subTest(func=f):
1543n/a f_copy = copy.copy(f)
1544n/a self.assertIs(f_copy, f)
1545n/a
1546n/a def test_deepcopy(self):
1547n/a cls = self.__class__
1548n/a def orig(x, y):
1549n/a return 3 * x + y
1550n/a part = self.module.partial(orig, 2)
1551n/a funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1552n/a self.module.lru_cache(2)(part))
1553n/a for f in funcs:
1554n/a with self.subTest(func=f):
1555n/a f_copy = copy.deepcopy(f)
1556n/a self.assertIs(f_copy, f)
1557n/a
1558n/a
1559n/a@py_functools.lru_cache()
1560n/adef py_cached_func(x, y):
1561n/a return 3 * x + y
1562n/a
1563n/a@c_functools.lru_cache()
1564n/adef c_cached_func(x, y):
1565n/a return 3 * x + y
1566n/a
1567n/a
1568n/aclass TestLRUPy(TestLRU, unittest.TestCase):
1569n/a module = py_functools
1570n/a cached_func = py_cached_func,
1571n/a
1572n/a @module.lru_cache()
1573n/a def cached_meth(self, x, y):
1574n/a return 3 * x + y
1575n/a
1576n/a @staticmethod
1577n/a @module.lru_cache()
1578n/a def cached_staticmeth(x, y):
1579n/a return 3 * x + y
1580n/a
1581n/a
1582n/aclass TestLRUC(TestLRU, unittest.TestCase):
1583n/a module = c_functools
1584n/a cached_func = c_cached_func,
1585n/a
1586n/a @module.lru_cache()
1587n/a def cached_meth(self, x, y):
1588n/a return 3 * x + y
1589n/a
1590n/a @staticmethod
1591n/a @module.lru_cache()
1592n/a def cached_staticmeth(x, y):
1593n/a return 3 * x + y
1594n/a
1595n/a
1596n/aclass TestSingleDispatch(unittest.TestCase):
1597n/a def test_simple_overloads(self):
1598n/a @functools.singledispatch
1599n/a def g(obj):
1600n/a return "base"
1601n/a def g_int(i):
1602n/a return "integer"
1603n/a g.register(int, g_int)
1604n/a self.assertEqual(g("str"), "base")
1605n/a self.assertEqual(g(1), "integer")
1606n/a self.assertEqual(g([1,2,3]), "base")
1607n/a
1608n/a def test_mro(self):
1609n/a @functools.singledispatch
1610n/a def g(obj):
1611n/a return "base"
1612n/a class A:
1613n/a pass
1614n/a class C(A):
1615n/a pass
1616n/a class B(A):
1617n/a pass
1618n/a class D(C, B):
1619n/a pass
1620n/a def g_A(a):
1621n/a return "A"
1622n/a def g_B(b):
1623n/a return "B"
1624n/a g.register(A, g_A)
1625n/a g.register(B, g_B)
1626n/a self.assertEqual(g(A()), "A")
1627n/a self.assertEqual(g(B()), "B")
1628n/a self.assertEqual(g(C()), "A")
1629n/a self.assertEqual(g(D()), "B")
1630n/a
1631n/a def test_register_decorator(self):
1632n/a @functools.singledispatch
1633n/a def g(obj):
1634n/a return "base"
1635n/a @g.register(int)
1636n/a def g_int(i):
1637n/a return "int %s" % (i,)
1638n/a self.assertEqual(g(""), "base")
1639n/a self.assertEqual(g(12), "int 12")
1640n/a self.assertIs(g.dispatch(int), g_int)
1641n/a self.assertIs(g.dispatch(object), g.dispatch(str))
1642n/a # Note: in the assert above this is not g.
1643n/a # @singledispatch returns the wrapper.
1644n/a
1645n/a def test_wrapping_attributes(self):
1646n/a @functools.singledispatch
1647n/a def g(obj):
1648n/a "Simple test"
1649n/a return "Test"
1650n/a self.assertEqual(g.__name__, "g")
1651n/a if sys.flags.optimize < 2:
1652n/a self.assertEqual(g.__doc__, "Simple test")
1653n/a
1654n/a @unittest.skipUnless(decimal, 'requires _decimal')
1655n/a @support.cpython_only
1656n/a def test_c_classes(self):
1657n/a @functools.singledispatch
1658n/a def g(obj):
1659n/a return "base"
1660n/a @g.register(decimal.DecimalException)
1661n/a def _(obj):
1662n/a return obj.args
1663n/a subn = decimal.Subnormal("Exponent < Emin")
1664n/a rnd = decimal.Rounded("Number got rounded")
1665n/a self.assertEqual(g(subn), ("Exponent < Emin",))
1666n/a self.assertEqual(g(rnd), ("Number got rounded",))
1667n/a @g.register(decimal.Subnormal)
1668n/a def _(obj):
1669n/a return "Too small to care."
1670n/a self.assertEqual(g(subn), "Too small to care.")
1671n/a self.assertEqual(g(rnd), ("Number got rounded",))
1672n/a
1673n/a def test_compose_mro(self):
1674n/a # None of the examples in this test depend on haystack ordering.
1675n/a c = collections
1676n/a mro = functools._compose_mro
1677n/a bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1678n/a for haystack in permutations(bases):
1679n/a m = mro(dict, haystack)
1680n/a self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1681n/a c.Collection, c.Sized, c.Iterable,
1682n/a c.Container, object])
1683n/a bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1684n/a for haystack in permutations(bases):
1685n/a m = mro(c.ChainMap, haystack)
1686n/a self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1687n/a c.Collection, c.Sized, c.Iterable,
1688n/a c.Container, object])
1689n/a
1690n/a # If there's a generic function with implementations registered for
1691n/a # both Sized and Container, passing a defaultdict to it results in an
1692n/a # ambiguous dispatch which will cause a RuntimeError (see
1693n/a # test_mro_conflicts).
1694n/a bases = [c.Container, c.Sized, str]
1695n/a for haystack in permutations(bases):
1696n/a m = mro(c.defaultdict, [c.Sized, c.Container, str])
1697n/a self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1698n/a object])
1699n/a
1700n/a # MutableSequence below is registered directly on D. In other words, it
1701n/a # precedes MutableMapping which means single dispatch will always
1702n/a # choose MutableSequence here.
1703n/a class D(c.defaultdict):
1704n/a pass
1705n/a c.MutableSequence.register(D)
1706n/a bases = [c.MutableSequence, c.MutableMapping]
1707n/a for haystack in permutations(bases):
1708n/a m = mro(D, bases)
1709n/a self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1710n/a c.defaultdict, dict, c.MutableMapping, c.Mapping,
1711n/a c.Collection, c.Sized, c.Iterable, c.Container,
1712n/a object])
1713n/a
1714n/a # Container and Callable are registered on different base classes and
1715n/a # a generic function supporting both should always pick the Callable
1716n/a # implementation if a C instance is passed.
1717n/a class C(c.defaultdict):
1718n/a def __call__(self):
1719n/a pass
1720n/a bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1721n/a for haystack in permutations(bases):
1722n/a m = mro(C, haystack)
1723n/a self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1724n/a c.Collection, c.Sized, c.Iterable,
1725n/a c.Container, object])
1726n/a
1727n/a def test_register_abc(self):
1728n/a c = collections
1729n/a d = {"a": "b"}
1730n/a l = [1, 2, 3]
1731n/a s = {object(), None}
1732n/a f = frozenset(s)
1733n/a t = (1, 2, 3)
1734n/a @functools.singledispatch
1735n/a def g(obj):
1736n/a return "base"
1737n/a self.assertEqual(g(d), "base")
1738n/a self.assertEqual(g(l), "base")
1739n/a self.assertEqual(g(s), "base")
1740n/a self.assertEqual(g(f), "base")
1741n/a self.assertEqual(g(t), "base")
1742n/a g.register(c.Sized, lambda obj: "sized")
1743n/a self.assertEqual(g(d), "sized")
1744n/a self.assertEqual(g(l), "sized")
1745n/a self.assertEqual(g(s), "sized")
1746n/a self.assertEqual(g(f), "sized")
1747n/a self.assertEqual(g(t), "sized")
1748n/a g.register(c.MutableMapping, lambda obj: "mutablemapping")
1749n/a self.assertEqual(g(d), "mutablemapping")
1750n/a self.assertEqual(g(l), "sized")
1751n/a self.assertEqual(g(s), "sized")
1752n/a self.assertEqual(g(f), "sized")
1753n/a self.assertEqual(g(t), "sized")
1754n/a g.register(c.ChainMap, lambda obj: "chainmap")
1755n/a self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1756n/a self.assertEqual(g(l), "sized")
1757n/a self.assertEqual(g(s), "sized")
1758n/a self.assertEqual(g(f), "sized")
1759n/a self.assertEqual(g(t), "sized")
1760n/a g.register(c.MutableSequence, lambda obj: "mutablesequence")
1761n/a self.assertEqual(g(d), "mutablemapping")
1762n/a self.assertEqual(g(l), "mutablesequence")
1763n/a self.assertEqual(g(s), "sized")
1764n/a self.assertEqual(g(f), "sized")
1765n/a self.assertEqual(g(t), "sized")
1766n/a g.register(c.MutableSet, lambda obj: "mutableset")
1767n/a self.assertEqual(g(d), "mutablemapping")
1768n/a self.assertEqual(g(l), "mutablesequence")
1769n/a self.assertEqual(g(s), "mutableset")
1770n/a self.assertEqual(g(f), "sized")
1771n/a self.assertEqual(g(t), "sized")
1772n/a g.register(c.Mapping, lambda obj: "mapping")
1773n/a self.assertEqual(g(d), "mutablemapping") # not specific enough
1774n/a self.assertEqual(g(l), "mutablesequence")
1775n/a self.assertEqual(g(s), "mutableset")
1776n/a self.assertEqual(g(f), "sized")
1777n/a self.assertEqual(g(t), "sized")
1778n/a g.register(c.Sequence, lambda obj: "sequence")
1779n/a self.assertEqual(g(d), "mutablemapping")
1780n/a self.assertEqual(g(l), "mutablesequence")
1781n/a self.assertEqual(g(s), "mutableset")
1782n/a self.assertEqual(g(f), "sized")
1783n/a self.assertEqual(g(t), "sequence")
1784n/a g.register(c.Set, lambda obj: "set")
1785n/a self.assertEqual(g(d), "mutablemapping")
1786n/a self.assertEqual(g(l), "mutablesequence")
1787n/a self.assertEqual(g(s), "mutableset")
1788n/a self.assertEqual(g(f), "set")
1789n/a self.assertEqual(g(t), "sequence")
1790n/a g.register(dict, lambda obj: "dict")
1791n/a self.assertEqual(g(d), "dict")
1792n/a self.assertEqual(g(l), "mutablesequence")
1793n/a self.assertEqual(g(s), "mutableset")
1794n/a self.assertEqual(g(f), "set")
1795n/a self.assertEqual(g(t), "sequence")
1796n/a g.register(list, lambda obj: "list")
1797n/a self.assertEqual(g(d), "dict")
1798n/a self.assertEqual(g(l), "list")
1799n/a self.assertEqual(g(s), "mutableset")
1800n/a self.assertEqual(g(f), "set")
1801n/a self.assertEqual(g(t), "sequence")
1802n/a g.register(set, lambda obj: "concrete-set")
1803n/a self.assertEqual(g(d), "dict")
1804n/a self.assertEqual(g(l), "list")
1805n/a self.assertEqual(g(s), "concrete-set")
1806n/a self.assertEqual(g(f), "set")
1807n/a self.assertEqual(g(t), "sequence")
1808n/a g.register(frozenset, lambda obj: "frozen-set")
1809n/a self.assertEqual(g(d), "dict")
1810n/a self.assertEqual(g(l), "list")
1811n/a self.assertEqual(g(s), "concrete-set")
1812n/a self.assertEqual(g(f), "frozen-set")
1813n/a self.assertEqual(g(t), "sequence")
1814n/a g.register(tuple, lambda obj: "tuple")
1815n/a self.assertEqual(g(d), "dict")
1816n/a self.assertEqual(g(l), "list")
1817n/a self.assertEqual(g(s), "concrete-set")
1818n/a self.assertEqual(g(f), "frozen-set")
1819n/a self.assertEqual(g(t), "tuple")
1820n/a
1821n/a def test_c3_abc(self):
1822n/a c = collections
1823n/a mro = functools._c3_mro
1824n/a class A(object):
1825n/a pass
1826n/a class B(A):
1827n/a def __len__(self):
1828n/a return 0 # implies Sized
1829n/a @c.Container.register
1830n/a class C(object):
1831n/a pass
1832n/a class D(object):
1833n/a pass # unrelated
1834n/a class X(D, C, B):
1835n/a def __call__(self):
1836n/a pass # implies Callable
1837n/a expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1838n/a for abcs in permutations([c.Sized, c.Callable, c.Container]):
1839n/a self.assertEqual(mro(X, abcs=abcs), expected)
1840n/a # unrelated ABCs don't appear in the resulting MRO
1841n/a many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1842n/a self.assertEqual(mro(X, abcs=many_abcs), expected)
1843n/a
1844n/a def test_false_meta(self):
1845n/a # see issue23572
1846n/a class MetaA(type):
1847n/a def __len__(self):
1848n/a return 0
1849n/a class A(metaclass=MetaA):
1850n/a pass
1851n/a class AA(A):
1852n/a pass
1853n/a @functools.singledispatch
1854n/a def fun(a):
1855n/a return 'base A'
1856n/a @fun.register(A)
1857n/a def _(a):
1858n/a return 'fun A'
1859n/a aa = AA()
1860n/a self.assertEqual(fun(aa), 'fun A')
1861n/a
1862n/a def test_mro_conflicts(self):
1863n/a c = collections
1864n/a @functools.singledispatch
1865n/a def g(arg):
1866n/a return "base"
1867n/a class O(c.Sized):
1868n/a def __len__(self):
1869n/a return 0
1870n/a o = O()
1871n/a self.assertEqual(g(o), "base")
1872n/a g.register(c.Iterable, lambda arg: "iterable")
1873n/a g.register(c.Container, lambda arg: "container")
1874n/a g.register(c.Sized, lambda arg: "sized")
1875n/a g.register(c.Set, lambda arg: "set")
1876n/a self.assertEqual(g(o), "sized")
1877n/a c.Iterable.register(O)
1878n/a self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1879n/a c.Container.register(O)
1880n/a self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
1881n/a c.Set.register(O)
1882n/a self.assertEqual(g(o), "set") # because c.Set is a subclass of
1883n/a # c.Sized and c.Container
1884n/a class P:
1885n/a pass
1886n/a p = P()
1887n/a self.assertEqual(g(p), "base")
1888n/a c.Iterable.register(P)
1889n/a self.assertEqual(g(p), "iterable")
1890n/a c.Container.register(P)
1891n/a with self.assertRaises(RuntimeError) as re_one:
1892n/a g(p)
1893n/a self.assertIn(
1894n/a str(re_one.exception),
1895n/a (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1896n/a "or <class 'collections.abc.Iterable'>"),
1897n/a ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1898n/a "or <class 'collections.abc.Container'>")),
1899n/a )
1900n/a class Q(c.Sized):
1901n/a def __len__(self):
1902n/a return 0
1903n/a q = Q()
1904n/a self.assertEqual(g(q), "sized")
1905n/a c.Iterable.register(Q)
1906n/a self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1907n/a c.Set.register(Q)
1908n/a self.assertEqual(g(q), "set") # because c.Set is a subclass of
1909n/a # c.Sized and c.Iterable
1910n/a @functools.singledispatch
1911n/a def h(arg):
1912n/a return "base"
1913n/a @h.register(c.Sized)
1914n/a def _(arg):
1915n/a return "sized"
1916n/a @h.register(c.Container)
1917n/a def _(arg):
1918n/a return "container"
1919n/a # Even though Sized and Container are explicit bases of MutableMapping,
1920n/a # this ABC is implicitly registered on defaultdict which makes all of
1921n/a # MutableMapping's bases implicit as well from defaultdict's
1922n/a # perspective.
1923n/a with self.assertRaises(RuntimeError) as re_two:
1924n/a h(c.defaultdict(lambda: 0))
1925n/a self.assertIn(
1926n/a str(re_two.exception),
1927n/a (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1928n/a "or <class 'collections.abc.Sized'>"),
1929n/a ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1930n/a "or <class 'collections.abc.Container'>")),
1931n/a )
1932n/a class R(c.defaultdict):
1933n/a pass
1934n/a c.MutableSequence.register(R)
1935n/a @functools.singledispatch
1936n/a def i(arg):
1937n/a return "base"
1938n/a @i.register(c.MutableMapping)
1939n/a def _(arg):
1940n/a return "mapping"
1941n/a @i.register(c.MutableSequence)
1942n/a def _(arg):
1943n/a return "sequence"
1944n/a r = R()
1945n/a self.assertEqual(i(r), "sequence")
1946n/a class S:
1947n/a pass
1948n/a class T(S, c.Sized):
1949n/a def __len__(self):
1950n/a return 0
1951n/a t = T()
1952n/a self.assertEqual(h(t), "sized")
1953n/a c.Container.register(T)
1954n/a self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1955n/a class U:
1956n/a def __len__(self):
1957n/a return 0
1958n/a u = U()
1959n/a self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1960n/a # from the existence of __len__()
1961n/a c.Container.register(U)
1962n/a # There is no preference for registered versus inferred ABCs.
1963n/a with self.assertRaises(RuntimeError) as re_three:
1964n/a h(u)
1965n/a self.assertIn(
1966n/a str(re_three.exception),
1967n/a (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1968n/a "or <class 'collections.abc.Sized'>"),
1969n/a ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1970n/a "or <class 'collections.abc.Container'>")),
1971n/a )
1972n/a class V(c.Sized, S):
1973n/a def __len__(self):
1974n/a return 0
1975n/a @functools.singledispatch
1976n/a def j(arg):
1977n/a return "base"
1978n/a @j.register(S)
1979n/a def _(arg):
1980n/a return "s"
1981n/a @j.register(c.Container)
1982n/a def _(arg):
1983n/a return "container"
1984n/a v = V()
1985n/a self.assertEqual(j(v), "s")
1986n/a c.Container.register(V)
1987n/a self.assertEqual(j(v), "container") # because it ends up right after
1988n/a # Sized in the MRO
1989n/a
1990n/a def test_cache_invalidation(self):
1991n/a from collections import UserDict
1992n/a class TracingDict(UserDict):
1993n/a def __init__(self, *args, **kwargs):
1994n/a super(TracingDict, self).__init__(*args, **kwargs)
1995n/a self.set_ops = []
1996n/a self.get_ops = []
1997n/a def __getitem__(self, key):
1998n/a result = self.data[key]
1999n/a self.get_ops.append(key)
2000n/a return result
2001n/a def __setitem__(self, key, value):
2002n/a self.set_ops.append(key)
2003n/a self.data[key] = value
2004n/a def clear(self):
2005n/a self.data.clear()
2006n/a _orig_wkd = functools.WeakKeyDictionary
2007n/a td = TracingDict()
2008n/a functools.WeakKeyDictionary = lambda: td
2009n/a c = collections
2010n/a @functools.singledispatch
2011n/a def g(arg):
2012n/a return "base"
2013n/a d = {}
2014n/a l = []
2015n/a self.assertEqual(len(td), 0)
2016n/a self.assertEqual(g(d), "base")
2017n/a self.assertEqual(len(td), 1)
2018n/a self.assertEqual(td.get_ops, [])
2019n/a self.assertEqual(td.set_ops, [dict])
2020n/a self.assertEqual(td.data[dict], g.registry[object])
2021n/a self.assertEqual(g(l), "base")
2022n/a self.assertEqual(len(td), 2)
2023n/a self.assertEqual(td.get_ops, [])
2024n/a self.assertEqual(td.set_ops, [dict, list])
2025n/a self.assertEqual(td.data[dict], g.registry[object])
2026n/a self.assertEqual(td.data[list], g.registry[object])
2027n/a self.assertEqual(td.data[dict], td.data[list])
2028n/a self.assertEqual(g(l), "base")
2029n/a self.assertEqual(g(d), "base")
2030n/a self.assertEqual(td.get_ops, [list, dict])
2031n/a self.assertEqual(td.set_ops, [dict, list])
2032n/a g.register(list, lambda arg: "list")
2033n/a self.assertEqual(td.get_ops, [list, dict])
2034n/a self.assertEqual(len(td), 0)
2035n/a self.assertEqual(g(d), "base")
2036n/a self.assertEqual(len(td), 1)
2037n/a self.assertEqual(td.get_ops, [list, dict])
2038n/a self.assertEqual(td.set_ops, [dict, list, dict])
2039n/a self.assertEqual(td.data[dict],
2040n/a functools._find_impl(dict, g.registry))
2041n/a self.assertEqual(g(l), "list")
2042n/a self.assertEqual(len(td), 2)
2043n/a self.assertEqual(td.get_ops, [list, dict])
2044n/a self.assertEqual(td.set_ops, [dict, list, dict, list])
2045n/a self.assertEqual(td.data[list],
2046n/a functools._find_impl(list, g.registry))
2047n/a class X:
2048n/a pass
2049n/a c.MutableMapping.register(X) # Will not invalidate the cache,
2050n/a # not using ABCs yet.
2051n/a self.assertEqual(g(d), "base")
2052n/a self.assertEqual(g(l), "list")
2053n/a self.assertEqual(td.get_ops, [list, dict, dict, list])
2054n/a self.assertEqual(td.set_ops, [dict, list, dict, list])
2055n/a g.register(c.Sized, lambda arg: "sized")
2056n/a self.assertEqual(len(td), 0)
2057n/a self.assertEqual(g(d), "sized")
2058n/a self.assertEqual(len(td), 1)
2059n/a self.assertEqual(td.get_ops, [list, dict, dict, list])
2060n/a self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2061n/a self.assertEqual(g(l), "list")
2062n/a self.assertEqual(len(td), 2)
2063n/a self.assertEqual(td.get_ops, [list, dict, dict, list])
2064n/a self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2065n/a self.assertEqual(g(l), "list")
2066n/a self.assertEqual(g(d), "sized")
2067n/a self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2068n/a self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2069n/a g.dispatch(list)
2070n/a g.dispatch(dict)
2071n/a self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2072n/a list, dict])
2073n/a self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2074n/a c.MutableSet.register(X) # Will invalidate the cache.
2075n/a self.assertEqual(len(td), 2) # Stale cache.
2076n/a self.assertEqual(g(l), "list")
2077n/a self.assertEqual(len(td), 1)
2078n/a g.register(c.MutableMapping, lambda arg: "mutablemapping")
2079n/a self.assertEqual(len(td), 0)
2080n/a self.assertEqual(g(d), "mutablemapping")
2081n/a self.assertEqual(len(td), 1)
2082n/a self.assertEqual(g(l), "list")
2083n/a self.assertEqual(len(td), 2)
2084n/a g.register(dict, lambda arg: "dict")
2085n/a self.assertEqual(g(d), "dict")
2086n/a self.assertEqual(g(l), "list")
2087n/a g._clear_cache()
2088n/a self.assertEqual(len(td), 0)
2089n/a functools.WeakKeyDictionary = _orig_wkd
2090n/a
2091n/a
2092n/aif __name__ == '__main__':
2093n/a unittest.main()