From 9b1d30f9943a967f9ae56387e9e1a096b4e4f11c Mon Sep 17 00:00:00 2001 From: Noa Date: Thu, 9 Jan 2025 20:55:45 -0600 Subject: [PATCH] Switch to newer thread::LocalKey convenience methods --- Cargo.lock | 8 ++++++ Cargo.toml | 2 ++ common/src/str.rs | 7 +++-- stdlib/src/contextvars.rs | 27 +++++++------------- vm/Cargo.toml | 2 ++ vm/src/stdlib/sys.rs | 18 ++++++------- vm/src/stdlib/thread.rs | 16 +++++------- vm/src/vm/thread.rs | 25 +++++++----------- wasm/lib/src/vm_class.rs | 54 +++++++++++++-------------------------- 9 files changed, 65 insertions(+), 94 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3ecbdf9cea..94805fd5b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2558,6 +2558,8 @@ dependencies = [ "rustpython-sre_engine", "rustyline", "schannel", + "scoped-tls", + "scopeguard", "serde", "static_assertions", "strum", @@ -2668,6 +2670,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index ddf45d65c4..c373860324 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -207,6 +207,8 @@ rustix = { version = "1.0", features = ["event"] } rustyline = "17.0.0" serde = { version = "1.0.133", default-features = false } schannel = "0.1.27" +scoped-tls = "1" +scopeguard = "1" static_assertions = "1.1" strum = "0.27" strum_macros = "0.27" diff --git a/common/src/str.rs b/common/src/str.rs index 10d0296619..2d867130ed 100644 --- a/common/src/str.rs +++ b/common/src/str.rs @@ -533,10 +533,9 @@ pub mod levenshtein { return max_cost + 1; } - BUFFER.with(|buffer| { - let mut buffer = buffer.borrow_mut(); - for i in 0..a_end { - buffer[i] = (i + 1) * MOVE_COST; + BUFFER.with_borrow_mut(|buffer| { + for (i, x) in buffer.iter_mut().take(a_end).enumerate() { + *x = (i + 1) * MOVE_COST; } let mut result = 0usize; diff --git a/stdlib/src/contextvars.rs b/stdlib/src/contextvars.rs index 72eba70389..62f7dfc73d 100644 --- a/stdlib/src/contextvars.rs +++ b/stdlib/src/contextvars.rs @@ -107,8 +107,7 @@ mod _contextvars { return Err(vm.new_runtime_error(msg)); } - super::CONTEXTS.with(|ctxs| { - let mut ctxs = ctxs.borrow_mut(); + super::CONTEXTS.with_borrow_mut(|ctxs| { zelf.inner.idx.set(ctxs.len()); ctxs.push(zelf.to_owned()); }); @@ -126,18 +125,12 @@ mod _contextvars { return Err(vm.new_runtime_error(msg)); } - super::CONTEXTS.with(|ctxs| { - let mut ctxs = ctxs.borrow_mut(); - // TODO: use Vec::pop_if once stabilized - if ctxs.last().is_some_and(|ctx| ctx.get_id() == zelf.get_id()) { - let _ = ctxs.pop(); - Ok(()) - } else { - let msg = - "cannot exit context: thread state references a different context object" - .to_owned(); - Err(vm.new_runtime_error(msg)) - } + super::CONTEXTS.with_borrow_mut(|ctxs| { + let err_msg = + "cannot exit context: thread state references a different context object"; + ctxs.pop_if(|ctx| ctx.get_id() == zelf.get_id()) + .map(drop) + .ok_or_else(|| vm.new_runtime_error(err_msg)) })?; zelf.inner.entered.set(false); @@ -145,8 +138,7 @@ mod _contextvars { } fn current(vm: &VirtualMachine) -> PyRef { - super::CONTEXTS.with(|ctxs| { - let mut ctxs = ctxs.borrow_mut(); + super::CONTEXTS.with_borrow_mut(|ctxs| { if let Some(ctx) = ctxs.last() { ctx.clone() } else { @@ -382,8 +374,7 @@ mod _contextvars { default: OptionalArg, vm: &VirtualMachine, ) -> PyResult> { - let found = super::CONTEXTS.with(|ctxs| { - let ctxs = ctxs.borrow(); + let found = super::CONTEXTS.with_borrow(|ctxs| { let ctx = ctxs.last()?; let cached_ptr = zelf.cached.as_ptr(); debug_assert!(!cached_ptr.is_null()); diff --git a/vm/Cargo.toml b/vm/Cargo.toml index a55d2d8325..803eae014e 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -65,6 +65,8 @@ num_enum = { workspace = true } once_cell = { workspace = true } parking_lot = { workspace = true } paste = { workspace = true } +scoped-tls = { workspace = true } +scopeguard = { workspace = true } serde = { workspace = true, optional = true } static_assertions = { workspace = true } strum = { workspace = true } diff --git a/vm/src/stdlib/sys.rs b/vm/src/stdlib/sys.rs index a8fb4031a9..5f30876b30 100644 --- a/vm/src/stdlib/sys.rs +++ b/vm/src/stdlib/sys.rs @@ -830,13 +830,13 @@ mod sys { if depth < 0 { return Err(vm.new_value_error("depth must be >= 0")); } - crate::vm::thread::COROUTINE_ORIGIN_TRACKING_DEPTH.with(|cell| cell.set(depth as _)); + crate::vm::thread::COROUTINE_ORIGIN_TRACKING_DEPTH.set(depth as u32); Ok(()) } #[pyfunction] fn get_coroutine_origin_tracking_depth() -> i32 { - crate::vm::thread::COROUTINE_ORIGIN_TRACKING_DEPTH.with(|cell| cell.get()) as _ + crate::vm::thread::COROUTINE_ORIGIN_TRACKING_DEPTH.get() as i32 } #[pyfunction] @@ -887,14 +887,10 @@ mod sys { } if let Some(finalizer) = args.finalizer.into_option() { - crate::vm::thread::ASYNC_GEN_FINALIZER.with(|cell| { - cell.replace(finalizer); - }); + crate::vm::thread::ASYNC_GEN_FINALIZER.set(finalizer); } if let Some(firstiter) = args.firstiter.into_option() { - crate::vm::thread::ASYNC_GEN_FIRSTITER.with(|cell| { - cell.replace(firstiter); - }); + crate::vm::thread::ASYNC_GEN_FIRSTITER.set(firstiter); } Ok(()) @@ -914,9 +910,11 @@ mod sys { fn get_asyncgen_hooks(vm: &VirtualMachine) -> PyAsyncgenHooks { PyAsyncgenHooks { firstiter: crate::vm::thread::ASYNC_GEN_FIRSTITER - .with(|cell| cell.borrow().clone().to_pyobject(vm)), + .with_borrow(Clone::clone) + .to_pyobject(vm), finalizer: crate::vm::thread::ASYNC_GEN_FINALIZER - .with(|cell| cell.borrow().clone().to_pyobject(vm)), + .with_borrow(Clone::clone) + .to_pyobject(vm), } } diff --git a/vm/src/stdlib/thread.rs b/vm/src/stdlib/thread.rs index 90e363774f..9e49445653 100644 --- a/vm/src/stdlib/thread.rs +++ b/vm/src/stdlib/thread.rs @@ -334,13 +334,11 @@ pub(crate) mod _thread { ); } } - SENTINELS.with(|sentinels| { - for lock in sentinels.replace(Default::default()) { - if lock.mu.is_locked() { - unsafe { lock.mu.unlock() }; - } + for lock in SENTINELS.take() { + if lock.mu.is_locked() { + unsafe { lock.mu.unlock() }; } - }); + } vm.state.thread_count.fetch_sub(1); } @@ -355,14 +353,12 @@ pub(crate) mod _thread { Err(vm.new_exception_empty(vm.ctx.exceptions.system_exit.to_owned())) } - thread_local! { - static SENTINELS: RefCell>> = const { RefCell::new(Vec::new()) }; - } + thread_local!(static SENTINELS: RefCell>> = const { RefCell::new(Vec::new()) }); #[pyfunction] fn _set_sentinel(vm: &VirtualMachine) -> PyRef { let lock = Lock { mu: RawMutex::INIT }.into_ref(&vm.ctx); - SENTINELS.with(|sentinels| sentinels.borrow_mut().push(lock.clone())); + SENTINELS.with_borrow_mut(|sentinels| sentinels.push(lock.clone())); lock } diff --git a/vm/src/vm/thread.rs b/vm/src/vm/thread.rs index a5a0a7b63b..2e687d9982 100644 --- a/vm/src/vm/thread.rs +++ b/vm/src/vm/thread.rs @@ -8,30 +8,26 @@ use std::{ thread_local! { pub(super) static VM_STACK: RefCell>> = Vec::with_capacity(1).into(); - static VM_CURRENT: RefCell<*const VirtualMachine> = std::ptr::null::().into(); pub(crate) static COROUTINE_ORIGIN_TRACKING_DEPTH: Cell = const { Cell::new(0) }; pub(crate) static ASYNC_GEN_FINALIZER: RefCell> = const { RefCell::new(None) }; pub(crate) static ASYNC_GEN_FIRSTITER: RefCell> = const { RefCell::new(None) }; } +scoped_tls::scoped_thread_local!(static VM_CURRENT: VirtualMachine); + pub fn with_current_vm(f: impl FnOnce(&VirtualMachine) -> R) -> R { - VM_CURRENT.with(|x| unsafe { - f(x.clone() - .into_inner() - .as_ref() - .expect("call with_current_vm() but VM_CURRENT is null")) - }) + if !VM_CURRENT.is_set() { + panic!("call with_current_vm() but VM_CURRENT is null"); + } + VM_CURRENT.with(f) } pub fn enter_vm(vm: &VirtualMachine, f: impl FnOnce() -> R) -> R { VM_STACK.with(|vms| { vms.borrow_mut().push(vm.into()); - let prev = VM_CURRENT.with(|current| current.replace(vm)); - let ret = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); - vms.borrow_mut().pop(); - VM_CURRENT.with(|current| current.replace(prev)); - ret.unwrap_or_else(|e| std::panic::resume_unwind(e)) + scopeguard::defer! { vms.borrow_mut().pop(); } + VM_CURRENT.set(vm, f) }) } @@ -55,10 +51,7 @@ where // SAFETY: all references in VM_STACK should be valid, and should not be changed or moved // at least until this function returns and the stack unwinds to an enter_vm() call let vm = unsafe { interp.as_ref() }; - let prev = VM_CURRENT.with(|current| current.replace(vm)); - let ret = f(vm); - VM_CURRENT.with(|current| current.replace(prev)); - Some(ret) + Some(VM_CURRENT.set(vm, || f(vm))) }) } diff --git a/wasm/lib/src/vm_class.rs b/wasm/lib/src/vm_class.rs index bbd895c989..ad5df8dab9 100644 --- a/wasm/lib/src/vm_class.rs +++ b/wasm/lib/src/vm_class.rs @@ -58,8 +58,8 @@ impl StoredVirtualMachine { setup_browser_module(vm); } - VM_INIT_FUNCS.with(|cell| { - for f in cell.borrow().iter() { + VM_INIT_FUNCS.with_borrow(|funcs| { + for f in funcs { f(vm) } }); @@ -78,7 +78,7 @@ impl StoredVirtualMachine { /// Add a hook to add builtins or frozen modules to the RustPython VirtualMachine while it's /// initializing. pub fn add_init_func(f: fn(&mut VirtualMachine)) { - VM_INIT_FUNCS.with(|cell| cell.borrow_mut().push(f)) + VM_INIT_FUNCS.with_borrow_mut(|funcs| funcs.push(f)) } // It's fine that it's thread local, since WASM doesn't even have threads yet. thread_local! @@ -97,17 +97,15 @@ pub fn get_vm_id(vm: &VirtualMachine) -> &str { .expect("VirtualMachine inside of WASM crate should have wasm_id set") } pub(crate) fn stored_vm_from_wasm(wasm_vm: &WASMVirtualMachine) -> Rc { - STORED_VMS.with(|cell| { - cell.borrow() - .get(&wasm_vm.id) + STORED_VMS.with_borrow(|vms| { + vms.get(&wasm_vm.id) .expect("VirtualMachine is not valid") .clone() }) } pub(crate) fn weak_vm(vm: &VirtualMachine) -> Weak { let id = get_vm_id(vm); - STORED_VMS - .with(|cell| Rc::downgrade(cell.borrow().get(id).expect("VirtualMachine is not valid"))) + STORED_VMS.with_borrow(|vms| Rc::downgrade(vms.get(id).expect("VirtualMachine is not valid"))) } #[wasm_bindgen(js_name = vmStore)] @@ -116,8 +114,7 @@ pub struct VMStore; #[wasm_bindgen(js_class = vmStore)] impl VMStore { pub fn init(id: String, inject_browser_module: Option) -> WASMVirtualMachine { - STORED_VMS.with(|cell| { - let mut vms = cell.borrow_mut(); + STORED_VMS.with_borrow_mut(|vms| { if !vms.contains_key(&id) { let stored_vm = StoredVirtualMachine::new(id.clone(), inject_browser_module.unwrap_or(true)); @@ -128,14 +125,7 @@ impl VMStore { } pub(crate) fn _get(id: String) -> Option { - STORED_VMS.with(|cell| { - let vms = cell.borrow(); - if vms.contains_key(&id) { - Some(WASMVirtualMachine { id }) - } else { - None - } - }) + STORED_VMS.with_borrow(|vms| vms.contains_key(&id).then_some(WASMVirtualMachine { id })) } pub fn get(id: String) -> JsValue { @@ -146,24 +136,19 @@ impl VMStore { } pub fn destroy(id: String) { - STORED_VMS.with(|cell| { - use std::collections::hash_map::Entry; - match cell.borrow_mut().entry(id) { - Entry::Occupied(o) => { - let (_k, stored_vm) = o.remove_entry(); - // for f in stored_vm.drop_handlers.iter() { - // f(); - // } - // deallocate the VM - drop(stored_vm); - } - Entry::Vacant(_v) => {} + STORED_VMS.with_borrow_mut(|vms| { + if let Some(stored_vm) = vms.remove(&id) { + // for f in stored_vm.drop_handlers.iter() { + // f(); + // } + // deallocate the VM + drop(stored_vm); } }); } pub fn ids() -> Vec { - STORED_VMS.with(|cell| cell.borrow().keys().map(|k| k.into()).collect()) + STORED_VMS.with_borrow(|vms| vms.keys().map(|k| k.into()).collect()) } } @@ -179,10 +164,7 @@ impl WASMVirtualMachine { where F: FnOnce(&StoredVirtualMachine) -> R, { - let stored_vm = STORED_VMS.with(|cell| { - let mut vms = cell.borrow_mut(); - vms.get_mut(&self.id).unwrap().clone() - }); + let stored_vm = STORED_VMS.with_borrow_mut(|vms| vms.get_mut(&self.id).unwrap().clone()); f(&stored_vm) } @@ -202,7 +184,7 @@ impl WASMVirtualMachine { } pub fn valid(&self) -> bool { - STORED_VMS.with(|cell| cell.borrow().contains_key(&self.id)) + STORED_VMS.with_borrow(|vms| vms.contains_key(&self.id)) } pub(crate) fn push_held_rc(