ยปCore Development>Code coverage>Lib/asyncio/test_utils.py

Python code coverage for Lib/asyncio/test_utils.py

#countcontent
1n/a"""Utilities shared by tests."""
2n/a
3n/aimport collections
4n/aimport contextlib
5n/aimport io
6n/aimport logging
7n/aimport os
8n/aimport re
9n/aimport socket
10n/aimport socketserver
11n/aimport sys
12n/aimport tempfile
13n/aimport threading
14n/aimport time
15n/aimport unittest
16n/aimport weakref
17n/a
18n/afrom unittest import mock
19n/a
20n/afrom http.server import HTTPServer
21n/afrom wsgiref.simple_server import WSGIRequestHandler, WSGIServer
22n/a
23n/atry:
24n/a import ssl
25n/aexcept ImportError: # pragma: no cover
26n/a ssl = None
27n/a
28n/afrom . import base_events
29n/afrom . import compat
30n/afrom . import events
31n/afrom . import futures
32n/afrom . import selectors
33n/afrom . import tasks
34n/afrom .coroutines import coroutine
35n/afrom .log import logger
36n/a
37n/a
38n/aif sys.platform == 'win32': # pragma: no cover
39n/a from .windows_utils import socketpair
40n/aelse:
41n/a from socket import socketpair # pragma: no cover
42n/a
43n/a
44n/adef dummy_ssl_context():
45n/a if ssl is None:
46n/a return None
47n/a else:
48n/a return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
49n/a
50n/a
51n/adef run_briefly(loop):
52n/a @coroutine
53n/a def once():
54n/a pass
55n/a gen = once()
56n/a t = loop.create_task(gen)
57n/a # Don't log a warning if the task is not done after run_until_complete().
58n/a # It occurs if the loop is stopped or if a task raises a BaseException.
59n/a t._log_destroy_pending = False
60n/a try:
61n/a loop.run_until_complete(t)
62n/a finally:
63n/a gen.close()
64n/a
65n/a
66n/adef run_until(loop, pred, timeout=30):
67n/a deadline = time.time() + timeout
68n/a while not pred():
69n/a if timeout is not None:
70n/a timeout = deadline - time.time()
71n/a if timeout <= 0:
72n/a raise futures.TimeoutError()
73n/a loop.run_until_complete(tasks.sleep(0.001, loop=loop))
74n/a
75n/a
76n/adef run_once(loop):
77n/a """Legacy API to run once through the event loop.
78n/a
79n/a This is the recommended pattern for test code. It will poll the
80n/a selector once and run all callbacks scheduled in response to I/O
81n/a events.
82n/a """
83n/a loop.call_soon(loop.stop)
84n/a loop.run_forever()
85n/a
86n/a
87n/aclass SilentWSGIRequestHandler(WSGIRequestHandler):
88n/a
89n/a def get_stderr(self):
90n/a return io.StringIO()
91n/a
92n/a def log_message(self, format, *args):
93n/a pass
94n/a
95n/a
96n/aclass SilentWSGIServer(WSGIServer):
97n/a
98n/a request_timeout = 2
99n/a
100n/a def get_request(self):
101n/a request, client_addr = super().get_request()
102n/a request.settimeout(self.request_timeout)
103n/a return request, client_addr
104n/a
105n/a def handle_error(self, request, client_address):
106n/a pass
107n/a
108n/a
109n/aclass SSLWSGIServerMixin:
110n/a
111n/a def finish_request(self, request, client_address):
112n/a # The relative location of our test directory (which
113n/a # contains the ssl key and certificate files) differs
114n/a # between the stdlib and stand-alone asyncio.
115n/a # Prefer our own if we can find it.
116n/a here = os.path.join(os.path.dirname(__file__), '..', 'tests')
117n/a if not os.path.isdir(here):
118n/a here = os.path.join(os.path.dirname(os.__file__),
119n/a 'test', 'test_asyncio')
120n/a keyfile = os.path.join(here, 'ssl_key.pem')
121n/a certfile = os.path.join(here, 'ssl_cert.pem')
122n/a context = ssl.SSLContext()
123n/a context.load_cert_chain(certfile, keyfile)
124n/a
125n/a ssock = context.wrap_socket(request, server_side=True)
126n/a try:
127n/a self.RequestHandlerClass(ssock, client_address, self)
128n/a ssock.close()
129n/a except OSError:
130n/a # maybe socket has been closed by peer
131n/a pass
132n/a
133n/a
134n/aclass SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
135n/a pass
136n/a
137n/a
138n/adef _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
139n/a
140n/a def app(environ, start_response):
141n/a status = '200 OK'
142n/a headers = [('Content-type', 'text/plain')]
143n/a start_response(status, headers)
144n/a return [b'Test message']
145n/a
146n/a # Run the test WSGI server in a separate thread in order not to
147n/a # interfere with event handling in the main thread
148n/a server_class = server_ssl_cls if use_ssl else server_cls
149n/a httpd = server_class(address, SilentWSGIRequestHandler)
150n/a httpd.set_app(app)
151n/a httpd.address = httpd.server_address
152n/a server_thread = threading.Thread(
153n/a target=lambda: httpd.serve_forever(poll_interval=0.05))
154n/a server_thread.start()
155n/a try:
156n/a yield httpd
157n/a finally:
158n/a httpd.shutdown()
159n/a httpd.server_close()
160n/a server_thread.join()
161n/a
162n/a
163n/aif hasattr(socket, 'AF_UNIX'):
164n/a
165n/a class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
166n/a
167n/a def server_bind(self):
168n/a socketserver.UnixStreamServer.server_bind(self)
169n/a self.server_name = '127.0.0.1'
170n/a self.server_port = 80
171n/a
172n/a
173n/a class UnixWSGIServer(UnixHTTPServer, WSGIServer):
174n/a
175n/a request_timeout = 2
176n/a
177n/a def server_bind(self):
178n/a UnixHTTPServer.server_bind(self)
179n/a self.setup_environ()
180n/a
181n/a def get_request(self):
182n/a request, client_addr = super().get_request()
183n/a request.settimeout(self.request_timeout)
184n/a # Code in the stdlib expects that get_request
185n/a # will return a socket and a tuple (host, port).
186n/a # However, this isn't true for UNIX sockets,
187n/a # as the second return value will be a path;
188n/a # hence we return some fake data sufficient
189n/a # to get the tests going
190n/a return request, ('127.0.0.1', '')
191n/a
192n/a
193n/a class SilentUnixWSGIServer(UnixWSGIServer):
194n/a
195n/a def handle_error(self, request, client_address):
196n/a pass
197n/a
198n/a
199n/a class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
200n/a pass
201n/a
202n/a
203n/a def gen_unix_socket_path():
204n/a with tempfile.NamedTemporaryFile() as file:
205n/a return file.name
206n/a
207n/a
208n/a @contextlib.contextmanager
209n/a def unix_socket_path():
210n/a path = gen_unix_socket_path()
211n/a try:
212n/a yield path
213n/a finally:
214n/a try:
215n/a os.unlink(path)
216n/a except OSError:
217n/a pass
218n/a
219n/a
220n/a @contextlib.contextmanager
221n/a def run_test_unix_server(*, use_ssl=False):
222n/a with unix_socket_path() as path:
223n/a yield from _run_test_server(address=path, use_ssl=use_ssl,
224n/a server_cls=SilentUnixWSGIServer,
225n/a server_ssl_cls=UnixSSLWSGIServer)
226n/a
227n/a
228n/a@contextlib.contextmanager
229n/adef run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
230n/a yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
231n/a server_cls=SilentWSGIServer,
232n/a server_ssl_cls=SSLWSGIServer)
233n/a
234n/a
235n/adef make_test_protocol(base):
236n/a dct = {}
237n/a for name in dir(base):
238n/a if name.startswith('__') and name.endswith('__'):
239n/a # skip magic names
240n/a continue
241n/a dct[name] = MockCallback(return_value=None)
242n/a return type('TestProtocol', (base,) + base.__bases__, dct)()
243n/a
244n/a
245n/aclass TestSelector(selectors.BaseSelector):
246n/a
247n/a def __init__(self):
248n/a self.keys = {}
249n/a
250n/a def register(self, fileobj, events, data=None):
251n/a key = selectors.SelectorKey(fileobj, 0, events, data)
252n/a self.keys[fileobj] = key
253n/a return key
254n/a
255n/a def unregister(self, fileobj):
256n/a return self.keys.pop(fileobj)
257n/a
258n/a def select(self, timeout):
259n/a return []
260n/a
261n/a def get_map(self):
262n/a return self.keys
263n/a
264n/a
265n/aclass TestLoop(base_events.BaseEventLoop):
266n/a """Loop for unittests.
267n/a
268n/a It manages self time directly.
269n/a If something scheduled to be executed later then
270n/a on next loop iteration after all ready handlers done
271n/a generator passed to __init__ is calling.
272n/a
273n/a Generator should be like this:
274n/a
275n/a def gen():
276n/a ...
277n/a when = yield ...
278n/a ... = yield time_advance
279n/a
280n/a Value returned by yield is absolute time of next scheduled handler.
281n/a Value passed to yield is time advance to move loop's time forward.
282n/a """
283n/a
284n/a def __init__(self, gen=None):
285n/a super().__init__()
286n/a
287n/a if gen is None:
288n/a def gen():
289n/a yield
290n/a self._check_on_close = False
291n/a else:
292n/a self._check_on_close = True
293n/a
294n/a self._gen = gen()
295n/a next(self._gen)
296n/a self._time = 0
297n/a self._clock_resolution = 1e-9
298n/a self._timers = []
299n/a self._selector = TestSelector()
300n/a
301n/a self.readers = {}
302n/a self.writers = {}
303n/a self.reset_counters()
304n/a
305n/a self._transports = weakref.WeakValueDictionary()
306n/a
307n/a def time(self):
308n/a return self._time
309n/a
310n/a def advance_time(self, advance):
311n/a """Move test time forward."""
312n/a if advance:
313n/a self._time += advance
314n/a
315n/a def close(self):
316n/a super().close()
317n/a if self._check_on_close:
318n/a try:
319n/a self._gen.send(0)
320n/a except StopIteration:
321n/a pass
322n/a else: # pragma: no cover
323n/a raise AssertionError("Time generator is not finished")
324n/a
325n/a def _add_reader(self, fd, callback, *args):
326n/a self.readers[fd] = events.Handle(callback, args, self)
327n/a
328n/a def _remove_reader(self, fd):
329n/a self.remove_reader_count[fd] += 1
330n/a if fd in self.readers:
331n/a del self.readers[fd]
332n/a return True
333n/a else:
334n/a return False
335n/a
336n/a def assert_reader(self, fd, callback, *args):
337n/a assert fd in self.readers, 'fd {} is not registered'.format(fd)
338n/a handle = self.readers[fd]
339n/a assert handle._callback == callback, '{!r} != {!r}'.format(
340n/a handle._callback, callback)
341n/a assert handle._args == args, '{!r} != {!r}'.format(
342n/a handle._args, args)
343n/a
344n/a def _add_writer(self, fd, callback, *args):
345n/a self.writers[fd] = events.Handle(callback, args, self)
346n/a
347n/a def _remove_writer(self, fd):
348n/a self.remove_writer_count[fd] += 1
349n/a if fd in self.writers:
350n/a del self.writers[fd]
351n/a return True
352n/a else:
353n/a return False
354n/a
355n/a def assert_writer(self, fd, callback, *args):
356n/a assert fd in self.writers, 'fd {} is not registered'.format(fd)
357n/a handle = self.writers[fd]
358n/a assert handle._callback == callback, '{!r} != {!r}'.format(
359n/a handle._callback, callback)
360n/a assert handle._args == args, '{!r} != {!r}'.format(
361n/a handle._args, args)
362n/a
363n/a def _ensure_fd_no_transport(self, fd):
364n/a try:
365n/a transport = self._transports[fd]
366n/a except KeyError:
367n/a pass
368n/a else:
369n/a raise RuntimeError(
370n/a 'File descriptor {!r} is used by transport {!r}'.format(
371n/a fd, transport))
372n/a
373n/a def add_reader(self, fd, callback, *args):
374n/a """Add a reader callback."""
375n/a self._ensure_fd_no_transport(fd)
376n/a return self._add_reader(fd, callback, *args)
377n/a
378n/a def remove_reader(self, fd):
379n/a """Remove a reader callback."""
380n/a self._ensure_fd_no_transport(fd)
381n/a return self._remove_reader(fd)
382n/a
383n/a def add_writer(self, fd, callback, *args):
384n/a """Add a writer callback.."""
385n/a self._ensure_fd_no_transport(fd)
386n/a return self._add_writer(fd, callback, *args)
387n/a
388n/a def remove_writer(self, fd):
389n/a """Remove a writer callback."""
390n/a self._ensure_fd_no_transport(fd)
391n/a return self._remove_writer(fd)
392n/a
393n/a def reset_counters(self):
394n/a self.remove_reader_count = collections.defaultdict(int)
395n/a self.remove_writer_count = collections.defaultdict(int)
396n/a
397n/a def _run_once(self):
398n/a super()._run_once()
399n/a for when in self._timers:
400n/a advance = self._gen.send(when)
401n/a self.advance_time(advance)
402n/a self._timers = []
403n/a
404n/a def call_at(self, when, callback, *args):
405n/a self._timers.append(when)
406n/a return super().call_at(when, callback, *args)
407n/a
408n/a def _process_events(self, event_list):
409n/a return
410n/a
411n/a def _write_to_self(self):
412n/a pass
413n/a
414n/a
415n/adef MockCallback(**kwargs):
416n/a return mock.Mock(spec=['__call__'], **kwargs)
417n/a
418n/a
419n/aclass MockPattern(str):
420n/a """A regex based str with a fuzzy __eq__.
421n/a
422n/a Use this helper with 'mock.assert_called_with', or anywhere
423n/a where a regex comparison between strings is needed.
424n/a
425n/a For instance:
426n/a mock_call.assert_called_with(MockPattern('spam.*ham'))
427n/a """
428n/a def __eq__(self, other):
429n/a return bool(re.search(str(self), other, re.S))
430n/a
431n/a
432n/adef get_function_source(func):
433n/a source = events._get_function_source(func)
434n/a if source is None:
435n/a raise ValueError("unable to get the source of %r" % (func,))
436n/a return source
437n/a
438n/a
439n/aclass TestCase(unittest.TestCase):
440n/a def set_event_loop(self, loop, *, cleanup=True):
441n/a assert loop is not None
442n/a # ensure that the event loop is passed explicitly in asyncio
443n/a events.set_event_loop(None)
444n/a if cleanup:
445n/a self.addCleanup(loop.close)
446n/a
447n/a def new_test_loop(self, gen=None):
448n/a loop = TestLoop(gen)
449n/a self.set_event_loop(loop)
450n/a return loop
451n/a
452n/a def setUp(self):
453n/a self._get_running_loop = events._get_running_loop
454n/a events._get_running_loop = lambda: None
455n/a
456n/a def tearDown(self):
457n/a events._get_running_loop = self._get_running_loop
458n/a
459n/a events.set_event_loop(None)
460n/a
461n/a # Detect CPython bug #23353: ensure that yield/yield-from is not used
462n/a # in an except block of a generator
463n/a self.assertEqual(sys.exc_info(), (None, None, None))
464n/a
465n/a if not compat.PY34:
466n/a # Python 3.3 compatibility
467n/a def subTest(self, *args, **kwargs):
468n/a class EmptyCM:
469n/a def __enter__(self):
470n/a pass
471n/a def __exit__(self, *exc):
472n/a pass
473n/a return EmptyCM()
474n/a
475n/a
476n/a@contextlib.contextmanager
477n/adef disable_logger():
478n/a """Context manager to disable asyncio logger.
479n/a
480n/a For example, it can be used to ignore warnings in debug mode.
481n/a """
482n/a old_level = logger.level
483n/a try:
484n/a logger.setLevel(logging.CRITICAL+1)
485n/a yield
486n/a finally:
487n/a logger.setLevel(old_level)
488n/a
489n/a
490n/adef mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
491n/a family=socket.AF_INET):
492n/a """Create a mock of a non-blocking socket."""
493n/a sock = mock.MagicMock(socket.socket)
494n/a sock.proto = proto
495n/a sock.type = type
496n/a sock.family = family
497n/a sock.gettimeout.return_value = 0.0
498n/a return sock
499n/a
500n/a
501n/adef force_legacy_ssl_support():
502n/a return mock.patch('asyncio.sslproto._is_sslproto_available',
503n/a return_value=False)