Skip to content

Rust: Type inference for tuples #20041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion rust/ql/.generated.list

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion rust/ql/.gitattributes

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 14 additions & 2 deletions rust/ql/lib/codeql/rust/elements/internal/TuplePatImpl.qll
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// generated by codegen, remove this comment if you wish to edit this file
/**
* This module provides a hand-modifiable wrapper around the generated class `TuplePat`.
*
Expand All @@ -12,12 +11,25 @@ private import codeql.rust.elements.internal.generated.TuplePat
* be referenced directly.
*/
module Impl {
private import rust

// the following QLdoc is generated: if you need to edit it, do it in the schema file
/**
* A tuple pattern. For example:
* ```rust
* let (x, y) = (1, 2);
* let (a, b, .., z) = (1, 2, 3, 4, 5);
* ```
*/
class TuplePat extends Generated::TuplePat { }
class TuplePat extends Generated::TuplePat {
/**
* Gets the arity of the tuple matched by this pattern, if any.
*
* This is the number of fields in the tuple pattern if and only if the
* pattern does not contain a `..` pattern.
*/
int getTupleArity() {
result = this.getNumberOfFields() and not this.getAField() instanceof RestPat
}
}
}
59 changes: 52 additions & 7 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,23 @@ private import codeql.rust.elements.internal.generated.Synth

cached
newtype TType =
TUnit() or
TStruct(Struct s) { Stages::TypeInferenceStage::ref() } or
TTuple(int arity) {
arity =
[
any(TupleTypeRepr t).getNumberOfFields(),
any(TupleExpr e).getNumberOfFields(),
any(TuplePat p).getNumberOfFields()
] and
Stages::TypeInferenceStage::ref()
} or
TStruct(Struct s) or
TEnum(Enum e) or
TTrait(Trait t) or
TArrayType() or // todo: add size?
TRefType() or // todo: add mut?
TImplTraitType(ImplTraitTypeRepr impl) or
TSliceType() or
TTupleTypeParameter(int arity, int i) { exists(TTuple(arity)) and i in [0 .. arity - 1] } or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way of defining arity would make TType's definition recursive, right? Not sure if that is a problem, but we might want to avoid it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. No recursion markers show up in VScode, so I think what's happening is that TTuple is materialized first and then TTupleTypeParameter is materialized. So TTupleTypeParameter depends on TTuple but there is no recursion. I'm just guessing though so maybe I'm wrong.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine with me, @hvitved once commented on a case where I defined a recursive type where it wasn't necessary. Not sure if it is a problem here though. In any case if it's easy to rewrite if it turns out to be a problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Let's keep it and change it if it's actually suboptimal.

TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
TArrayTypeParameter() or
Expand Down Expand Up @@ -55,21 +64,33 @@ abstract class Type extends TType {
abstract Location getLocation();
}

/** The unit type `()`. */
class UnitType extends Type, TUnit {
UnitType() { this = TUnit() }
/** A tuple type `(T, ...)`. */
class TupleType extends Type, TTuple {
private int arity;

TupleType() { this = TTuple(arity) }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) { none() }
override TypeParameter getTypeParameter(int i) { result = TTupleTypeParameter(arity, i) }

override string toString() { result = "()" }
/** Gets the arity of this tuple type. */
int getArity() { result = arity }

override string toString() { result = "(T_" + arity + ")" }

override Location getLocation() { result instanceof EmptyLocation }
}

/** The unit type `()`. */
class UnitType extends TupleType, TTuple {
UnitType() { this = TTuple(0) }

override string toString() { result = "()" }
}

abstract private class StructOrEnumType extends Type {
abstract ItemNode asItemNode();
}
Expand Down Expand Up @@ -329,6 +350,30 @@ class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypePara
override Location getLocation() { result = typeAlias.getLocation() }
}

/**
* A tuple type parameter. For instance the `T` in `(T, U)`.
*
* Since tuples are structural their type parameters can be represented as their
* positional index. The type inference library requires that type parameters
* belong to a single type, so we also include the arity of the tuple type.
*/
class TupleTypeParameter extends TypeParameter, TTupleTypeParameter {
private int arity;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really necessary to have arity as a field? Is it important to distinguish the first element of a pair from the first element of a triple?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintuitively the answer is yes. Before 7c04c9f arity was not in TupleTypeParameter. But the type inference library relies on the assumption that every type parameter corresponds to exactly one type, so not having the arity caused problems.

I've noted this in the QLdoc for TupleTypeParameter now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I was wondering this as well.

private int index;

TupleTypeParameter() { this = TTupleTypeParameter(arity, index) }

override string toString() { result = index.toString() + "(" + arity + ")" }

override Location getLocation() { result instanceof EmptyLocation }

/** Gets the index of this tuple type parameter. */
int getIndex() { result = index }

/** Gets the tuple type that corresponds to this tuple type parameter. */
TupleType getTupleType() { result = TTuple(arity) }
}

/** An implicit array type parameter. */
class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
override string toString() { result = "[T;...]" }
Expand Down
69 changes: 68 additions & 1 deletion rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ private module Input1 implements InputSig1<Location> {
node = tp0.(SelfTypeParameter).getTrait() or
node = tp0.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr()
)
or
exists(TupleTypeParameter ttp, int maxArity |
maxArity = max(int i | i = any(TupleType tt).getArity()) and
tp0 = ttp and
kind = 2 and
id = ttp.getTupleType().getArity() * maxArity + ttp.getIndex()
)
|
tp0 order by kind, id
)
Expand Down Expand Up @@ -229,7 +236,7 @@ private Type inferLogicalOperationType(AstNode n, TypePath path) {
private Type inferAssignmentOperationType(AstNode n, TypePath path) {
n instanceof AssignmentOperation and
path.isEmpty() and
result = TUnit()
result instanceof UnitType
}

pragma[nomagic]
Expand Down Expand Up @@ -321,6 +328,17 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
prefix1.isEmpty() and
prefix2 = TypePath::singleton(TRefTypeParameter())
or
exists(int i, int arity |
prefix1.isEmpty() and
prefix2 = TypePath::singleton(TTupleTypeParameter(arity, i))
|
arity = n2.(TupleExpr).getNumberOfFields() and
n1 = n2.(TupleExpr).getField(i)
or
arity = n2.(TuplePat).getTupleArity() and
n1 = n2.(TuplePat).getField(i)
)
or
exists(BlockExpr be |
n1 = be and
n2 = be.getStmtList().getTailExpr() and
Expand Down Expand Up @@ -534,6 +552,12 @@ private Type inferStructExprType(AstNode n, TypePath path) {
)
}

pragma[nomagic]
private Type inferTupleRootType(AstNode n) {
// `typeEquality` handles the non-root cases
result = TTuple([n.(TupleExpr).getNumberOfFields(), n.(TuplePat).getTupleArity()])
}

pragma[nomagic]
private Type inferPathExprType(PathExpr pe, TypePath path) {
// nullary struct/variant constructors
Expand Down Expand Up @@ -1055,6 +1079,42 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
)
}

pragma[nomagic]
private Type inferTupleIndexExprType(FieldExpr fe, TypePath path) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine to me, although I had expected something more similar to inferFieldExprType (or even integrated into that predicate). Or do things really work differently for named fields compared to numeric fields?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into that, but inferFieldExprType uses the Matching module which is about propagating type info from declarations. Tuple types need not correspond to any declaration, so they're different enough that I don't see a clear way to handle them in inferFieldExprType.

exists(int i, TypePath path0 |
fe.getIdentifier().getText() = i.toString() and
result = inferType(fe.getContainer(), path0) and
path0.isCons(TTupleTypeParameter(_, i), path) and
fe.getIdentifier().getText() = i.toString()
)
}

/** Infers the type of `t` in `t.n` when `t` is a tuple. */
private Type inferTupleContainerExprType(Expr e, TypePath path) {
// NOTE: For a field expression `t.n` where `n` is a number `t` might be a
// tuple as in:
// ```rust
// let t = (Default::default(), 2);
// let s: String = t.0;
// ```
// But it could also be a tuple struct as in:
// ```rust
// struct T(String, u32);
// let t = T(Default::default(), 2);
// let s: String = t.0;
// ```
// We need type information to flow from `t.n` to tuple type parameters of `t`
// in the former case but not the latter case. Hence we include the condition
// that the root type of `t` must be a tuple type.
exists(int i, TypePath path0, FieldExpr fe, int arity |
e = fe.getContainer() and
fe.getIdentifier().getText() = i.toString() and
arity = inferType(fe.getContainer()).(TupleType).getArity() and
result = inferType(fe, path0) and
path = TypePath::cons(TTupleTypeParameter(arity, i), path0)
)
}

/** Gets the root type of the reference node `ref`. */
pragma[nomagic]
private Type inferRefNodeType(AstNode ref) {
Expand Down Expand Up @@ -1943,12 +2003,19 @@ private module Cached {
or
result = inferStructExprType(n, path)
or
result = inferTupleRootType(n) and
path.isEmpty()
or
result = inferPathExprType(n, path)
or
result = inferCallExprBaseType(n, path)
or
result = inferFieldExprType(n, path)
or
result = inferTupleIndexExprType(n, path)
or
result = inferTupleContainerExprType(n, path)
or
result = inferRefNodeType(n) and
path.isEmpty()
or
Expand Down
12 changes: 12 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ abstract class TypeMention extends AstNode {
final Type resolveType() { result = this.resolveTypeAt(TypePath::nil()) }
}

class TupleTypeReprMention extends TypeMention instanceof TupleTypeRepr {
override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
result = TTuple(super.getNumberOfFields())
or
exists(TypePath suffix, int i |
result = super.getField(i).(TypeMention).resolveTypeAt(suffix) and
path = TypePath::cons(TTupleTypeParameter(super.getNumberOfFields(), i), suffix)
)
}
}

class ArrayTypeReprMention extends TypeMention instanceof ArrayTypeRepr {
override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
category: minorAnalysis
---
* Type inference now supports tuple types.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
multipleCallTargets
| main.rs:445:18:445:24 | n.len() |
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,7 @@ readStep
| main.rs:442:25:442:29 | names | file://:0:0:0:0 | element | main.rs:442:9:442:20 | TuplePat |
| main.rs:444:41:444:67 | [post] \|...\| ... | main.rs:441:9:441:20 | captured default_name | main.rs:444:41:444:67 | [post] default_name |
| main.rs:444:44:444:55 | this | main.rs:441:9:441:20 | captured default_name | main.rs:444:44:444:55 | default_name |
| main.rs:445:18:445:18 | [post] receiver for n | file://:0:0:0:0 | &ref | main.rs:445:18:445:18 | [post] n |
| main.rs:469:13:469:13 | [post] receiver for b | file://:0:0:0:0 | &ref | main.rs:469:13:469:13 | [post] b |
| main.rs:470:18:470:18 | [post] receiver for b | file://:0:0:0:0 | &ref | main.rs:470:18:470:18 | [post] b |
| main.rs:481:10:481:11 | vs | file://:0:0:0:0 | element | main.rs:481:10:481:14 | vs[0] |
Expand Down Expand Up @@ -1078,6 +1079,7 @@ storeStep
| main.rs:429:30:429:30 | 3 | file://:0:0:0:0 | element | main.rs:429:23:429:31 | [...] |
| main.rs:432:18:432:27 | source(...) | file://:0:0:0:0 | element | main.rs:432:5:432:11 | [post] mut_arr |
| main.rs:444:41:444:67 | default_name | main.rs:441:9:441:20 | captured default_name | main.rs:444:41:444:67 | \|...\| ... |
| main.rs:445:18:445:18 | n | file://:0:0:0:0 | &ref | main.rs:445:18:445:18 | receiver for n |
| main.rs:469:13:469:13 | b | file://:0:0:0:0 | &ref | main.rs:469:13:469:13 | receiver for b |
| main.rs:470:18:470:18 | b | file://:0:0:0:0 | &ref | main.rs:470:18:470:18 | receiver for b |
| main.rs:479:15:479:24 | source(...) | file://:0:0:0:0 | element | main.rs:479:14:479:34 | [...] |
Expand Down
44 changes: 30 additions & 14 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2334,20 +2334,36 @@ mod tuples {
}

pub fn f() {
let a = S1::get_pair(); // $ target=get_pair MISSING: type=a:?
let mut b = S1::get_pair(); // $ target=get_pair MISSING: type=b:?
let (c, d) = S1::get_pair(); // $ target=get_pair MISSING: type=c:? type=d:?
let (mut e, f) = S1::get_pair(); // $ target=get_pair MISSING: type=e: type=f:
let (mut g, mut h) = S1::get_pair(); // $ target=get_pair MISSING: type=g:? type=h:?

a.0.foo(); // $ MISSING: target=foo
b.1.foo(); // $ MISSING: target=foo
c.foo(); // $ MISSING: target=foo
d.foo(); // $ MISSING: target=foo
e.foo(); // $ MISSING: target=foo
f.foo(); // $ MISSING: target=foo
g.foo(); // $ MISSING: target=foo
h.foo(); // $ MISSING: target=foo
let a = S1::get_pair(); // $ target=get_pair type=a:(T_2)
let mut b = S1::get_pair(); // $ target=get_pair type=b:(T_2)
let (c, d) = S1::get_pair(); // $ target=get_pair type=c:S1 type=d:S1
let (mut e, f) = S1::get_pair(); // $ target=get_pair type=e:S1 type=f:S1
let (mut g, mut h) = S1::get_pair(); // $ target=get_pair type=g:S1 type=h:S1

a.0.foo(); // $ target=foo
b.1.foo(); // $ target=foo
c.foo(); // $ target=foo
d.foo(); // $ target=foo
e.foo(); // $ target=foo
f.foo(); // $ target=foo
g.foo(); // $ target=foo
h.foo(); // $ target=foo

// Here type information must flow from `pair.0` and `pair.1` into
// `pair` and from `(a, b)` into `a` and `b` in order for the types of
// `a` and `b` to be inferred.
let a = Default::default(); // $ target=default type=a:i64
let b = Default::default(); // $ target=default type=b:bool
let pair = (a, b); // $ type=pair:0(2).i64 type=pair:1(2).bool
let i: i64 = pair.0;
let j: bool = pair.1;

let pair = [1, 1].into(); // $ type=pair:(T_2) type=pair:0(2).i32 type=pair:1(2).i32 MISSING: target=into
match pair {
(0,0) => print!("unexpected"),
_ => print!("expected"),
}
let x = pair.0; // $ type=x:i32
}
}

Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy