From 8f6257b37f071d670e4d3750afd1d7b63ddf5149 Mon Sep 17 00:00:00 2001 From: Roel Bruggink Date: Sun, 10 Mar 2019 03:11:28 +0100 Subject: [PATCH] Make open a context manager --- tests/snippets/builtin_open.py | 4 ++++ vm/src/stdlib/io.rs | 25 ++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/snippets/builtin_open.py b/tests/snippets/builtin_open.py index 55eafcd6ce..41bcf2b800 100644 --- a/tests/snippets/builtin_open.py +++ b/tests/snippets/builtin_open.py @@ -4,3 +4,7 @@ assert 'RustPython' in fd.read() assert_raises(FileNotFoundError, lambda: open('DoesNotExist')) + +# Use open as a context manager +with open('README.md') as fp: + fp.read() diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index f35d5eb041..9d5600498a 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -60,6 +60,26 @@ fn bytes_io_getvalue(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.get_none()) } +fn io_base_cm_enter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(instance, None)]); + Ok(instance.clone()) +} + +fn io_base_cm_exit(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + // The context manager protocol requires these, but we don't use them + required = [ + (_instance, None), + (_exception_type, None), + (_exception_value, None), + (_traceback, None) + ] + ); + Ok(vm.get_none()) +} + fn buffered_io_base_init(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(buffered, None), (raw, None)]); vm.ctx.set_attr(&buffered, "raw", raw.clone()); @@ -347,7 +367,10 @@ pub fn io_open(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn mk_module(ctx: &PyContext) -> PyObjectRef { //IOBase the abstract base class of the IO Module - let io_base = py_class!(ctx, "IOBase", ctx.object(), {}); + let io_base = py_class!(ctx, "IOBase", ctx.object(), { + "__enter__" => ctx.new_rustfunc(io_base_cm_enter), + "__exit__" => ctx.new_rustfunc(io_base_cm_exit) + }); // IOBase Subclasses let raw_io_base = py_class!(ctx, "RawIOBase", ctx.object(), {});