Skip to content

GH-135904: Optimize the JIT's assembly control flow #135905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 27, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Switch to a linked list
  • Loading branch information
brandtbucher committed Jun 12, 2025
commit 858624af55e42b7c7eceabf21877f744b00d0ee8
215 changes: 123 additions & 92 deletions Tools/jit/_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,44 +35,59 @@
"_R", _schema.COFFRelocation, _schema.ELFRelocation, _schema.MachORelocation
)

inverted_branches = {}
branches = {}
for op, nop in [
("ja", "jna"),
("jae", "jnae"),
("jb", "jnb"),
("jbe", "jnbe"),
("jc", "jnc"),
("jcxz", None),
("je", "jne"),
("jecxz", None),
("jg", "jng"),
("jge", "jnge"),
("jl", "jnl"),
("jle", "jnle"),
("jo", "jno"),
("jp", "jnp"),
("js", "jns"),
("jz", "jnz"),
("jpe", "jpo"),
("jcxz", None),
("jecxz", None),
("jrxz", None),
("js", "jns"),
("jz", "jnz"),
("loop", None),
("loope", None),
("loopne", None),
("loopnz", None),
("loopz", None),
]:
inverted_branches[op] = nop
branches[op] = nop
if nop is not None:
inverted_branches[nop] = op
branches[nop] = op


@dataclasses.dataclass
class _Line:
fallthrough: typing.ClassVar[bool] = True
text: str
hot: bool = dataclasses.field(init=False, default=False)
predecessors: list["_Line"] = dataclasses.field(
init=False, repr=False, default_factory=list
)
link: "_Line | None" = dataclasses.field(init=False, repr=False, default=None)

def heat(self) -> None:
if self.hot:
return
self.hot = True
for predecessor in self.predecessors:
predecessor.heat()
if self.fallthrough and self.link is not None:
self.link.heat()

def optimize(self) -> None:
if self.link is not None:
self.link.optimize()


@dataclasses.dataclass
Expand All @@ -82,18 +97,23 @@ class _Label(_Line):

@dataclasses.dataclass
class _Jump(_Line):
target: _Label | None
fallthrough = False
target: _Label

def optimize(self) -> None:
super().optimize()
target_aliases = _aliases(self.target)
if any(alias in target_aliases for alias in _aliases(self.link)):
self.remove()

def remove(self) -> None:
self.text = ""
if self.target is not None:
self.target.predecessors.remove(self)
if not self.target.predecessors:
self.target.text = ""
self.target = None
[predecessor] = self.predecessors
assert predecessor.link is self
self.target.predecessors.remove(self)
predecessor.link = self.link
self.target.predecessors.append(predecessor)

def update(self, target: _Label) -> None:
assert self.target is not None
self.target.predecessors.remove(self)
assert self.target.label in self.text
self.text = self.text.replace(self.target.label, target.label)
Expand All @@ -105,22 +125,28 @@ def update(self, target: _Label) -> None:
class _Branch(_Line):
op: str
target: _Label
fallthrough: _Line | None = None

def optimize(self) -> None:
super().optimize()
if self.target.hot:
for jump in _aliases(self.link):
if isinstance(jump, _Jump) and self.invert(jump):
jump.optimize()

def update(self, target: _Label) -> None:
assert self.target is not None
self.target.predecessors.remove(self)
assert self.target.label in self.text
self.text = self.text.replace(self.target.label, target.label)
self.target = target
self.target.predecessors.append(self)

def invert(self, jump: _Jump) -> bool:
inverted = inverted_branches[self.op]
if inverted is None or jump.target is None:
inverted = branches[self.op]
if inverted is None:
return False
assert self.op in self.text
self.text = self.text.replace(self.op, inverted)
self.op = inverted
old_target = self.target
self.update(jump.target)
jump.update(old_target)
Expand All @@ -129,7 +155,7 @@ def invert(self, jump: _Jump) -> bool:

@dataclasses.dataclass
class _Return(_Line):
pass
fallthrough = False


@dataclasses.dataclass
Expand All @@ -140,7 +166,7 @@ class _Noise(_Line):
def _branch(
line: str, use_label: typing.Callable[[str], _Label]
) -> tuple[str, _Label] | None:
branch = re.match(rf"\s*({'|'.join(inverted_branches)})\s+([\w\.]+)", line)
branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line)
return branch and (branch.group(1), use_label(branch.group(2)))


Expand All @@ -163,25 +189,57 @@ def _noise(line: str) -> bool:
return re.match(r"\s*[#\.]|\s*$", line) is not None


def _apply_asm_transformations(path: pathlib.Path) -> None:
labels = {}
def _aliases(line: _Line | None) -> list[_Line]:
aliases = []
while line is not None and isinstance(line, (_Label, _Noise)):
aliases.append(line)
line = line.link
if line is not None:
aliases.append(line)
return aliases

def use_label(label: str) -> _Label:
if label not in labels:
labels[label] = _Label("", label)
return labels[label]

def new_line(text: str) -> _Line:
if branch := _branch(text, use_label):
@dataclasses.dataclass
class _AssemblyTransformer:
_path: pathlib.Path
_alignment: int = 1
_lines: _Line = dataclasses.field(init=False)
_labels: dict[str, _Label] = dataclasses.field(init=False, default_factory=dict)
_ran: bool = dataclasses.field(init=False, default=False)

def __post_init__(self) -> None:
dummy = current = _Noise("")
for line in self._path.read_text().splitlines(True):
new = self._new_line(line)
if current.fallthrough:
new.predecessors.append(current)
current.link = new
current = new
assert dummy.link is not None
self._lines = dummy.link

def __iter__(self) -> typing.Iterator[_Line]:
line = self._lines
while line is not None:
yield line
line = line.link

def _use_label(self, label: str) -> _Label:
if label not in self._labels:
self._labels[label] = _Label("", label)
return self._labels[label]

def _new_line(self, text: str) -> _Line:
if branch := _branch(text, self._use_label):
op, label = branch
line = _Branch(text, op, label)
label.predecessors.append(line)
return line
if label := _jump(text, use_label):
if label := _jump(text, self._use_label):
line = _Jump(text, label)
label.predecessors.append(line)
return line
if line := _label(text, use_label):
if line := _label(text, self._use_label):
assert line.text == ""
line.text = text
return line
Expand All @@ -191,66 +249,39 @@ def new_line(text: str) -> _Line:
return _Noise(text)
return _Line(text)

# Build graph:
lines = []
line = _Noise("") # Dummy.
with path.open() as file:
for i, text in enumerate(file):
new = new_line(text)
if not isinstance(line, (_Jump, _Return)):
new.predecessors.append(line)
lines.append(new)
line = new
for i, line in enumerate(reversed(lines)):
if not isinstance(line, (_Label, _Noise)):
break
new = new_line("_JIT_CONTINUE:\n")
lines.insert(len(lines) - i, new)
line = new
# Mark hot lines:
todo = labels["_JIT_CONTINUE"].predecessors.copy()
while todo:
line = todo.pop()
line.hot = True
for predecessor in line.predecessors:
if not predecessor.hot:
todo.append(predecessor)
for pair in itertools.pairwise(
filter(lambda line: line.text and not isinstance(line, _Noise), lines)
):
match pair:
case (_Branch(hot=True) as branch, _Jump(hot=False) as jump):
branch.invert(jump)
jump.hot = True
for pair in itertools.pairwise(lines):
match pair:
case (_Jump() | _Return(), _):
pass
case (_Line(hot=True), _Line(hot=False) as cold):
cold.hot = True
# Reorder blocks:
hot = []
cold = []
for line in lines:
if line.hot:
hot.append(line)
else:
cold.append(line)
lines = hot + cold
# Remove zero-length jumps:
again = True
while again:
again = False
for pair in itertools.pairwise(
filter(lambda line: line.text and not isinstance(line, _Noise), lines)
):
match pair:
case (_Jump(target=target) as jump, label) if target is label:
jump.remove()
again = True
# Write new assembly:
with path.open("w") as file:
file.writelines(line.text for line in lines)
def _dump(self) -> str:
return "".join(line.text for line in self)

def _break_on(self, name: str) -> None:
if self._path.stem == name:
print(self._dump())
breakpoint()

def run(self) -> None:
assert not self._ran
self._ran = True
last_line = None
for line in self:
if not isinstance(line, (_Label, _Noise)):
last_line = line
assert last_line is not None
new = self._new_line(f".balign {self._alignment}\n")
new.link = last_line.link
last_line.link = new
new = self._new_line("_JIT_CONTINUE:\n")
new.link = last_line.link
last_line.link = new
# Mark hot lines and optimize:
recursion_limit = sys.getrecursionlimit()
sys.setrecursionlimit(10_000)
try:
self._labels["_JIT_CONTINUE"].heat()
# self._break_on("_BUILD_TUPLE")
self._lines.optimize()
finally:
sys.setrecursionlimit(recursion_limit)
# Write new assembly:
self._path.write_text(self._dump())


@dataclasses.dataclass
Expand Down Expand Up @@ -377,7 +408,7 @@ async def _compile(
*self.args,
]
await _llvm.run("clang", args_s, echo=self.verbose)
_apply_asm_transformations(s)
_AssemblyTransformer(s, self.alignment).run()
args_o = [
f"--target={self.triple}",
"-c",
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy