diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py index b7e5784b48..804cce1dba 100644 --- a/Lib/test/test_code.py +++ b/Lib/test/test_code.py @@ -222,8 +222,6 @@ class List(list): obj = List([1, 2, 3]) self.assertEqual(obj[0], "Foreign getitem: 1") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_constructor(self): def func(): pass co = func.__code__ @@ -255,8 +253,6 @@ def test_qualname(self): CodeTest.test_qualname.__qualname__ ) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_replace(self): def func(): x = 1 @@ -297,8 +293,6 @@ def func2(): self.assertEqual(new_code.co_varnames, code2.co_varnames) self.assertEqual(new_code.co_nlocals, code2.co_nlocals) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_nlocals_mismatch(self): def func(): x = 1 diff --git a/Lib/test/test_funcattrs.py b/Lib/test/test_funcattrs.py index 3d5378092b..e06e9f7f4a 100644 --- a/Lib/test/test_funcattrs.py +++ b/Lib/test/test_funcattrs.py @@ -65,8 +65,6 @@ def duplicate(): return 3 self.assertNotEqual(self.b, duplicate) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_copying___code__(self): def test(): pass self.assertEqual(test(), None) diff --git a/compiler/core/src/marshal.rs b/compiler/core/src/marshal.rs index 5e16e59102..f803317772 100644 --- a/compiler/core/src/marshal.rs +++ b/compiler/core/src/marshal.rs @@ -165,6 +165,19 @@ impl<'a> ReadBorrowed<'a> for &'a [u8] { } } +/// Parses bytecode bytes into CodeUnit instructions. +/// Each instruction is 2 bytes: opcode and argument. +pub fn parse_instructions_from_bytes(bytes: &[u8]) -> Result> { + bytes + .chunks_exact(2) + .map(|cu| { + let op = Instruction::try_from(cu[0])?; + let arg = OpArgByte(cu[1]); + Ok(CodeUnit { op, arg }) + }) + .collect() +} + pub struct Cursor { pub data: B, pub position: usize, @@ -185,14 +198,7 @@ pub fn deserialize_code( ) -> Result> { let len = rdr.read_u32()?; let instructions = rdr.read_slice(len * 2)?; - let instructions = instructions - .chunks_exact(2) - .map(|cu| { - let op = Instruction::try_from(cu[0])?; - let arg = OpArgByte(cu[1]); - Ok(CodeUnit { op, arg }) - }) - .collect::>>()?; + let instructions = parse_instructions_from_bytes(instructions)?; let len = rdr.read_u32()?; let locations = (0..len) diff --git a/vm/src/builtins/code.rs b/vm/src/builtins/code.rs index 79ad896aaf..2a22993a9e 100644 --- a/vm/src/builtins/code.rs +++ b/vm/src/builtins/code.rs @@ -2,21 +2,24 @@ */ -use super::{PyStrRef, PyTupleRef, PyType, PyTypeRef}; +use super::{PyBytesRef, PyStrRef, PyTupleRef, PyType, PyTypeRef}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::PyStrInterned, - bytecode::{self, AsBag, BorrowedConstant, CodeFlags, Constant, ConstantBag}, + bytecode::{self, AsBag, BorrowedConstant, CodeFlags, CodeUnit, Constant, ConstantBag}, class::{PyClassImpl, StaticType}, convert::ToPyObject, frozen, - function::{FuncArgs, OptionalArg}, - types::Representable, + function::OptionalArg, + types::{Constructor, Representable}, }; use malachite_bigint::BigInt; use num_traits::Zero; -use rustpython_compiler_core::OneIndexed; -use rustpython_compiler_core::bytecode::PyCodeLocationInfoKind; +use rustpython_compiler_core::{ + OneIndexed, + bytecode::PyCodeLocationInfoKind, + marshal::{MarshalError, parse_instructions_from_bytes}, +}; use std::{borrow::Borrow, fmt, ops::Deref}; /// State for iterating through code address ranges @@ -367,13 +370,158 @@ impl Representable for PyCode { } } -#[pyclass(with(Representable))] -impl PyCode { - #[pyslot] - fn slot_new(_cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error("Cannot directly create code object")) +// Arguments for code object constructor +#[derive(FromArgs)] +pub struct PyCodeNewArgs { + argcount: u32, + posonlyargcount: u32, + kwonlyargcount: u32, + nlocals: u32, + stacksize: u32, + flags: u16, + co_code: PyBytesRef, + consts: PyTupleRef, + names: PyTupleRef, + varnames: PyTupleRef, + filename: PyStrRef, + name: PyStrRef, + qualname: PyStrRef, + firstlineno: i32, + linetable: PyBytesRef, + exceptiontable: PyBytesRef, + freevars: PyTupleRef, + cellvars: PyTupleRef, +} + +impl Constructor for PyCode { + type Args = PyCodeNewArgs; + + fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult { + // Convert names tuple to vector of interned strings + let names: Box<[&'static PyStrInterned]> = args + .names + .iter() + .map(|obj| { + let s = obj.downcast_ref::().ok_or_else(|| { + vm.new_type_error("names must be tuple of strings".to_owned()) + })?; + Ok(vm.ctx.intern_str(s.as_str())) + }) + .collect::>>()? + .into_boxed_slice(); + + let varnames: Box<[&'static PyStrInterned]> = args + .varnames + .iter() + .map(|obj| { + let s = obj.downcast_ref::().ok_or_else(|| { + vm.new_type_error("varnames must be tuple of strings".to_owned()) + })?; + Ok(vm.ctx.intern_str(s.as_str())) + }) + .collect::>>()? + .into_boxed_slice(); + + let cellvars: Box<[&'static PyStrInterned]> = args + .cellvars + .iter() + .map(|obj| { + let s = obj.downcast_ref::().ok_or_else(|| { + vm.new_type_error("cellvars must be tuple of strings".to_owned()) + })?; + Ok(vm.ctx.intern_str(s.as_str())) + }) + .collect::>>()? + .into_boxed_slice(); + + let freevars: Box<[&'static PyStrInterned]> = args + .freevars + .iter() + .map(|obj| { + let s = obj.downcast_ref::().ok_or_else(|| { + vm.new_type_error("freevars must be tuple of strings".to_owned()) + })?; + Ok(vm.ctx.intern_str(s.as_str())) + }) + .collect::>>()? + .into_boxed_slice(); + + // Check nlocals matches varnames length + if args.nlocals as usize != varnames.len() { + return Err(vm.new_value_error(format!( + "nlocals ({}) != len(varnames) ({})", + args.nlocals, + varnames.len() + ))); + } + + // Parse and validate bytecode from bytes + let bytecode_bytes = args.co_code.as_bytes(); + let instructions = parse_bytecode(bytecode_bytes) + .map_err(|e| vm.new_value_error(format!("invalid bytecode: {}", e)))?; + + // Convert constants + let constants: Box<[Literal]> = args + .consts + .iter() + .map(|obj| { + // Convert PyObject to Literal constant + // For now, just wrap it + Literal(obj.clone()) + }) + .collect::>() + .into_boxed_slice(); + + // Create locations + let row = if args.firstlineno > 0 { + OneIndexed::new(args.firstlineno as usize).unwrap_or(OneIndexed::MIN) + } else { + OneIndexed::MIN + }; + let locations: Box<[rustpython_compiler_core::SourceLocation]> = vec![ + rustpython_compiler_core::SourceLocation { + line: row, + character_offset: OneIndexed::from_zero_indexed(0), + }; + instructions.len() + ] + .into_boxed_slice(); + + // Build the CodeObject + let code = CodeObject { + instructions, + locations, + flags: CodeFlags::from_bits_truncate(args.flags), + posonlyarg_count: args.posonlyargcount, + arg_count: args.argcount, + kwonlyarg_count: args.kwonlyargcount, + source_path: vm.ctx.intern_str(args.filename.as_str()), + first_line_number: if args.firstlineno > 0 { + OneIndexed::new(args.firstlineno as usize) + } else { + None + }, + max_stackdepth: args.stacksize, + obj_name: vm.ctx.intern_str(args.name.as_str()), + qualname: vm.ctx.intern_str(args.qualname.as_str()), + cell2arg: None, // TODO: reuse `fn cell2arg` + constants, + names, + varnames, + cellvars, + freevars, + linetable: args.linetable.as_bytes().to_vec().into_boxed_slice(), + exceptiontable: args.exceptiontable.as_bytes().to_vec().into_boxed_slice(), + }; + + Ok(PyCode::new(code) + .into_ref_with_type(vm, cls)? + .to_pyobject(vm)) } +} +#[pyclass(with(Representable, Constructor))] +impl PyCode { #[pygetset] const fn co_posonlyargcount(&self) -> usize { self.code.posonlyarg_count as usize @@ -397,9 +545,7 @@ impl PyCode { #[pygetset] pub fn co_cellvars(&self, vm: &VirtualMachine) -> PyTupleRef { let cellvars = self - .code .cellvars - .deref() .iter() .map(|name| name.to_pyobject(vm)) .collect(); @@ -408,7 +554,7 @@ impl PyCode { #[pygetset] fn co_nlocals(&self) -> usize { - self.varnames.len() + self.code.varnames.len() } #[pygetset] @@ -690,42 +836,62 @@ impl PyCode { #[pymethod] pub fn replace(&self, args: ReplaceArgs, vm: &VirtualMachine) -> PyResult { - let posonlyarg_count = match args.co_posonlyargcount { + let ReplaceArgs { + co_posonlyargcount, + co_argcount, + co_kwonlyargcount, + co_filename, + co_firstlineno, + co_consts, + co_name, + co_names, + co_flags, + co_varnames, + co_nlocals, + co_stacksize, + co_code, + co_linetable, + co_exceptiontable, + co_freevars, + co_cellvars, + co_qualname, + } = args; + let posonlyarg_count = match co_posonlyargcount { OptionalArg::Present(posonlyarg_count) => posonlyarg_count, OptionalArg::Missing => self.code.posonlyarg_count, }; - let arg_count = match args.co_argcount { + let arg_count = match co_argcount { OptionalArg::Present(arg_count) => arg_count, OptionalArg::Missing => self.code.arg_count, }; - let source_path = match args.co_filename { + let source_path = match co_filename { OptionalArg::Present(source_path) => source_path, OptionalArg::Missing => self.code.source_path.to_owned(), }; - let first_line_number = match args.co_firstlineno { + let first_line_number = match co_firstlineno { OptionalArg::Present(first_line_number) => OneIndexed::new(first_line_number as _), OptionalArg::Missing => self.code.first_line_number, }; - let kwonlyarg_count = match args.co_kwonlyargcount { + let kwonlyarg_count = match co_kwonlyargcount { OptionalArg::Present(kwonlyarg_count) => kwonlyarg_count, OptionalArg::Missing => self.code.kwonlyarg_count, }; - let constants = match args.co_consts { + let constants = match co_consts { OptionalArg::Present(constants) => constants, OptionalArg::Missing => self.code.constants.iter().map(|x| x.0.clone()).collect(), }; - let obj_name = match args.co_name { + let obj_name = match co_name { OptionalArg::Present(obj_name) => obj_name, OptionalArg::Missing => self.code.obj_name.to_owned(), }; - let names = match args.co_names { + let names = match co_names { OptionalArg::Present(names) => names, OptionalArg::Missing => self .code @@ -736,37 +902,36 @@ impl PyCode { .collect(), }; - let flags = match args.co_flags { + let flags = match co_flags { OptionalArg::Present(flags) => flags, OptionalArg::Missing => self.code.flags.bits(), }; - let varnames = match args.co_varnames { + let varnames = match co_varnames { OptionalArg::Present(varnames) => varnames, OptionalArg::Missing => self.code.varnames.iter().map(|s| s.to_object()).collect(), }; - let qualname = match args.co_qualname { + let qualname = match co_qualname { OptionalArg::Present(qualname) => qualname, OptionalArg::Missing => self.code.qualname.to_owned(), }; - let max_stackdepth = match args.co_stacksize { + let max_stackdepth = match co_stacksize { OptionalArg::Present(stacksize) => stacksize, OptionalArg::Missing => self.code.max_stackdepth, }; - let instructions = match args.co_code { - OptionalArg::Present(_code_bytes) => { - // Convert bytes back to instructions - // For now, keep the original instructions - // TODO: Properly parse bytecode from bytes - self.code.instructions.clone() + let instructions = match co_code { + OptionalArg::Present(code_bytes) => { + // Parse and validate bytecode from bytes + parse_bytecode(code_bytes.as_bytes()) + .map_err(|e| vm.new_value_error(format!("invalid bytecode: {}", e)))? } OptionalArg::Missing => self.code.instructions.clone(), }; - let cellvars = match args.co_cellvars { + let cellvars = match co_cellvars { OptionalArg::Present(cellvars) => cellvars .into_iter() .map(|o| o.as_interned_str(vm).unwrap()) @@ -774,7 +939,7 @@ impl PyCode { OptionalArg::Missing => self.code.cellvars.clone(), }; - let freevars = match args.co_freevars { + let freevars = match co_freevars { OptionalArg::Present(freevars) => freevars .into_iter() .map(|o| o.as_interned_str(vm).unwrap()) @@ -783,7 +948,7 @@ impl PyCode { }; // Validate co_nlocals if provided - if let OptionalArg::Present(nlocals) = args.co_nlocals + if let OptionalArg::Present(nlocals) = co_nlocals && nlocals as usize != varnames.len() { return Err(vm.new_value_error(format!( @@ -794,48 +959,50 @@ impl PyCode { } // Handle linetable and exceptiontable - let linetable = match args.co_linetable { + let linetable = match co_linetable { OptionalArg::Present(linetable) => linetable.as_bytes().to_vec().into_boxed_slice(), OptionalArg::Missing => self.code.linetable.clone(), }; - let exceptiontable = match args.co_exceptiontable { + let exceptiontable = match co_exceptiontable { OptionalArg::Present(exceptiontable) => { exceptiontable.as_bytes().to_vec().into_boxed_slice() } OptionalArg::Missing => self.code.exceptiontable.clone(), }; - Ok(Self { - code: CodeObject { - flags: CodeFlags::from_bits_truncate(flags), - posonlyarg_count, - arg_count, - kwonlyarg_count, - source_path: source_path.as_object().as_interned_str(vm).unwrap(), - first_line_number, - obj_name: obj_name.as_object().as_interned_str(vm).unwrap(), - qualname: qualname.as_object().as_interned_str(vm).unwrap(), - - max_stackdepth, - instructions, - locations: self.code.locations.clone(), - constants: constants.into_iter().map(Literal).collect(), - names: names - .into_iter() - .map(|o| o.as_interned_str(vm).unwrap()) - .collect(), - varnames: varnames - .into_iter() - .map(|o| o.as_interned_str(vm).unwrap()) - .collect(), - cellvars, - freevars, - cell2arg: self.code.cell2arg.clone(), - linetable, - exceptiontable, - }, - }) + let new_code = CodeObject { + flags: CodeFlags::from_bits_truncate(flags), + posonlyarg_count, + arg_count, + kwonlyarg_count, + source_path: source_path.as_object().as_interned_str(vm).unwrap(), + first_line_number, + obj_name: obj_name.as_object().as_interned_str(vm).unwrap(), + qualname: qualname.as_object().as_interned_str(vm).unwrap(), + + max_stackdepth, + instructions, + // FIXME: invalid locations. Actually locations is a duplication of linetable. + // It can be removed once we move every other code to use linetable only. + locations: self.code.locations.clone(), + constants: constants.into_iter().map(Literal).collect(), + names: names + .into_iter() + .map(|o| o.as_interned_str(vm).unwrap()) + .collect(), + varnames: varnames + .into_iter() + .map(|o| o.as_interned_str(vm).unwrap()) + .collect(), + cellvars, + freevars, + cell2arg: self.code.cell2arg.clone(), + linetable, + exceptiontable, + }; + + Ok(PyCode::new(new_code)) } #[pymethod] @@ -866,6 +1033,19 @@ impl ToPyObject for bytecode::CodeObject { } } +/// Validates and parses bytecode bytes into CodeUnit instructions. +/// Returns MarshalError if bytecode is invalid (odd length or contains invalid opcodes). +/// Note: Returning MarshalError is not necessary at this point because this is not a part of marshalling API. +/// However, we (temporarily) reuse MarshalError for simplicity. +fn parse_bytecode(bytecode_bytes: &[u8]) -> Result, MarshalError> { + // Bytecode must have even length (each instruction is 2 bytes) + if !bytecode_bytes.len().is_multiple_of(2) { + return Err(MarshalError::InvalidBytecode); + } + + parse_instructions_from_bytes(bytecode_bytes) +} + // Helper struct for reading linetable struct LineTableReader<'a> { data: &'a [u8], diff --git a/vm/src/builtins/function.rs b/vm/src/builtins/function.rs index 02e983f2ff..19daf885fb 100644 --- a/vm/src/builtins/function.rs +++ b/vm/src/builtins/function.rs @@ -28,7 +28,7 @@ use rustpython_jit::CompiledCode; #[pyclass(module = false, name = "function", traverse = "manual")] #[derive(Debug)] pub struct PyFunction { - code: PyRef, + code: PyMutex>, globals: PyDictRef, builtins: PyObjectRef, closure: Option>>, @@ -73,7 +73,7 @@ impl PyFunction { let qualname = vm.ctx.new_str(code.qualname.as_str()); let func = Self { - code: code.clone(), + code: PyMutex::new(code.clone()), globals, builtins, closure: None, @@ -96,7 +96,7 @@ impl PyFunction { func_args: FuncArgs, vm: &VirtualMachine, ) -> PyResult<()> { - let code = &*self.code; + let code = &*self.code.lock(); let nargs = func_args.args.len(); let n_expected_args = code.arg_count as usize; let total_args = code.arg_count as usize + code.kwonlyarg_count as usize; @@ -392,14 +392,15 @@ impl Py { Err(err) => info!( "jit: function `{}` is falling back to being interpreted because of the \ error: {}", - self.code.obj_name, err + self.code.lock().obj_name, + err ), } } - let code = &self.code; + let code = self.code.lock().clone(); - let locals = if self.code.flags.contains(bytecode::CodeFlags::NEW_LOCALS) { + let locals = if code.flags.contains(bytecode::CodeFlags::NEW_LOCALS) { ArgMapping::from_dict_exact(vm.ctx.new_dict()) } else if let Some(locals) = locals { locals @@ -451,7 +452,18 @@ impl PyPayload for PyFunction { impl PyFunction { #[pygetset] fn __code__(&self) -> PyRef { - self.code.clone() + self.code.lock().clone() + } + + #[pygetset(setter)] + fn set___code__(&self, code: PyRef) { + *self.code.lock() = code; + // TODO: jit support + // #[cfg(feature = "jit")] + // { + // // If available, clear cached compiled code. + // let _ = self.jitted_code.take(); + // } } #[pygetset] @@ -595,7 +607,8 @@ impl PyFunction { .get_or_try_init(|| { let arg_types = jit::get_jit_arg_types(&zelf, vm)?; let ret_type = jit::jit_ret_type(&zelf, vm)?; - rustpython_jit::compile(&zelf.code.code, &arg_types, ret_type) + let code = zelf.code.lock(); + rustpython_jit::compile(&code.code, &arg_types, ret_type) .map_err(|err| jit::new_jit_error(err.to_string(), vm)) }) .map(drop) diff --git a/vm/src/builtins/function/jit.rs b/vm/src/builtins/function/jit.rs index c528c9bb31..21d8c9c0ab 100644 --- a/vm/src/builtins/function/jit.rs +++ b/vm/src/builtins/function/jit.rs @@ -65,10 +65,10 @@ fn get_jit_arg_type(dict: &PyDictRef, name: &str, vm: &VirtualMachine) -> PyResu } pub fn get_jit_arg_types(func: &Py, vm: &VirtualMachine) -> PyResult> { - let arg_names = func.code.arg_names(); + let code = func.code.lock(); + let arg_names = code.arg_names(); - if func - .code + if code .flags .intersects(CodeFlags::HAS_VARARGS | CodeFlags::HAS_VARKEYWORDS) { @@ -157,9 +157,13 @@ pub(crate) fn get_jit_args<'a>( ) -> Result, ArgsError> { let mut jit_args = jitted_code.args_builder(); let nargs = func_args.args.len(); - let arg_names = func.code.arg_names(); - if nargs > func.code.arg_count as usize || nargs < func.code.posonlyarg_count as usize { + let code = func.code.lock(); + let arg_names = code.arg_names(); + let arg_count = code.arg_count; + let posonlyarg_count = code.posonlyarg_count; + + if nargs > arg_count as usize || nargs < posonlyarg_count as usize { return Err(ArgsError::WrongNumberOfArgs); } @@ -178,7 +182,7 @@ pub(crate) fn get_jit_args<'a>( } jit_args.set(arg_idx, get_jit_value(vm, value)?)?; } else if let Some(kwarg_idx) = arg_pos(arg_names.kwonlyargs, name) { - let arg_idx = kwarg_idx + func.code.arg_count as usize; + let arg_idx = kwarg_idx + arg_count as usize; if jit_args.is_set(arg_idx) { return Err(ArgsError::ArgPassedMultipleTimes); } @@ -193,7 +197,7 @@ pub(crate) fn get_jit_args<'a>( // fill in positional defaults if let Some(defaults) = defaults { for (i, default) in defaults.iter().enumerate() { - let arg_idx = i + func.code.arg_count as usize - defaults.len(); + let arg_idx = i + arg_count as usize - defaults.len(); if !jit_args.is_set(arg_idx) { jit_args.set(arg_idx, get_jit_value(vm, default)?)?; } @@ -203,7 +207,7 @@ pub(crate) fn get_jit_args<'a>( // fill in keyword only defaults if let Some(kw_only_defaults) = kwdefaults { for (i, name) in arg_names.kwonlyargs.iter().enumerate() { - let arg_idx = i + func.code.arg_count as usize; + let arg_idx = i + arg_count as usize; if !jit_args.is_set(arg_idx) { let default = kw_only_defaults .get_item(&**name, vm) @@ -214,5 +218,7 @@ pub(crate) fn get_jit_args<'a>( } } + drop(code); + jit_args.into_args().ok_or(ArgsError::NotAllArgsPassed) }