Skip to content

Commit 0a3eb7b

Browse files
committed
Use __class__ cell in super
1 parent 2eb1f8f commit 0a3eb7b

File tree

3 files changed

+65
-28
lines changed

3 files changed

+65
-28
lines changed

tests/snippets/class.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,20 @@ class Me2(Me):
8585
def test(me):
8686
return super().test()
8787

88+
class A():
89+
def f(self):
90+
pass
91+
92+
class B(A):
93+
def f(self):
94+
super().f()
95+
96+
class C(B):
97+
def f(self):
98+
super().f()
99+
100+
C().f()
101+
88102
me = Me2()
89103
assert me.test() == 100
90104

vm/src/frame.rs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,20 @@ impl Scope {
106106
}
107107
}
108108

109+
pub fn get(&self, name: &str) -> Option<PyObjectRef> {
110+
for dict in self.locals.iter() {
111+
if let Some(value) = dict.get_item(name) {
112+
return Some(value);
113+
}
114+
}
115+
116+
if let Some(value) = self.globals.get_item(name) {
117+
return Some(value);
118+
}
119+
120+
None
121+
}
122+
109123
pub fn get_only_locals(&self) -> Option<PyObjectRef> {
110124
self.locals.iter().next().cloned()
111125
}
@@ -130,17 +144,11 @@ pub trait NameProtocol {
130144

131145
impl NameProtocol for Scope {
132146
fn load_name(&self, vm: &VirtualMachine, name: &str) -> Option<PyObjectRef> {
133-
for dict in self.locals.iter() {
134-
if let Some(value) = dict.get_item(name) {
135-
return Some(value);
136-
}
137-
}
138-
139-
if let Some(value) = self.globals.get_item(name) {
140-
return Some(value);
147+
if let Some(value) = self.get(name) {
148+
Some(value)
149+
} else {
150+
vm.builtins.get_item(name)
141151
}
142-
143-
vm.builtins.get_item(name)
144152
}
145153

146154
fn store_name(&self, vm: &VirtualMachine, key: &str, value: PyObjectRef) {

vm/src/obj/objsuper.rs

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use super::objtype;
1919
#[derive(Debug)]
2020
pub struct PySuper {
2121
obj: PyObjectRef,
22+
typ: PyObjectRef,
2223
}
2324

2425
impl PyValue for PySuper {
@@ -68,8 +69,9 @@ fn super_getattribute(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
6869
);
6970

7071
let inst = super_obj.payload::<PySuper>().unwrap().obj.clone();
72+
let typ = super_obj.payload::<PySuper>().unwrap().typ.clone();
7173

72-
match inst.typ().payload::<PyClass>() {
74+
match typ.payload::<PyClass>() {
7375
Some(PyClass { ref mro, .. }) => {
7476
for class in mro {
7577
if let Ok(item) = vm.get_attribute(class.as_object().clone(), name_str.clone()) {
@@ -99,6 +101,29 @@ fn super_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
99101
return Err(vm.new_type_error(format!("{:?} is not a subtype of super", cls)));
100102
}
101103

104+
// Get the type:
105+
let py_type = if let Some(ty) = py_type {
106+
ty.clone()
107+
} else {
108+
match vm.current_scope().get("__class__") {
109+
Some(obj) => obj.clone(),
110+
_ => {
111+
return Err(vm.new_type_error(
112+
"super must be called with 1 argument or from inside class method".to_string(),
113+
));
114+
}
115+
}
116+
};
117+
118+
// Check type argument:
119+
if !objtype::isinstance(&py_type, &vm.get_type()) {
120+
let type_name = objtype::get_type_name(&py_type.typ());
121+
return Err(vm.new_type_error(format!(
122+
"super() argument 1 must be type, not {}",
123+
type_name
124+
)));
125+
}
126+
102127
// Get the bound object:
103128
let py_obj = if let Some(obj) = py_obj {
104129
obj.clone()
@@ -119,28 +144,18 @@ fn super_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
119144
}
120145
};
121146

122-
// Get the type:
123-
let py_type = if let Some(ty) = py_type {
124-
ty.clone()
125-
} else {
126-
py_obj.typ().clone()
127-
};
128-
129-
// Check type argument:
130-
if !objtype::isinstance(&py_type, &vm.get_type()) {
131-
let type_name = objtype::get_type_name(&py_type.typ());
132-
return Err(vm.new_type_error(format!(
133-
"super() argument 1 must be type, not {}",
134-
type_name
135-
)));
136-
}
137-
138147
// Check obj type:
139148
if !(objtype::isinstance(&py_obj, &py_type) || objtype::issubclass(&py_obj, &py_type)) {
140149
return Err(vm.new_type_error(
141150
"super(type, obj): obj must be an instance or subtype of type".to_string(),
142151
));
143152
}
144153

145-
Ok(PyObject::new(PySuper { obj: py_obj }, cls.clone()))
154+
Ok(PyObject::new(
155+
PySuper {
156+
obj: py_obj,
157+
typ: py_type,
158+
},
159+
cls.clone(),
160+
))
146161
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy