diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 3b4026cb73869a..2eac64c8165fcc 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -145,6 +145,27 @@ def test_ipaddr_info_no_inet_pton(self, m_socket): socket.SOCK_STREAM, socket.IPPROTO_TCP)) + def test_interleave_addrinfos(self): + SIX_A = (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)) + SIX_B = (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)) + SIX_C = (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)) + SIX_D = (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)) + FOUR_A = (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)) + FOUR_B = (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)) + FOUR_C = (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)) + FOUR_D = (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)) + + addrinfos = [SIX_A, SIX_B, SIX_C, SIX_D, FOUR_A, FOUR_B, FOUR_C, FOUR_D] + expected = [SIX_A, FOUR_A, SIX_B, FOUR_B, SIX_C, FOUR_C, SIX_D, FOUR_D] + + self.assertEqual(expected, base_events._interleave_addrinfos(addrinfos)) + + expected_fafc_2 = [SIX_A, SIX_B, FOUR_A, SIX_C, FOUR_B, SIX_D, FOUR_C, FOUR_D] + self.assertEqual( + expected_fafc_2, + base_events._interleave_addrinfos(addrinfos, first_address_family_count=2), + ) + class BaseEventLoopTests(test_utils.TestCase): @@ -1426,6 +1447,65 @@ def getaddrinfo_task(*args, **kwds): self.assertRaises( OSError, self.loop.run_until_complete, coro) + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'no IPv6 support') + @patch_socket + def test_create_connection_happy_eyeballs(self, m_socket): + + class MyProto(asyncio.Protocol): + pass + + async def getaddrinfo(*args, **kw): + return [(socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), + (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5))] + + async def sock_connect(sock, address): + if address[0] == '2001:db8::1': + await asyncio.sleep(1) + sock.connect(address) + + self.loop._add_reader = mock.Mock() + self.loop._add_writer = mock.Mock() + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = sock_connect + + coro = self.loop.create_connection(MyProto, 'example.com', 80, happy_eyeballs_delay=0.3) + transport, protocol = self.loop.run_until_complete(coro) + try: + sock = transport._sock + sock.connect.assert_called_with(('192.0.2.1', 5)) + finally: + transport.close() + test_utils.run_briefly(self.loop) # allow transport to close + + @patch_socket + def test_create_connection_happy_eyeballs_ipv4_only(self, m_socket): + + class MyProto(asyncio.Protocol): + pass + + async def getaddrinfo(*args, **kw): + return [(socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), + (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6))] + + async def sock_connect(sock, address): + if address[0] == '192.0.2.1': + await asyncio.sleep(1) + sock.connect(address) + + self.loop._add_reader = mock.Mock() + self.loop._add_writer = mock.Mock() + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = sock_connect + + coro = self.loop.create_connection(MyProto, 'example.com', 80, happy_eyeballs_delay=0.3) + transport, protocol = self.loop.run_until_complete(coro) + try: + sock = transport._sock + sock.connect.assert_called_with(('192.0.2.2', 6)) + finally: + transport.close() + test_utils.run_briefly(self.loop) # allow transport to close + @patch_socket def test_create_connection_bluetooth(self, m_socket): # See http://bugs.python.org/issue27136, fallback to getaddrinfo when diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py new file mode 100644 index 00000000000000..775f6f0901fa59 --- /dev/null +++ b/Lib/test/test_asyncio/test_staggered.py @@ -0,0 +1,115 @@ +import asyncio +import functools +import unittest +from asyncio.staggered import staggered_race + + +# To prevent a warning "test altered the execution environment" +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class TestStaggered(unittest.IsolatedAsyncioTestCase): + @staticmethod + async def waiting_coroutine(return_value, wait_seconds, success): + await asyncio.sleep(wait_seconds) + if success: + return return_value + raise RuntimeError(str(return_value)) + + def get_waiting_coroutine_factory(self, return_value, wait_seconds, success): + return functools.partial(self.waiting_coroutine, return_value, wait_seconds, success) + + async def test_single_success(self): + winner_result, winner_idx, exceptions = await staggered_race( + (self.get_waiting_coroutine_factory(0, 0.1, True),), + 0.1, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0) + self.assertEqual(len(exceptions), 1) + self.assertIsNone(exceptions[0]) + + async def test_single_fail(self): + winner_result, winner_idx, exceptions = await staggered_race( + (self.get_waiting_coroutine_factory(0, 0.1, False),), + 0.1, + ) + self.assertEqual(winner_result, None) + self.assertEqual(winner_idx, None) + self.assertEqual(len(exceptions), 1) + self.assertIsInstance(exceptions[0], RuntimeError) + + async def test_first_win(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.2, True), + self.get_waiting_coroutine_factory(1, 0.2, True), + ), + 0.1, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0) + self.assertEqual(len(exceptions), 2) + self.assertIsNone(exceptions[0]) + self.assertIsInstance(exceptions[1], asyncio.CancelledError) + + async def test_second_win(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.3, True), + self.get_waiting_coroutine_factory(1, 0.1, True), + ), + 0.1, + ) + self.assertEqual(winner_result, 1) + self.assertEqual(winner_idx, 1) + self.assertEqual(len(exceptions), 2) + self.assertIsInstance(exceptions[0], asyncio.CancelledError) + self.assertIsNone(exceptions[1]) + + async def test_first_fail(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.2, False), + self.get_waiting_coroutine_factory(1, 0.2, True), + ), + 0.1, + ) + self.assertEqual(winner_result, 1) + self.assertEqual(winner_idx, 1) + self.assertEqual(len(exceptions), 2) + self.assertIsInstance(exceptions[0], RuntimeError) + self.assertIsNone(exceptions[1]) + + async def test_second_fail(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.2, True), + self.get_waiting_coroutine_factory(1, 0, False), + ), + 0.1, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0) + self.assertEqual(len(exceptions), 2) + self.assertIsNone(exceptions[0]) + self.assertIsInstance(exceptions[1], RuntimeError) + + async def test_simultaneous_success_fail(self): + # There's a potential race condition here: + # https://github.com/python/cpython/issues/86296 + # As with any race condition, it can be difficult to reproduce. + # This test may not fail every time. + for i in range(201): + time_unit = 0.0001 * i + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, time_unit*2, True), + self.get_waiting_coroutine_factory(1, time_unit, False), + self.get_waiting_coroutine_factory(2, 0.05, True) + ), + time_unit, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: