diff --git a/asyncio/events.py b/asyncio/events.py index 176a8466..7b600431 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -457,9 +457,15 @@ def remove_writer(self, fd): def sock_recv(self, sock, nbytes): raise NotImplementedError + def sock_recvfrom(self, sock, nbytes): + raise NotImplementedError + def sock_sendall(self, sock, data): raise NotImplementedError + def sock_sendto(self, sock, data, address): + raise NotImplementedError + def sock_connect(self, sock, address): raise NotImplementedError diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 812fac19..229579f1 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -315,10 +315,27 @@ def sock_recv(self, sock, n): if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) - self._sock_recv(fut, False, sock, n) + self._sock_recv(sock.recv, fut, False, sock, n) return fut - def _sock_recv(self, fut, registered, sock, n): + def sock_recvfrom(self, sock, n): + """Receive data from the socket. + + Receive data from the socket. The return value is a pair (bytes, address) + where bytes is a bytes object representing the data received and address + is the address of the socket sending the data. + The maximum amount of data to be received at once is specified by + nbytes. + + This method is a coroutine. + """ + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + self._sock_recv(sock.recvfrom, fut, False, sock, n) + return fut + + def _sock_recv(self, method, fut, registered, sock, n): # _sock_recv() can add itself as an I/O callback if the operation can't # be done immediately. Don't use it directly, call sock_recv(). fd = sock.fileno() @@ -331,9 +348,9 @@ def _sock_recv(self, fut, registered, sock, n): if fut.cancelled(): return try: - data = sock.recv(n) + data = method(n) except (BlockingIOError, InterruptedError): - self.add_reader(fd, self._sock_recv, fut, True, sock, n) + self.add_reader(fd, self._sock_recv, method, fut, True, sock, n) except Exception as exc: fut.set_exception(exc) else: @@ -382,6 +399,42 @@ def _sock_sendall(self, fut, registered, sock, data): data = data[n:] self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + def sock_sendto(self, sock, data, address): + """Send data to the socket. + + The socket should not be connected to a remote socket, since the + destination socket is specified by address. + + This method is a coroutine. + """ + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + if data: + self._sock_sendto(fut, False, sock, data, address) + else: + fut.set_result(None) + return fut + + def _sock_sendto(self, fut, registered, sock, data, address): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + sock.sendto(data, address) + except (BlockingIOError, InterruptedError): + # if sendto blocks, wait for it to be ready + self.add_writer(fd, self._sock_sendto, fut, True, sock, data, address) + except Exception as exc: + fut.set_exception(exc) + return + else: + fut.set_result(None) + def sock_connect(self, sock, address): """Connect to a remote socket at address. diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 135b5abf..9bc03c8a 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -183,7 +183,15 @@ def test_sock_recv(self): f = self.loop.sock_recv(sock, 1024) self.assertIsInstance(f, asyncio.Future) - self.loop._sock_recv.assert_called_with(f, False, sock, 1024) + self.loop._sock_recv.assert_called_with(sock.recv, f, False, sock, 1024) + + def test_sock_recfrom(self): + sock = test_utils.mock_nonblocking_socket() + self.loop._sock_recv = mock.Mock() + + f = self.loop.sock_recvfrom(sock, 1024) + self.assertIsInstance(f, asyncio.Future) + self.loop._sock_recv.assert_called_with(sock.recvfrom, f, False, sock, 1024) def test__sock_recv_canceled_fut(self): sock = mock.Mock() @@ -191,7 +199,7 @@ def test__sock_recv_canceled_fut(self): f = asyncio.Future(loop=self.loop) f.cancel() - self.loop._sock_recv(f, False, sock, 1024) + self.loop._sock_recv(sock.recv, f, False, sock, 1024) self.assertFalse(sock.recv.called) def test__sock_recv_unregister(self): @@ -202,7 +210,7 @@ def test__sock_recv_unregister(self): f.cancel() self.loop.remove_reader = mock.Mock() - self.loop._sock_recv(f, True, sock, 1024) + self.loop._sock_recv(sock.recv, f, True, sock, 1024) self.assertEqual((10,), self.loop.remove_reader.call_args[0]) def test__sock_recv_tryagain(self): @@ -212,8 +220,8 @@ def test__sock_recv_tryagain(self): sock.recv.side_effect = BlockingIOError self.loop.add_reader = mock.Mock() - self.loop._sock_recv(f, False, sock, 1024) - self.assertEqual((10, self.loop._sock_recv, f, True, sock, 1024), + self.loop._sock_recv(sock.recv, f, False, sock, 1024) + self.assertEqual((10, self.loop._sock_recv, sock.recv, f, True, sock, 1024), self.loop.add_reader.call_args[0]) def test__sock_recv_exception(self): @@ -222,7 +230,7 @@ def test__sock_recv_exception(self): sock.fileno.return_value = 10 err = sock.recv.side_effect = OSError() - self.loop._sock_recv(f, False, sock, 1024) + self.loop._sock_recv(sock.recv, f, False, sock, 1024) self.assertIs(err, f.exception()) def test_sock_sendall(self): @@ -337,6 +345,90 @@ def test__sock_sendall_none(self): (10, self.loop._sock_sendall, f, True, sock, b'data'), self.loop.add_writer.call_args[0]) + def test_sock_sendto(self): + sock = test_utils.mock_nonblocking_socket() + self.loop._sock_sendto = mock.Mock() + + f = self.loop.sock_sendto(sock, b'data', ('localhost', 80)) + self.assertIsInstance(f, asyncio.Future) + self.assertEqual( + (f, False, sock, b'data', ('localhost', 80)), + self.loop._sock_sendto.call_args[0]) + + def test_sock_sendto_nodata(self): + sock = test_utils.mock_nonblocking_socket() + self.loop._sock_sendto = mock.Mock() + + f = self.loop.sock_sendto(sock, b'', ('localhost', 80)) + self.assertIsInstance(f, asyncio.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + self.assertFalse(self.loop._sock_sendto.called) + + def test__sock_sendto_canceled_fut(self): + sock = mock.Mock() + + f = asyncio.Future(loop=self.loop) + f.cancel() + + self.loop._sock_sendto(f, False, sock, b'data', ('localhost', 80)) + self.assertFalse(sock.sendto.called) + + def test__sock_sendto_unregister(self): + sock = mock.Mock() + sock.fileno.return_value = 10 + + f = asyncio.Future(loop=self.loop) + f.cancel() + + self.loop.remove_writer = mock.Mock() + self.loop._sock_sendto(f, True, sock, b'data', ('localhost', 80)) + self.assertEqual((10,), self.loop.remove_writer.call_args[0]) + + def test__sock_sendto_tryagain(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.sendto.side_effect = BlockingIOError + + self.loop.add_writer = mock.Mock() + self.loop._sock_sendto(f, False, sock, b'data', ('localhost', 80)) + self.assertEqual( + (10, self.loop._sock_sendto, f, True, sock, b'data', ('localhost', 80)), + self.loop.add_writer.call_args[0]) + + def test__sock_sendto_interrupted(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.sendto.side_effect = InterruptedError + + self.loop.add_writer = mock.Mock() + self.loop._sock_sendto(f, False, sock, b'data', ('localhost', 80)) + self.assertEqual( + (10, self.loop._sock_sendto, f, True, sock, b'data', ('localhost', 80)), + self.loop.add_writer.call_args[0]) + + def test__sock_sendto_exception(self): + f = asyncio.Future(loop=self.loop) + sock = mock.Mock() + sock.fileno.return_value = 10 + err = sock.sendto.side_effect = OSError() + + self.loop._sock_sendto(f, False, sock, b'data', ('localhost', 80)) + self.assertIs(f.exception(), err) + + def test__sock_sendto(self): + sock = mock.Mock() + + f = asyncio.Future(loop=self.loop) + sock.fileno.return_value = 10 + sock.sendto.return_value = 4 + + self.loop._sock_sendto(f, False, sock, b'data', ('localhost', 80)) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + def test_sock_connect(self): sock = test_utils.mock_nonblocking_socket() self.loop._sock_connect = mock.Mock()