Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions Lib/unittest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
__version__ = '1.0'


import io
import inspect
import pprint
import sys
Expand Down Expand Up @@ -2379,20 +2380,27 @@ def mock_open(mock=None, read_data=''):
`read_data` is a string for the `read`, `readline` and `readlines` of the
file handle to return. This is an empty string by default.
"""
if isinstance(read_data, bytes):
_read_data = io.BytesIO(read_data)
else:
_read_data = io.StringIO(read_data)

_state = [_read_data, None]

def _readlines_side_effect(*args, **kwargs):
if handle.readlines.return_value is not None:
return handle.readlines.return_value
return list(_state[0])
return _state[0].readlines(*args, **kwargs)

def _read_side_effect(*args, **kwargs):
if handle.read.return_value is not None:
return handle.read.return_value
return type(read_data)().join(_state[0])
return _state[0].read(*args, **kwargs)

def _readline_side_effect():
def _readline_side_effect(*args, **kwargs):
yield from _iter_side_effect()
while True:
yield type(read_data)()
yield _state[0].readline(*args, **kwargs)

def _iter_side_effect():
if handle.readline.return_value is not None:
Expand All @@ -2412,8 +2420,6 @@ def _iter_side_effect():
handle = MagicMock(spec=file_spec)
handle.__enter__.return_value = handle

_state = [_iterate_read_data(read_data), None]

handle.write.return_value = None
handle.read.return_value = None
handle.readline.return_value = None
Expand All @@ -2426,7 +2432,10 @@ def _iter_side_effect():
handle.__iter__.side_effect = _iter_side_effect

def reset_data(*args, **kwargs):
_state[0] = _iterate_read_data(read_data)
if isinstance(read_data, bytes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to get a function for this. Something like "_read_data_to_stream" to remove duplication.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks :)

_state[0] = io.BytesIO(read_data)
else:
_state[0] = io.StringIO(read_data)
if handle.readline.side_effect == _state[1]:
# Only reset the side effect if the user hasn't overridden it.
_state[1] = _readline_side_effect()
Expand Down
2 changes: 1 addition & 1 deletion Lib/unittest/test/testmock/testwith.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_mock_open_read_with_argument(self):
# for mocks returned by mock_open
some_data = 'foo\nbar\nbaz'
mock = mock_open(read_data=some_data)
self.assertEqual(mock().read(10), some_data)
self.assertEqual(mock().read(10), some_data[:10])


def test_interleaved_reads(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
unittest.mock.mock_open() results now respects the argument of read([size]).
Patch contributed by Rémi Lapeyre.