Skip to content

Commit 17fbc90

Browse files
committed
Test PatternNodeRewriter doesn't support multi-output nodes in pattern
But it's fine if they're just root inputs
1 parent 58de233 commit 17fbc90

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

tests/graph/rewriting/test_basic.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
op_y,
4242
op_z,
4343
)
44+
from tests.unittest_tools import assert_equal_computations
4445

4546

4647
class AssertNoChanges(Feature):
@@ -725,22 +726,35 @@ def test_patternsub_invalid_dtype(out_pattern):
725726
assert e.type.is_super(fg.outputs[0].type)
726727

727728

728-
def test_patternsub_different_output_lengths():
729-
# Test that PatternNodeRewriter won't replace nodes with different numbers of outputs
730-
ps = PatternNodeRewriter(
731-
(op1, "x"),
732-
("x"),
729+
def test_patternsub_multi_output_nodes():
730+
# Test that PatternNodeRewriter won't attempt to replace multi-output nodes
731+
multiple_op_ps = PatternNodeRewriter(
732+
(op_multiple_outputs, "x"),
733+
"x",
733734
name="ps",
734735
)
735-
rewriter = in2out(ps)
736+
737+
single_op_ps = PatternNodeRewriter(
738+
(op_y, "x"),
739+
"x",
740+
name="ps",
741+
)
742+
743+
rewriter = in2out(multiple_op_ps, single_op_ps)
736744

737745
x = MyVariable("x")
738746
e1, e2 = op_multiple_outputs(x)
739-
o = op1(e1)
747+
o1, o2 = op_y(e1), op_y(e2)
748+
749+
fgraph = FunctionGraph(inputs=[x], outputs=[e2, e1], copy_inputs=False)
750+
rewriter.rewrite(fgraph)
751+
# This shouldn't rewrite because PatternNodeRewriter has no way of specifying which output(s) are being matched
752+
assert_equal_computations(fgraph.outputs, [e2, e1])
740753

741-
fgraph = FunctionGraph(inputs=[x], outputs=[o])
754+
fgraph = FunctionGraph(inputs=[x], outputs=[o2, o1], copy_inputs=False)
742755
rewriter.rewrite(fgraph)
743-
assert fgraph.outputs[0].owner.op == op1
756+
# Having a variable that comes out of a multi-output node should be fine
757+
assert_equal_computations(fgraph.outputs, [e2, e1])
744758

745759

746760
class TestSequentialNodeRewriter:

tests/graph/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def make_node(self, *inputs):
107107

108108

109109
class MyOpMultipleOutputs(MyOp):
110+
def __init__(self, name, dmap=None, x=None):
111+
super().__init__(name=name, dmap=dmap, x=x, n_outs=2)
112+
110113
def make_node(self, input):
111114
outputs = [input.type(), input.type()]
112115
return Apply(self, [input], outputs)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy