ยปCore Development>Code coverage>Lib/test/test_socketserver.py

Python code coverage for Lib/test/test_socketserver.py

#countcontent
1n/a"""
2n/aTest suite for socketserver.
3n/a"""
4n/a
5n/aimport contextlib
6n/aimport io
7n/aimport os
8n/aimport select
9n/aimport signal
10n/aimport socket
11n/aimport tempfile
12n/aimport unittest
13n/aimport socketserver
14n/a
15n/aimport test.support
16n/afrom test.support import reap_children, reap_threads, verbose
17n/atry:
18n/a import threading
19n/aexcept ImportError:
20n/a threading = None
21n/a
22n/atest.support.requires("network")
23n/a
24n/aTEST_STR = b"hello world\n"
25n/aHOST = test.support.HOST
26n/a
27n/aHAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
28n/arequires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
29n/a 'requires Unix sockets')
30n/aHAVE_FORKING = hasattr(os, "fork")
31n/arequires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
32n/a
33n/adef signal_alarm(n):
34n/a """Call signal.alarm when it exists (i.e. not on Windows)."""
35n/a if hasattr(signal, 'alarm'):
36n/a signal.alarm(n)
37n/a
38n/a# Remember real select() to avoid interferences with mocking
39n/a_real_select = select.select
40n/a
41n/adef receive(sock, n, timeout=20):
42n/a r, w, x = _real_select([sock], [], [], timeout)
43n/a if sock in r:
44n/a return sock.recv(n)
45n/a else:
46n/a raise RuntimeError("timed out on %r" % (sock,))
47n/a
48n/aif HAVE_UNIX_SOCKETS and HAVE_FORKING:
49n/a class ForkingUnixStreamServer(socketserver.ForkingMixIn,
50n/a socketserver.UnixStreamServer):
51n/a pass
52n/a
53n/a class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
54n/a socketserver.UnixDatagramServer):
55n/a pass
56n/a
57n/a
58n/a@contextlib.contextmanager
59n/adef simple_subprocess(testcase):
60n/a """Tests that a custom child process is not waited on (Issue 1540386)"""
61n/a pid = os.fork()
62n/a if pid == 0:
63n/a # Don't raise an exception; it would be caught by the test harness.
64n/a os._exit(72)
65n/a yield None
66n/a pid2, status = os.waitpid(pid, 0)
67n/a testcase.assertEqual(pid2, pid)
68n/a testcase.assertEqual(72 << 8, status)
69n/a
70n/a
71n/a@unittest.skipUnless(threading, 'Threading required for this test.')
72n/aclass SocketServerTest(unittest.TestCase):
73n/a """Test all socket servers."""
74n/a
75n/a def setUp(self):
76n/a signal_alarm(60) # Kill deadlocks after 60 seconds.
77n/a self.port_seed = 0
78n/a self.test_files = []
79n/a
80n/a def tearDown(self):
81n/a signal_alarm(0) # Didn't deadlock.
82n/a reap_children()
83n/a
84n/a for fn in self.test_files:
85n/a try:
86n/a os.remove(fn)
87n/a except OSError:
88n/a pass
89n/a self.test_files[:] = []
90n/a
91n/a def pickaddr(self, proto):
92n/a if proto == socket.AF_INET:
93n/a return (HOST, 0)
94n/a else:
95n/a # XXX: We need a way to tell AF_UNIX to pick its own name
96n/a # like AF_INET provides port==0.
97n/a dir = None
98n/a fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
99n/a self.test_files.append(fn)
100n/a return fn
101n/a
102n/a def make_server(self, addr, svrcls, hdlrbase):
103n/a class MyServer(svrcls):
104n/a def handle_error(self, request, client_address):
105n/a self.close_request(request)
106n/a raise
107n/a
108n/a class MyHandler(hdlrbase):
109n/a def handle(self):
110n/a line = self.rfile.readline()
111n/a self.wfile.write(line)
112n/a
113n/a if verbose: print("creating server")
114n/a server = MyServer(addr, MyHandler)
115n/a self.assertEqual(server.server_address, server.socket.getsockname())
116n/a return server
117n/a
118n/a @reap_threads
119n/a def run_server(self, svrcls, hdlrbase, testfunc):
120n/a server = self.make_server(self.pickaddr(svrcls.address_family),
121n/a svrcls, hdlrbase)
122n/a # We had the OS pick a port, so pull the real address out of
123n/a # the server.
124n/a addr = server.server_address
125n/a if verbose:
126n/a print("ADDR =", addr)
127n/a print("CLASS =", svrcls)
128n/a
129n/a t = threading.Thread(
130n/a name='%s serving' % svrcls,
131n/a target=server.serve_forever,
132n/a # Short poll interval to make the test finish quickly.
133n/a # Time between requests is short enough that we won't wake
134n/a # up spuriously too many times.
135n/a kwargs={'poll_interval':0.01})
136n/a t.daemon = True # In case this function raises.
137n/a t.start()
138n/a if verbose: print("server running")
139n/a for i in range(3):
140n/a if verbose: print("test client", i)
141n/a testfunc(svrcls.address_family, addr)
142n/a if verbose: print("waiting for server")
143n/a server.shutdown()
144n/a t.join()
145n/a server.server_close()
146n/a self.assertEqual(-1, server.socket.fileno())
147n/a if verbose: print("done")
148n/a
149n/a def stream_examine(self, proto, addr):
150n/a s = socket.socket(proto, socket.SOCK_STREAM)
151n/a s.connect(addr)
152n/a s.sendall(TEST_STR)
153n/a buf = data = receive(s, 100)
154n/a while data and b'\n' not in buf:
155n/a data = receive(s, 100)
156n/a buf += data
157n/a self.assertEqual(buf, TEST_STR)
158n/a s.close()
159n/a
160n/a def dgram_examine(self, proto, addr):
161n/a s = socket.socket(proto, socket.SOCK_DGRAM)
162n/a if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
163n/a s.bind(self.pickaddr(proto))
164n/a s.sendto(TEST_STR, addr)
165n/a buf = data = receive(s, 100)
166n/a while data and b'\n' not in buf:
167n/a data = receive(s, 100)
168n/a buf += data
169n/a self.assertEqual(buf, TEST_STR)
170n/a s.close()
171n/a
172n/a def test_TCPServer(self):
173n/a self.run_server(socketserver.TCPServer,
174n/a socketserver.StreamRequestHandler,
175n/a self.stream_examine)
176n/a
177n/a def test_ThreadingTCPServer(self):
178n/a self.run_server(socketserver.ThreadingTCPServer,
179n/a socketserver.StreamRequestHandler,
180n/a self.stream_examine)
181n/a
182n/a @requires_forking
183n/a def test_ForkingTCPServer(self):
184n/a with simple_subprocess(self):
185n/a self.run_server(socketserver.ForkingTCPServer,
186n/a socketserver.StreamRequestHandler,
187n/a self.stream_examine)
188n/a
189n/a @requires_unix_sockets
190n/a def test_UnixStreamServer(self):
191n/a self.run_server(socketserver.UnixStreamServer,
192n/a socketserver.StreamRequestHandler,
193n/a self.stream_examine)
194n/a
195n/a @requires_unix_sockets
196n/a def test_ThreadingUnixStreamServer(self):
197n/a self.run_server(socketserver.ThreadingUnixStreamServer,
198n/a socketserver.StreamRequestHandler,
199n/a self.stream_examine)
200n/a
201n/a @requires_unix_sockets
202n/a @requires_forking
203n/a def test_ForkingUnixStreamServer(self):
204n/a with simple_subprocess(self):
205n/a self.run_server(ForkingUnixStreamServer,
206n/a socketserver.StreamRequestHandler,
207n/a self.stream_examine)
208n/a
209n/a def test_UDPServer(self):
210n/a self.run_server(socketserver.UDPServer,
211n/a socketserver.DatagramRequestHandler,
212n/a self.dgram_examine)
213n/a
214n/a def test_ThreadingUDPServer(self):
215n/a self.run_server(socketserver.ThreadingUDPServer,
216n/a socketserver.DatagramRequestHandler,
217n/a self.dgram_examine)
218n/a
219n/a @requires_forking
220n/a def test_ForkingUDPServer(self):
221n/a with simple_subprocess(self):
222n/a self.run_server(socketserver.ForkingUDPServer,
223n/a socketserver.DatagramRequestHandler,
224n/a self.dgram_examine)
225n/a
226n/a @requires_unix_sockets
227n/a def test_UnixDatagramServer(self):
228n/a self.run_server(socketserver.UnixDatagramServer,
229n/a socketserver.DatagramRequestHandler,
230n/a self.dgram_examine)
231n/a
232n/a @requires_unix_sockets
233n/a def test_ThreadingUnixDatagramServer(self):
234n/a self.run_server(socketserver.ThreadingUnixDatagramServer,
235n/a socketserver.DatagramRequestHandler,
236n/a self.dgram_examine)
237n/a
238n/a @requires_unix_sockets
239n/a @requires_forking
240n/a def test_ForkingUnixDatagramServer(self):
241n/a self.run_server(ForkingUnixDatagramServer,
242n/a socketserver.DatagramRequestHandler,
243n/a self.dgram_examine)
244n/a
245n/a @reap_threads
246n/a def test_shutdown(self):
247n/a # Issue #2302: shutdown() should always succeed in making an
248n/a # other thread leave serve_forever().
249n/a class MyServer(socketserver.TCPServer):
250n/a pass
251n/a
252n/a class MyHandler(socketserver.StreamRequestHandler):
253n/a pass
254n/a
255n/a threads = []
256n/a for i in range(20):
257n/a s = MyServer((HOST, 0), MyHandler)
258n/a t = threading.Thread(
259n/a name='MyServer serving',
260n/a target=s.serve_forever,
261n/a kwargs={'poll_interval':0.01})
262n/a t.daemon = True # In case this function raises.
263n/a threads.append((t, s))
264n/a for t, s in threads:
265n/a t.start()
266n/a s.shutdown()
267n/a for t, s in threads:
268n/a t.join()
269n/a s.server_close()
270n/a
271n/a def test_tcpserver_bind_leak(self):
272n/a # Issue #22435: the server socket wouldn't be closed if bind()/listen()
273n/a # failed.
274n/a # Create many servers for which bind() will fail, to see if this result
275n/a # in FD exhaustion.
276n/a for i in range(1024):
277n/a with self.assertRaises(OverflowError):
278n/a socketserver.TCPServer((HOST, -1),
279n/a socketserver.StreamRequestHandler)
280n/a
281n/a def test_context_manager(self):
282n/a with socketserver.TCPServer((HOST, 0),
283n/a socketserver.StreamRequestHandler) as server:
284n/a pass
285n/a self.assertEqual(-1, server.socket.fileno())
286n/a
287n/a
288n/aclass ErrorHandlerTest(unittest.TestCase):
289n/a """Test that the servers pass normal exceptions from the handler to
290n/a handle_error(), and that exiting exceptions like SystemExit and
291n/a KeyboardInterrupt are not passed."""
292n/a
293n/a def tearDown(self):
294n/a test.support.unlink(test.support.TESTFN)
295n/a
296n/a def test_sync_handled(self):
297n/a BaseErrorTestServer(ValueError)
298n/a self.check_result(handled=True)
299n/a
300n/a def test_sync_not_handled(self):
301n/a with self.assertRaises(SystemExit):
302n/a BaseErrorTestServer(SystemExit)
303n/a self.check_result(handled=False)
304n/a
305n/a @unittest.skipUnless(threading, 'Threading required for this test.')
306n/a def test_threading_handled(self):
307n/a ThreadingErrorTestServer(ValueError)
308n/a self.check_result(handled=True)
309n/a
310n/a @unittest.skipUnless(threading, 'Threading required for this test.')
311n/a def test_threading_not_handled(self):
312n/a ThreadingErrorTestServer(SystemExit)
313n/a self.check_result(handled=False)
314n/a
315n/a @requires_forking
316n/a def test_forking_handled(self):
317n/a ForkingErrorTestServer(ValueError)
318n/a self.check_result(handled=True)
319n/a
320n/a @requires_forking
321n/a def test_forking_not_handled(self):
322n/a ForkingErrorTestServer(SystemExit)
323n/a self.check_result(handled=False)
324n/a
325n/a def check_result(self, handled):
326n/a with open(test.support.TESTFN) as log:
327n/a expected = 'Handler called\n' + 'Error handled\n' * handled
328n/a self.assertEqual(log.read(), expected)
329n/a
330n/a
331n/aclass BaseErrorTestServer(socketserver.TCPServer):
332n/a def __init__(self, exception):
333n/a self.exception = exception
334n/a super().__init__((HOST, 0), BadHandler)
335n/a with socket.create_connection(self.server_address):
336n/a pass
337n/a try:
338n/a self.handle_request()
339n/a finally:
340n/a self.server_close()
341n/a self.wait_done()
342n/a
343n/a def handle_error(self, request, client_address):
344n/a with open(test.support.TESTFN, 'a') as log:
345n/a log.write('Error handled\n')
346n/a
347n/a def wait_done(self):
348n/a pass
349n/a
350n/a
351n/aclass BadHandler(socketserver.BaseRequestHandler):
352n/a def handle(self):
353n/a with open(test.support.TESTFN, 'a') as log:
354n/a log.write('Handler called\n')
355n/a raise self.server.exception('Test error')
356n/a
357n/a
358n/aclass ThreadingErrorTestServer(socketserver.ThreadingMixIn,
359n/a BaseErrorTestServer):
360n/a def __init__(self, *pos, **kw):
361n/a self.done = threading.Event()
362n/a super().__init__(*pos, **kw)
363n/a
364n/a def shutdown_request(self, *pos, **kw):
365n/a super().shutdown_request(*pos, **kw)
366n/a self.done.set()
367n/a
368n/a def wait_done(self):
369n/a self.done.wait()
370n/a
371n/a
372n/aif HAVE_FORKING:
373n/a class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
374n/a def wait_done(self):
375n/a [child] = self.active_children
376n/a os.waitpid(child, 0)
377n/a self.active_children.clear()
378n/a
379n/a
380n/aclass SocketWriterTest(unittest.TestCase):
381n/a def test_basics(self):
382n/a class Handler(socketserver.StreamRequestHandler):
383n/a def handle(self):
384n/a self.server.wfile = self.wfile
385n/a self.server.wfile_fileno = self.wfile.fileno()
386n/a self.server.request_fileno = self.request.fileno()
387n/a
388n/a server = socketserver.TCPServer((HOST, 0), Handler)
389n/a self.addCleanup(server.server_close)
390n/a s = socket.socket(
391n/a server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
392n/a with s:
393n/a s.connect(server.server_address)
394n/a server.handle_request()
395n/a self.assertIsInstance(server.wfile, io.BufferedIOBase)
396n/a self.assertEqual(server.wfile_fileno, server.request_fileno)
397n/a
398n/a @unittest.skipUnless(threading, 'Threading required for this test.')
399n/a def test_write(self):
400n/a # Test that wfile.write() sends data immediately, and that it does
401n/a # not truncate sends when interrupted by a Unix signal
402n/a pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
403n/a
404n/a class Handler(socketserver.StreamRequestHandler):
405n/a def handle(self):
406n/a self.server.sent1 = self.wfile.write(b'write data\n')
407n/a # Should be sent immediately, without requiring flush()
408n/a self.server.received = self.rfile.readline()
409n/a big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
410n/a self.server.sent2 = self.wfile.write(big_chunk)
411n/a
412n/a server = socketserver.TCPServer((HOST, 0), Handler)
413n/a self.addCleanup(server.server_close)
414n/a interrupted = threading.Event()
415n/a
416n/a def signal_handler(signum, frame):
417n/a interrupted.set()
418n/a
419n/a original = signal.signal(signal.SIGUSR1, signal_handler)
420n/a self.addCleanup(signal.signal, signal.SIGUSR1, original)
421n/a response1 = None
422n/a received2 = None
423n/a main_thread = threading.get_ident()
424n/a
425n/a def run_client():
426n/a s = socket.socket(server.address_family, socket.SOCK_STREAM,
427n/a socket.IPPROTO_TCP)
428n/a with s, s.makefile('rb') as reader:
429n/a s.connect(server.server_address)
430n/a nonlocal response1
431n/a response1 = reader.readline()
432n/a s.sendall(b'client response\n')
433n/a
434n/a reader.read(100)
435n/a # The main thread should now be blocking in a send() syscall.
436n/a # But in theory, it could get interrupted by other signals,
437n/a # and then retried. So keep sending the signal in a loop, in
438n/a # case an earlier signal happens to be delivered at an
439n/a # inconvenient moment.
440n/a while True:
441n/a pthread_kill(main_thread, signal.SIGUSR1)
442n/a if interrupted.wait(timeout=float(1)):
443n/a break
444n/a nonlocal received2
445n/a received2 = len(reader.read())
446n/a
447n/a background = threading.Thread(target=run_client)
448n/a background.start()
449n/a server.handle_request()
450n/a background.join()
451n/a self.assertEqual(server.sent1, len(response1))
452n/a self.assertEqual(response1, b'write data\n')
453n/a self.assertEqual(server.received, b'client response\n')
454n/a self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
455n/a self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
456n/a
457n/a
458n/aclass MiscTestCase(unittest.TestCase):
459n/a
460n/a def test_all(self):
461n/a # objects defined in the module should be in __all__
462n/a expected = []
463n/a for name in dir(socketserver):
464n/a if not name.startswith('_'):
465n/a mod_object = getattr(socketserver, name)
466n/a if getattr(mod_object, '__module__', None) == 'socketserver':
467n/a expected.append(name)
468n/a self.assertCountEqual(socketserver.__all__, expected)
469n/a
470n/a def test_shutdown_request_called_if_verify_request_false(self):
471n/a # Issue #26309: BaseServer should call shutdown_request even if
472n/a # verify_request is False
473n/a
474n/a class MyServer(socketserver.TCPServer):
475n/a def verify_request(self, request, client_address):
476n/a return False
477n/a
478n/a shutdown_called = 0
479n/a def shutdown_request(self, request):
480n/a self.shutdown_called += 1
481n/a socketserver.TCPServer.shutdown_request(self, request)
482n/a
483n/a server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
484n/a s = socket.socket(server.address_family, socket.SOCK_STREAM)
485n/a s.connect(server.server_address)
486n/a s.close()
487n/a server.handle_request()
488n/a self.assertEqual(server.shutdown_called, 1)
489n/a server.server_close()
490n/a
491n/a
492n/aif __name__ == "__main__":
493n/a unittest.main()