Fine-Grained Causal Dynamics Learning With Quantization For Improving Robustness in Reinforcement Learning
Fine-Grained Causal Dynamics Learning With Quantization For Improving Robustness in Reinforcement Learning
Fine-Grained Causal Dynamics Learning With Quantization For Improving Robustness in Reinforcement Learning
Inwoo Hwang 1 Yunhyeok Kwak 1 Suhyung Choi 1 Byoung-Tak Zhang 1 Sanghack Lee 1 2
Abstract
S ⇥A
<latexit sha1_base64="m9iHAv9KiC4kj6ms0spC2L8VFQg=">AAACBXicbVDLSsNAFJ34rPUVdamLYBFclUTqY1lx47KifUATymQ6aYdOJmHmRighGzf+ihsXirj1H9z5N07aINp64MLhnHu59x4/5kyBbX8ZC4tLyyurpbXy+sbm1ra5s9tSUSIJbZKIR7LjY0U5E7QJDDjtxJLi0Oe07Y+ucr99T6VikbiDcUy9EA8ECxjBoKWeeeCGGIYE8/Q2c4GFVP0Il1nPrNhVewJrnjgFqaACjZ756fYjkoRUAOFYqa5jx+ClWAIjnGZlN1E0xmSEB7SrqcB6n5dOvsisI630rSCSugRYE/X3RIpDpcahrzvzE9Wsl4v/ed0EggsvZSJOgAoyXRQk3ILIyiOx+kxSAnysCSaS6VstMsQSE9DBlXUIzuzL86R1UnXOqqc3tUq9VsRRQvvoEB0jB52jOrpGDdREBD2gJ/SCXo1H49l4M96nrQtGMbOH/sD4+AZ9mJkt</latexit>
1
Fine-Grained Causal Dynamics Learning
it is unclear under which circumstances the inferred depen- a tuple V, U, F, P (U) , where V = {X1 , · · · , Xd } is a
dencies hold, making them hard to interpret and challenging set of endogenous variables and U is a set of exogenous
to generalize to unseen states. variables. A set of functions F = {f1 , · · · , fd } determine
how each variable is generated; Xj = fj (P a(j), Uj ) where
In this work, we propose a dynamics model that infers fine-
P a(j) ⊆ V \{Xj } is parents of Xj and Uj ⊆ U. An SCM
grained causal structures and employs them for prediction,
M induces a directed acyclic graph (DAG) G = (V, E),
leading to improved robustness in MBRL. For this, we es-
i.e., a causal graph (CG) (Peters et al., 2017), where
tablish a principled way to examine fine-grained causal rela-
V = {1, · · · , d} and E ⊆ V × V are the set of nodes
tionships based on the quantization of the state-action space.
and edges, respectively. Each edge denotes a direct causal
Importantly, this provides a clear understanding of mean-
relationship from Xi to Xj . An SCM entails the condi-
ingful contexts displaying sparse dependencies (Fig. 1-(c)).
tional independence relationship of each variable (namely,
However, this involves the optimization of the regularized
local Markov property): Xi ⊥ ⊥ N D(Xi ) | P a(Xi ), where
maximum likelihood score over the quantization which is
N D(Xi ) is a non-descendant of Xi , which can be read off
generally intractable. To this end, we present a practical dif-
from the corresponding causal graph.
ferentiable method that jointly learns the dynamics model
and a discrete latent variable that decomposes the state- Factored Markov Decision Process. A Markov Decision
action space into subgroups by utilizing vector quantization Process (MDP) (Sutton & Barto, 2018) is defined as a tuple
(Van Den Oord et al., 2017). Theoretically, we show that ⟨S, A, T, r, γ⟩ where S is a state space, A is an action space,
joint optimization leads to identifying meaningful contexts T : S × A → P(S) is a transition dynamics, r is a reward
and fine-grained causal structures. function, and γ is a discount factor. We consider a factored
MDP (Kearns & Koller, 1999) where the state and action
We evaluate our method on both discrete and continuous
spaces are factorized as S = S1 × · · · × SN and A =
control environments where fine-grained causal reasoning
A1 × · · · × AQ M . A transition dynamics is factorized as
is crucial. Experimental results demonstrate the superior
p(s′ | s, a) = j p(s′j | s, a) where s = (s1 , · · · , sN ) and
robustness of our approach to locally spurious correlations
a = (a1 , · · · , aM ).
and unseen states in downstream tasks compared to prior
causal/non-causal approaches. Finally, we illustrate that our Assumptions and notations. We are concerned with an
method infers fine-grained relationships in a more effective SCM associated with the transition dynamics in a factored
and robust manner compared to sample-specific approaches. MDP where the states are fully observable. To properly
identify the causal relationships, we make assumptions stan-
Our contributions are summarized as follows.
dard in the field, namely, Markov property (Pearl, 2009),
faithfulness (Peters et al., 2017), causal sufficiency (Spirtes
• We establish a principled way to examine fine-grained
et al., 2000), and that causal connections only appear within
causal relationships based on the quantization of the
consecutive time steps. With these assumptions, we con-
state-action space which offers an identifiability guar-
sider a bipartite causal graph G = (V, E) which consists
antee and better interpretability.
of the set of nodes V = X ∪ Y and the set of edges
• We present a theoretically grounded and practical ap- E ⊆ X × Y, where X = {S1 , · · · , SN , A1 , · · · , AM }
proach to dynamics learning that infers fine-grained and Y = {S1′ , · · · , SN′
}. P a(j) denotes parent variables
′
causal relationships by utilizing vector quantization. of Sj . The conditional independence
2
Fine-Grained Causal Dynamics Learning
(De Haan et al., 2019; Buesing et al., 2019; Zhang et al., We then provide a principled way to examine fine-grained
2020a; Sontakke et al., 2021; Schölkopf et al., 2021; Zholus causal relationships (Sec. 3.2). Based on this, we propose a
et al., 2022; Zhang et al., 2020b). One focus is dynamics theoretically grounded and practical method for fine-grained
learning, which involves the causal structure of the transi- causal dynamics modeling (Sec. 3.3). Finally, we provide a
tion dynamics (Li et al., 2020; Yao et al., 2022; Bongers theoretical analysis with discussions (Sec. 3.4). All omitted
et al., 2018; Wang et al., 2022; Ding et al., 2022; Feng et al., proofs are provided in Appendix B.
2022; Huang et al., 2022) (more broad literature on causal
reasoning in RL is discussed in Appendix A.1). Recent 3.1. Preliminary
works proposed causal dynamics models that make robust
predictions based on the causal dependencies (Fig. 1-(a)), Analogous to the conditional independence explaining the
utilizing conditional independence tests (Ding et al., 2022) causal relationship between the variables (i.e., Eq. (1)), their
or conditional mutual information (Wang et al., 2022) to fine-grained relationships can be understood with local in-
infer the causal graph in a factored MDP. However, prior dependence (Hwang et al., 2023). This is written as:
methods cannot harness fine-grained causal relationships
Sj′ ⊥
⊥ X \ P a(j; D) | P a(j; D), D, (2)
that provide a more detailed understanding of the dynamics.
In contrast, our work aims to discover and incorporate them where D ⊆ X = S × A is a local subset of the joint state-
into dynamics modeling, demonstrating that fine-grained action space, which we say context, and P a(j; D) ⊆ X
causal reasoning leads to improved robustness in MBRL. is a set of state and action variables locally relevant for
Discovering fine-grained causal relationships. In the con- predicting Sj′ under D. We provide a formal definition and
text of RL, a fine-grained structure of the environment dy- detailed background of local independence in Appendix A.3.
namics has been leveraged in various ways, e.g., with data For example, consider a mobile home robot interacting with
augmentation (Pitis et al., 2022), efficient planning (Hoey various objects (P a(j)). Under the context of the door
et al., 1999; Chitnis et al., 2021), or exploration (Seitzer closed (D), only objects within the same room (P a(j; D))
et al., 2021). For this, previous works often exploited do- become relevant. On the other hand, all objects remain rel-
main knowledge (Pitis et al., 2022) or true dynamics model evant under the context of the door opened. We say that
(Chitnis et al., 2021). However, such prior knowledge is of- a context is meaningful if it displays sparse dependencies:
ten unavailable in the context of dynamics learning. Existing P a(j; D) ⊊ P a(j), e.g., door closed. We are concerned
methods for discovering fine-grained relationships examine with the subgraph of the (global) causal graph G as a graph-
the gradient (Wang et al., 2023) or attention score (Pitis ical representation of such local dependencies.
et al., 2020) of each sample (Fig. 1-(b)). However, such
Definition 1. Local subgraph of the causal graph1 (LCG)
sample-specific approaches lack an understanding of under
on D ⊆ X is GD = (V, ED ) where ED = {(i, j) | i ∈
which circumstances the inferred dependencies hold, and it
P a(j; D)}.
is unclear whether they can generalize to unseen states.
In the field of causality, fine-grained causal relationships LCG GD represents a causal structure of the transition dy-
have been widely studied, e.g., context-specific indepen- namics specific to a certain context D. It is useful for our
dence (Boutilier et al., 2013; Zhang & Poole, 1999; Poole, approach to fine-grained dynamics modeling, e.g., it is suf-
1998; Dal et al., 2018; Tikka et al., 2019; Jamshidi et al., ficient to consider only objects in the same room when the
2023) (see Appendix A.2 for the background). Recently, door is closed. In contrast, prior causal dynamics models
Hwang et al. (2023) proposed an auxiliary network that ex- consider all objects under any circumstances (Fig. 1-(a)).
amines local independence for each sample. However, it Importantly, such information (e.g., D and GD ) is not known
also does not explicitly capture the context where the local in advance, and it is our goal to discover them. For this, ex-
independence holds. In contrast to existing approaches re- isting sample-specific approaches have focused on inferring
lying on sample-specific inference (Löwe et al., 2022; Pitis LCG directly from individual samples (Pitis et al., 2020;
et al., 2020; Hwang et al., 2023), we propose to examine Hwang et al., 2023) (Fig. 1-(b)). However, it is unclear
causal dependencies at a subgroup level through quantiza- under which context the inferred dependencies hold.
tion (Fig. 1-(c)), providing a more robust and principled
way of discovering fine-grained causal relationships with a Our approach is to quantize the state-action space into sub-
theoretical guarantee. groups and examine causal structures on each subgroup
(Fig. 1-(c)). This now makes it clear that each inferred LCG
will represent fine-grained causal relationships under the
3. Fine-Grained Causal Dynamics Learning corresponding subgroup. We now proceed to describe a
In this section, we first describe a brief background on principled way to discover LCGs based on quantization.
local independence and intuition of our approach (Sec. 3.1). 1
For brevity, we will henceforth denote it as local causal graph.
3
Fine-Grained Causal Dynamics Learning
s0
<latexit sha1_base64="rLydIPTJEtYRfoZpDbg2DKGKs/w=">AAAB6XicbVDLSgNBEOyNrxhfUY9eBoPoKeyKr2PAi8co5gHJEmYnvcmQ2dllZlYIS/7AiwdFvPpH3vwbJ8keNLGgoajqprsrSATXxnW/ncLK6tr6RnGztLW9s7tX3j9o6jhVDBssFrFqB1Sj4BIbhhuB7UQhjQKBrWB0O/VbT6g0j+WjGSfoR3QgecgZNVZ60Ke9csWtujOQZeLlpAI56r3yV7cfszRCaZigWnc8NzF+RpXhTOCk1E01JpSN6AA7lkoaofaz2aUTcmKVPgljZUsaMlN/T2Q00nocBbYzomaoF72p+J/XSU1442dcJqlByeaLwlQQE5Pp26TPFTIjxpZQpri9lbAhVZQZG07JhuAtvrxMmudV76p6eX9RqVXzOIpwBMdwBh5cQw3uoA4NYBDCM7zCmzNyXpx352PeWnDymUP4A+fzBzxjjR4=</latexit>
(s, a) (s, a)
Decoder
Encoder
Dynamics
Model
(a) Local Causal Graph Inference (b) Masked Prediction with LCG
Figure 2. Overall framework. (a) For each sample (s, a), our method determines the subgroup to which the sample belongs through
quantization and infers the local causal graph (LCG) that represents fine-grained causal relationships specific to the corresponding
subgroup. (b) The dynamics model predicts the future state based on the inferred LCG. All components (e.g., dynamics model and
codebook) are jointly learned throughout the training in an end-to-end manner.
3.2. Score for Decomposition and Graphs decomposition is now also a learning objective towards max-
imizing Eq. (4), i.e., {Gz∗ , Ez∗ } ∈ argmax S({Gz , Ez }K
z=1 ).
Let us consider arbitrary decomposition {Ez }K z=1 of the However, a naive optimization with respect to decomposi-
state-action space X , where K is the degree of the quantiza-
tion is generally intractable. Thus, we devise a practical
tion. The transition dynamics can be decomposed as:
method allowing joint training with the dynamics model.
X
p(s′j | s, a) = p(s′j | s, a, z)p(z | s, a)
3.3. Fine-Grained Causal Dynamics Learning with
z
X Quantization
= p(s′j | P a(j; Ez ), z)p(z | s, a), (3)
z We propose a practical differentiable method that allows
joint optimization of Eq. (4) over dynamics model p̂, decom-
where p(z | s, a) = 1 if (s, a) ∈ Ez . This illustrates our position {Ez }, and graphs {Gz }, in an end-to-end manner.
approach to fine-grained dynamics modeling, employing The key component is a discrete latent codebook C = {ez }
only locally relevant dependencies according to GEz on each where each code ez represents the pair of a subgroup Ez and
subgroup Ez . We now aim to learn each LCG GEz based on a graph Gz . The codebook learning is differentiable, and
Eq. (3). Specifically, we consider the regularized maximum these pairs will be learned throughout the training with the
likelihood score S({Gz , Ez }K K
z=1 ) of the graphs {Gz }z=1 and dynamics model. The overall framework is shown in Fig. 2.
decomposition {Ez }K z=1 which is defined as:
Quantization. The encoder genc maps each sample (s, a)
sup Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz , Ez }, ϕ) − λ|Gz | , (4) into a latent embedding h, which is then quantized to the
ϕ nearest prototype vector e (i.e., code) in the codebook C =
{e1 , · · · , eK }, following Van Den Oord et al. (2017):
where ϕ is the parameters of the dynamics model p̂ which
employs the graph Gz for prediction on corresponding sub- e = ez , where z = argmin ∥h − ej ∥2 . (5)
group Ez . We now show that graphs that maximize the score j∈[K]
faithfully represent causal dependencies on each subgroup.
This entails the subgroups since each sample corresponds
Theorem 1 (Identifiability of LCGs). With Assumptions 1 to exactly one of the codes, i.e., each code ez represents the
to 4, let {Ĝz } ∈ argmax S({Gz , Ez }K
z=1 ) for λ > 0 small subgroup Ez = {(s, a) | e = ez }. Thus, this corresponds to
enough. Then, each Ĝz is true LCG on Ez , i.e., Ĝz = GEz . the term p(z | s, a) in Eq. (3). In other words, the codebook
C serves as a proxy for decomposition {Ez }K z=1 .
Given the subgroups, corresponding LCGs can be recovered
by score maximization. Therefore, it provides a principled Local causal graphs. Quantized embedding e is then de-
way to discover LCGs, which is valid for any quantization. coded to an adjacency matrix A ∈ {0, 1}(N +M )×N . The
output of the decoder gdec is the parameters of Bernoulli dis-
Unfortunately, not all quantization is useful for fine-grained
tributions from which the matrix is sampled: A ∼ gdec (e).
dynamics modeling, e.g., by dividing into lights on and
In other words, each code ez corresponds to the matrix Az
lights off, it still needs to consider all objects under both
that represents the graph Gz . To properly backpropagate gra-
circumstances. Thus, it is crucial for quantization to cap-
dients, we adopt Gumbel-Softmax reparametrization trick
ture meaningful contexts displaying sparse dependencies.
(Jang et al., 2017; Maddison et al., 2017).
Such useful quantization will allow more sparse dynamics
modeling, i.e., the higher score of Eq. (4). Therefore, the Dynamics learning. The dynamics model p̂ employs the
4
Fine-Grained Causal Dynamics Learning
matrix A for prediction: j log p̂(s′j | s, a; A(j) ), where proach towards the maximization of S({Gz , Ez }K
P
z=1 ) since
A(j) ∈ {0, 1}(N +M ) is the j-th column of A. Each entry Lpred corresponds to Eq. (4) and Lquant is a mean squared
of A(j) indicates whether the corresponding state or action error in the latent space which can be minimized to 0. In
variable will be used to predict the next state s′j . This cor- this section, we provide its implications and discussions.
responds to the term p(s′j | P a(j; Ez ), z) in Eq. (3). For Proposition 1. Let {Gz∗ , Ez∗ } ∈ argmax S({Gz , Ez }K z=1 )
the implementation, we mask out the features of unused for λ > 0 small enough, with Assumptions 1 to 5. Then,
(i) each Gz∗ is true LCG on Ez∗ , and (ii) E |Gz∗ | ≤ E |Gz |
variables according to A. We found that this is more stable
compared to the input masking (Brouillard et al., 2020). where {Gz } are LCGs on arbitrary decomposition {Ez }K z=1 .
Training objective. We employ a regularization loss λ · In other words, the decomposition that maximizes the score
∥A∥1 to induce a sparse LCG, where λ is a hyperparameter.
P
is optimal in terms of E |Gz | = z p(Ez )|Gz |. This is an
To update the codebook, we use a quantization loss (Van important property involving the contexts which are more
Den Oord et al., 2017). The training objective is as follows: likely (i.e., large p(E)) and more meaningful (i.e., sparse
Ltotal = − log p̂(s′ | s, a; A) + λ · ∥A∥1 GE ). Therefore, Prop. 1 implies that score maximization
| {z } would lead to the fine-grained understanding of the dynam-
Lpred
ics at best it can achieve given the quantization degree K.
+ ∥sg [h] − e∥22 + β · ∥h − sg [e] ∥22 . (6)
| {z } We now illustrate how the optimal decomposition {Ez∗ }K z=1
Lquant in Prop. 1 with sufficient quantization degree identifies
Here, Lpred is the masked prediction loss with regulariza- important context D (e.g., door closed) that displays fine-
tion. Lquant is the quantization loss where sg [·] is a stop- grained causal relationships. We say the context D is canon-
gradient operator and β is a hyperparameter. Specifically, ical if GF = GD for any F ⊂ D.
∥sg [h] − e∥22 moves each code toward the center of the em- Theorem 2 (Identifiability of contexts). Let {Gz∗ , Ez∗ } ∈
beddings assigned to it and β · ∥h − sg [e] ∥22 encourages argmax S({Gz , Ez }K z=1 ) for λ > 0 small enough, with As-
the encoder to output the embeddings close to the codes. sumptions 1 to 5. Suppose X = ∪m∈[H] Dm where GDm is
This allows us to jointly train the dynamics model and the distinct for all m ∈ [H], and D1 , · · · , DH are disjoint and
codebook in an end-to-end manner. Intuitively, vector quan- canonical. Suppose K ≥ H. Then, S for all m ∈ [H], there
tization clusters the samples under a similar context and exists Im ⊂ [K] such that Dm = z∈Im Ez∗ almost surely.
reconstructs the LCGs for each clustering. The rationale is
that any error in the graph Gz or clustering Ez would lead to In other words, the joint optimization of Eq. (4) over the
the prediction error of the dynamics model. We provide the quantization and dynamics model with sufficient quantiza-
details of our model in Appendix C.4. tion degree perfectly captures meaningful contexts that exist
in the system (Thm. 2) and recovers corresponding LCGs
We note that prior works on learning a discrete latent code- (Prop. 1-(i)), thereby leading to a fine-grained understand-
book have mostly focused on the reconstruction of the ob- ing of the dynamics. Our method described in the previous
servation (Van Den Oord et al., 2017; Ozair et al., 2021). section serves as a practical approach toward this goal.
To the best of our knowledge, our work is the first to utilize
vector quantization for discovering diverse causal structures. Discussion on the codebook size. Thm. 2 also implies
that the identification of the meaningful contexts is agnos-
Discussion on the codebook collapsing. It is well known tic to the quantization degree K, as long as K ≥ H. In
that training a discrete latent codebook with vector quanti- Sec. 4.2, we demonstrate that our method works reason-
zation often suffers from the codebook collapsing, where ably well for various quantization degrees in practice. We
many codes learn the same output and converge to a trivial note that determining a minimal and sufficient number of
solution. For this, we employ exponential moving averages quantization H is not our primary focus. This is because
(EMA) to update the codebook, following Van Den Oord over-parametrization of quantization incurs only a small
et al. (2017). In practice, we found that the training was memory cost for additional codebook vectors in practice.
relatively stable for any choice of the codebook size K > 2. Note that even if K < H, Prop. 1-(ii) guarantees that it
In our experiments, we simply fixed it to 16 across all envi- would still discover meaningful
fine-grained causal relation-
ronments since they all performed comparably well, which
ships, optimal in terms of E |Gz | .
we will demonstrate in Sec. 4.2.
Relationship to past approaches. To better understand our
approach, we draw connections to (i) prior causal dynamics
3.4. Theoretical Analysis and Discussions
models and (ii) sample-specific approaches to discovering
So far, we have described how our method learns the de- fine-grained dependencies. First, our method with the quan-
composition and LCGs through the discrete latent codebook tization degree K = 1 degenerates to prior causal dynamics
C as a proxy. Our method can be viewed as a practical ap- models (Wang et al., 2022; Ding et al., 2022): it would
5
Fine-Grained Causal Dynamics Learning
4. Experiments
In this section, we evaluate our method, coined Fine-Grained
Causal Dynamics Learning (FCDL), to investigate the fol-
lowing questions: (1) Does our method improve robustness
in MBRL (Tables 1 and 2)? (2) Does our method discover
(b) Magnetic
fine-grained causal relationships and capture meaningful
contexts (Figs. 5 to 7)? (3) Is our method more effective Figure 3. Illustrations for each environment. (a) In Chemical, col-
and robust compared to sample-specific approaches (Figs. 6 ors change by the action according to the underlying causal graph.
and 7)? (4) How does the degree of quantization affect (b) In Magnetic, the red object exhibits magnetism.
performance (Fig. 7 and Table 3)?
fine-grained causal reasoning would generalize well since
4.1. Experimental Setup corrupted nodes are locally spurious to predict other nodes.
The environments are designed to exhibit fine-grained causal Magnetic. We designed a robot arm manipulation environ-
relationships under a particular context D. The state vari- ment based on the Robosuite framework (Zhu et al., 2020).
ables (e.g., position, velocity) are fully observable, follow- There is a moving ball and a box on the table, colored red
ing prior works (Ding et al., 2022; Wang et al., 2022; Seitzer or black (Fig. 3(b)). Red color indicates that the object is
et al., 2021; Pitis et al., 2020; 2022). Experimental details magnetic, and attracts the other magnetic object. For exam-
are provided in Appendix C.2 ple, when both are red, magnetic force will be applied, and
the ball will move toward the box. Otherwise, i.e., under the
4.1.1. E NVIRONMENTS non-magnetic context, the box would have no influence on
Chemical (Ke et al., 2021). It is a widely used benchmark the ball. The color and position of the objects are randomly
for systematic evaluation of causal reasoning in RL. There initialized for each episode, i.e., each episode is under the
are 10 nodes, each colored with one of 5 colors. According magnetic or non-magnetic context during training. The task
to the underlying causal graph, an action changes the colors is to reach the ball, predicting its trajectory. In this environ-
of the intervened node’s descendants as depicted in Fig. 3(a). ment, non-magnetic context D displays sparse dependencies
The task is to match the colors of each node to the given (GD ⊊ G) because the box no longer influences the ball un-
target. We designed two settings, named full-fork and full- der this context. In contrast, all causal dependencies remain
chain. In both settings, the underlying CG is both full. When the same under the magnetic context Dc , i.e., GDc = G. CG
the color of the root node is red (D), the colors change and LCGs are shown in Appendix C.1 (Fig. 9). During the
according to fork or chain, respectively (GD ). For example, test, one of the objects is black, and the box is located at
in full-fork, all other parent nodes except the root become an unseen position. Under this non-magnetic context, the
irrelevant under this context. Otherwise (Dc ), the transition box becomes locally spurious, and thus, the agent aware of
respects the graph full (i.e., GDc = G). During the test, the fine-grained causal relationships would generalize well to
root color is set to red, and LCG (fork or chain) is activated. unseen out-of-distribution (OOD) states.
Here, the agent receives a noisy observation for some nodes,
and the task is to match the colors of other clean nodes, as 4.1.2. E XPERIMENTAL DETAILS
depicted in Appendix C.1 (Fig. 8). The agent capable of Baselines. We first consider dense models, i.e., a monolithic
2 network implemented as MLP which learns p(s′ | s, a),
Our code is publicly available at https://github.com/
iwhwang/Fine-Grained-Causal-RL. and a modular
Q network having a separate network for each
variable: j p(s′j | s, a). We also include a graph neural
6
Fine-Grained Causal Dynamics Learning
Table 1. Average episode reward on training and downstream tasks in each environment. In Chemical, n denotes the number of noisy
nodes in downstream tasks.
Chemical (full-fork) Chemical (full-chain) Magnetic
Train Test Test Test Train Test Test Test
Methods Train Test
(n = 0) (n = 2) (n = 4) (n = 6) (n = 0) (n = 2) (n = 4) (n = 6)
MLP 19.00±0.83 6.49±0.48 5.93±0.71 6.84±1.17 17.91±0.87 7.39±0.65 6.63±0.58 6.78±0.93 8.37±0.74 0.86±0.45
Modular 18.55±1.00 6.05±0.70 5.65±0.50 6.43±1.00 17.37±1.63 6.61±0.63 7.01±0.55 7.04±1.07 8.45±0.80 0.88±0.52
GNN (Kipf et al., 2020) 18.60±1.19 6.61±0.92 6.15±0.74 6.95±0.78 16.97±1.85 6.89±0.28 6.38±0.28 6.56±0.53 8.53±0.83 0.92±0.51
NPS (Goyal et al., 2021a) 7.71±1.22 5.82±0.83 5.75±0.57 5.54±0.80 8.20±0.54 6.92±1.03 6.88±0.79 6.80±0.39 3.13±1.00 0.91±0.69
CDL (Wang et al., 2022) 18.95±1.40 9.37±1.33 8.23±0.40 9.50±1.18 17.95±0.83 8.71±0.55 8.65±0.38 10.23±0.50 8.75±0.69 1.10±0.67
GRADER (Ding et al., 2022) 18.65±0.98 9.27±1.31 8.79±0.65 10.61±1.31 17.71±0.54 8.69±0.56 8.75±0.80 10.14±0.33 - -
Oracle 19.64±1.18 7.83±0.87 8.04±0.62 9.66±0.21 17.79±0.76 8.47±0.69 8.85±0.78 10.29±0.37 8.42±0.86 0.95±0.55
NCD (Hwang et al., 2023) 19.30±0.95 10.95±1.63 9.11±0.63 10.32±0.93 18.27±0.27 9.60±1.52 8.86±0.23 10.32±0.37 8.48±0.70 1.31±0.77
FCDL (Ours) 19.28±0.87 15.27±2.53 14.73±1.68 13.62±2.56 17.22±0.61 13.36±3.60 12.35±3.23 12.00±1.21 8.52±0.74 4.81±3.01
8
episode reward
episode reward
15
6
10 4
(a) ID (all) (b) ID (fork) (c) OOD (fork)
2
5
0
0 0.5 1.0 1.5 0.5 1.0 1.5 2.0
5 5
Environment step (x10 ) Environment step (x10 )
7
Fine-Grained Causal Dynamics Learning
(a) FCDL (CG) (b) FCDL (LCG) (c) NCD (LCG, ID) (d) NCD (LCG, OOD)
Figure 6. Red boxes indicate edges included in global CG, but not in LCG under the non-magnetic context. (a) CG inferred by our
method. (b-d) LCG under the non-magnetic context inferred by (b) our method, and NCD on (c) ID and (d) OOD state.
variables, merely above 20% which is an expected accuracy our method in fine-grained causal reasoning compared to the
of random guessing. As expected, causal dynamics mod- sample-specific approach. For this, we examine the inferred
els tend to be more robust compared to dense models, but LCGs in Magnetic, where true LCGs and CG are shown in
they still suffer from OOD states. NCD is more robust than Appendix C.1 (Fig. 9). First, our method accurately learns
causal models when n = 2, but eventually becomes similar LCG under the non-magnetic context (Fig. 6(b)). On the
to them as the number of noisy nodes increases. In contrast, other hand, the LCG inferred by NCD is rather inaccurate
our method outperforms baselines by a large margin across (Fig. 6(c)), including some locally spurious dependencies (3
all downstream tasks, which demonstrates its effectiveness among 6 red boxes). Furthermore, its inference is inconsis-
and robustness in fine-grained causal reasoning. tent between ID and OOD states in the same non-magnetic
context and completely fails on OOD states (Fig. 6(d)). This
Recognizing important contexts and fine-grained causal
demonstrates that our approach is more effective and robust
relationships (Fig. 5). To illustrate the fine-grained causal
in discovering fine-grained causal relationships.
reasoning of our method, we closely examine the behavior of
our model with the quantization degree K = 4 in Chemical Evaluation of local causal discovery (Fig. 7). We evaluate
(full-fork, n = 2). Recall each code corresponds to the pair our method and NCD using structural hamming distance
of a subgroup and LCG, Fig. 5(a) shows how ID samples in (SHD) in Magnetic. For each sample, we compare the
the batch are allocated to one of the four codes. Interestingly, inferred LCG with the true LCG based on the magnetic/non-
ID samples corresponding to LCG fork are all allocated to magnetic context, and the SHD scores are averaged over
the last code (Fig. 5(b)), i.e., the subgroup corresponding the data samples in the evaluation batch. As expected, our
to the last code identifies this context. Furthermore, LCG method infers fine-grained relationships more accurately
decoded from this code (Fig. 5(e)) accurately captures the and maintains better performance on OOD states across var-
true fork structure (Fig. 5(d)). This demonstrates that our ious quantization degrees, which validates its effectiveness
method successfully recognizes meaningful context and fine- and robustness compared to NCD. Lastly, we note that our
grained causal relationships. Notably, Fig. 5(c) shows that method with the quantization degree K = 1 would learn
most of the OOD samples under fork are correctly allocated only a single CG over the entire data domain, as shown in
to the last code. This illustrates the robustness of our method, Fig. 6(a). This explains its mean SHD score of 6 in non-
i.e., its inference is consistent between ID and OOD states. magnetic samples in Fig. 7, since CG includes six redundant
Additional examples, including the visualization of all the edges in non-magnetic context (i.e., red boxes in Fig. 9).
learned LCGs from all codes, are provided in Appendix C.5.
Ablation on the quantization degree (Table 3). Finally,
Inferred LCGs compared to sample-specific approach we observe that our method works reasonably well across
(Fig. 6). We investigated the effectiveness and robustness of various quantization degrees on all downstream tasks in
8
Fine-Grained Causal Dynamics Learning
6 6 6 Methods (n = 2) (n = 4) (n = 6)
CDL 9.37±1.33 8.23±0.40 9.50±1.18
SHD
SHD
SHD
4 4 4
NCD 10.95±1.63 9.11±0.63 10.32±0.93
2 2 2 FCDL (K = 2) 13.44±5.41 12.86±5.58 12.99±5.27
0 0 0 FCDL (K = 4) 15.73±4.13 16.50±3.40 12.40±2.81
1 2 4 8 16 32 1 2 4 8 16 32 1 2 4 8 16 32 FCDL (K = 8) 14.95±1.16 15.03±2.61 13.42±2.67
Quantization degree Quantization degree Quantization degree
FCDL (K = 16) 15.27±2.53 14.73±1.68 13.62±2.56
FCDL (K = 32) 16.12±1.43 14.35±1.37 14.79±2.13
Figure 7. Evaluation of local causal discovery in Magnetic environment.
Chemical (full-fork). Our method consistently outperforms domain knowledge, it would still be useful for discovering
the prior causal dynamics model (CDL) and sample-specific fine-grained relationships more efficiently (e.g., Thm. 1).
approach (NCD), which corroborates the results in Fig. 7.
Implications to real-world scenarios. We believe our work
During our experiments, we found that the training was
has potential implications in many practical applications
relatively stable for any quantization degree of K > 2.
since context-dependent causal relationships are prevalent
We also found that instability often occurs under K = 2,
in real-world scenarios. For example, in healthcare, a dy-
where the samples frequently fluctuate between two proto-
namic treatment regime is a task of determining a sequence
type vectors and result in the codebook collapsing. This is
of decision rules (e.g., treatment type, drug dosage) based
also shown in Table 3 where the performance of K = 2 is
on the patient’s health status where it is known that many
worse compared to other choices of K. We speculate that
pathological factors involve fine-grained causal relation-
over-parametrization of quantization could alleviate such
ships (Barash & Friedman, 2001; Edwards & Toma, 1985).
fluctuation in general.
Our experiments illustrate that existing causal/non-causal
RL approaches could suffer from locally spurious correla-
5. Discussions and Future Works tions and fail to generalize in downstream tasks. We believe
our work serves as a stepping stone for further investigation
High-dimensional observation. The factorization of the
into fine-grained causal reasoning of RL systems and their
state space is natural in many real-world domains (e.g.,
robustness in real-world deployment.
healthcare, recommender system, social science, economics)
where discovering causal relationships is an important prob-
lem. Extending our framework to the image would require 6. Conclusion
extracting causal factors from pixels (Schölkopf et al., 2021),
We present a novel approach to dynamics learning that in-
which is orthogonal to ours and could be combined with.
fers fine-grained causal relationships, leading to improved
Scalability and stability in training. Vector quantization robustness of MBRL. We provide a principled way to exam-
(VQ) is a well-established component in generative models ine fine-grained dependencies under certain contexts. As a
where the quantization degree is usually very high (e.g., practical approach, our method learns a discrete latent vari-
K = 512, 1024), yet effectively captures diverse visual fea- able that represents the pairs of a subgroup and local causal
tures. Its scalability is further showcased in complex large- graphs (LCGs), allowing joint optimization with the dynam-
scale datasets (Razavi et al., 2019). In this sense, we believe ics model. Consequently, our method infers fine-grained
our framework could extend to complex real-world environ- causal structures in a more effective and robust manner com-
ments. For training stability, techniques have been recently pared to prior approaches. As one of the first steps towards
proposed to prevent codebook collapsing, such as codebook fine-grained causal reasoning in sequential decision-making
reset (Williams et al., 2020) and stochastic quantization systems, we hope our work stimulates future research to-
(Takida et al., 2022). We consider that such techniques and ward this goal.
tricks could be incorporated into our framework.
Conditional independence test (CIT). A CIT is an effec- Impact Statement
tive tool for understanding causal relationships, although
In real-world applications, model-based RL requires a large
often computation-costly. Our method may utilize it to fur-
amount of data. As a large-scale dataset may contain sensi-
ther calibrate the learned LCGs, e.g., applying CIT on each
tive information, it would be advisable to discreetly evaluate
subgroup after the training, which we defer to future work.
the models within simulated environments before their real-
Domain knowledge. Our method could leverage prior infor- world deployment.
mation on important contexts displaying sparse dependen-
cies, if available. While our method does not rely on such
9
Fine-Grained Causal Dynamics Learning
Acknowledgements Chung, J., Gulcehre, C., Cho, K., and Bengio, Y. Empirical
evaluation of gated recurrent neural networks on sequence
We would like to thank Sujin Jeon and Hyundo Lee for modeling. arXiv preprint arXiv:1412.3555, 2014.
the useful discussions. We also thank anonymous review-
ers for their constructive comments. This work was partly Dal, G. H., Laarman, A. W., and Lucas, P. J. Parallel proba-
supported by the IITP (RS-2021-II212068-AIHub/10%, RS- bilistic inference by weighted model counting. In Inter-
2021-II211343-GSAI/10%, 2022-0-00951-LBA/10%, 2022- national Conference on Probabilistic Graphical Models,
0-00953-PICA/20%), NRF (RS-2023-00211904/20%, RS- pp. 97–108. PMLR, 2018.
2023-00274280/10%, RS-2024-00353991/10%), and KEIT
(RS-2024-00423940/10%) grant funded by the Korean gov- De Haan, P., Jayaraman, D., and Levine, S. Causal confu-
ernment. sion in imitation learning. Advances in Neural Informa-
tion Processing Systems, 32, 2019.
References Ding, W., Lin, H., Li, B., and Zhao, D. Generalizing
goal-conditioned reinforcement learning with variational
Acid, S. and de Campos, L. M. Searching for bayesian causal reasoning. In Advances in Neural Information
network structures in the space of restricted acyclic par- Processing Systems, 2022.
tially directed graphs. Journal of Artificial Intelligence
Research, 18:445–490, 2003. Edwards, D. and Toma, H. A fast procedure for
model search in multidimensional contingency tables.
Barash, Y. and Friedman, N. Context-specific bayesian Biometrika, 72:339–351, 1985.
clustering for gene expression data. In Proceedings of the
fifth annual international conference on Computational Feng, F. and Magliacane, S. Learning dynamic attribute-
biology, pp. 12–21, 2001. factored world models for efficient multi-object reinforce-
ment learning. In Advances in Neural Information Pro-
Bareinboim, E., Forney, A., and Pearl, J. Bandits with
cessing Systems, volume 36, pp. 19117–19144, 2023.
unobserved confounders: A causal approach. Advances
in Neural Information Processing Systems, 28, 2015. Feng, F., Huang, B., Zhang, K., and Magliacane, S. Fac-
tored adaptation for non-stationary reinforcement learn-
Bica, I., Jarrett, D., and van der Schaar, M. Invariant causal
ing. In Advances in Neural Information Processing Sys-
imitation learning for generalizable policies. Advances in
tems, 2022.
Neural Information Processing Systems, 34:3952–3964,
2021. Goyal, A., Didolkar, A. R., Ke, N. R., Blundell, C., Beau-
doin, P., Heess, N., Mozer, M. C., and Bengio, Y. Neural
Bongers, S., Blom, T., and Mooij, J. M. Causal modeling
production systems. In Advances in Neural Information
of dynamical systems. arXiv preprint arXiv:1803.08784,
Processing Systems, 2021a.
2018.
Goyal, A., Lamb, A., Gampa, P., Beaudoin, P., Blundell,
Boutilier, C., Friedman, N., Goldszmidt, M., and Koller,
C., Levine, S., Bengio, Y., and Mozer, M. C. Factorizing
D. Context-specific independence in bayesian networks.
declarative and procedural knowledge in structured, dy-
CoRR, abs/1302.3562, 2013.
namical environments. In International Conference on
Brouillard, P., Lachapelle, S., Lacoste, A., Lacoste-Julien, Learning Representations, 2021b.
S., and Drouin, A. Differentiable causal discovery from
Goyal, A., Lamb, A., Hoffmann, J., Sodhani, S., Levine,
interventional data. Advances in Neural Information Pro-
S., Bengio, Y., and Schölkopf, B. Recurrent independent
cessing Systems, 33:21865–21877, 2020.
mechanisms. In International Conference on Learning
Buesing, L., Weber, T., Zwols, Y., Heess, N., Racaniere, S., Representations, 2021c.
Guez, A., and Lespiau, J.-B. Woulda, coulda, shoulda:
Counterfactually-guided policy search. In International Hoey, J., St-Aubin, R., Hu, A., and Boutilier, C. Spudd:
Conference on Learning Representations, 2019. stochastic planning using decision diagrams. In Pro-
ceedings of the Fifteenth conference on Uncertainty in
Camacho, E. F. and Alba, C. B. Model predictive control. artificial intelligence, pp. 279–288, 1999.
Springer science & business media, 2013.
Huang, B., Lu, C., Leqi, L., Hernández-Lobato, J. M., Gly-
Chitnis, R., Silver, T., Kim, B., Kaelbling, L., and Lozano- mour, C., Schölkopf, B., and Zhang, K. Action-sufficient
Perez, T. Camps: Learning context-specific abstractions state representation learning for control with structural
for efficient planning in factored mdps. In Conference on constraints. In International Conference on Machine
Robot Learning, pp. 64–79. PMLR, 2021. Learning, pp. 9260–9279. PMLR, 2022.
10
Fine-Grained Causal Dynamics Learning
Hwang, I., Kwak, Y., Song, Y.-J., Zhang, B.-T., and Lee, Li, Y., Torralba, A., Anandkumar, A., Fox, D., and Garg,
S. On discovery of local independence over continuous A. Causal discovery in physical systems from videos.
variables via neural contextual decomposition. In Confer- Advances in Neural Information Processing Systems, 33:
ence on Causal Learning and Reasoning, pp. 448–472. 9180–9192, 2020.
PMLR, 2023.
Löwe, S., Madras, D., Zemel, R., and Welling, M. Amor-
Jamshidi, F., Akbari, S., and Kiyavash, N. Causal imitabil- tized causal discovery: Learning to infer causal graphs
ity under context-specific independence relations. In from time-series data. In Conference on Causal Learning
Advances in Neural Information Processing Systems, vol- and Reasoning, pp. 509–525. PMLR, 2022.
ume 36, pp. 26810–26830, 2023.
Lu, C., Schölkopf, B., and Hernández-Lobato, J. M. Decon-
Jang, E., Gu, S., and Poole, B. Categorical reparameteriza- founding reinforcement learning in observational settings.
tion with gumbel-softmax. In International Conference arXiv preprint arXiv:1812.10576, 2018.
on Learning Representations, 2017.
Lu, C., Huang, B., Wang, K., Hernández-Lobato, J. M.,
Kaiser, Ł., Babaeizadeh, M., Miłos, P., Osiński, B., Camp- Zhang, K., and Schölkopf, B. Sample-efficient reinforce-
bell, R. H., Czechowski, K., Erhan, D., Finn, C., Koza- ment learning via counterfactual-based data augmenta-
kowski, P., Levine, S., et al. Model based reinforcement tion. arXiv preprint arXiv:2012.09092, 2020.
learning for atari. In International Conference on Learn-
ing Representations, 2020. Lyle, C., Zhang, A., Jiang, M., Pineau, J., and Gal, Y. Re-
solving causal confusion in reinforcement learning via
Ke, N. R., Didolkar, A. R., Mittal, S., Goyal, A., Lajoie, robust exploration. In Self-Supervision for Reinforcement
G., Bauer, S., Rezende, D. J., Mozer, M. C., Bengio, Y., Learning Workshop-ICLR, volume 2021, 2021.
and Pal, C. Systematic evaluation of causal discovery in
visual model based reinforcement learning. In Thirty-fifth Maddison, C. J., Mnih, A., and Teh, Y. W. The concrete
Conference on Neural Information Processing Systems distribution: A continuous relaxation of discrete random
Datasets and Benchmarks Track (Round 2), 2021. variables. In International Conference on Learning Rep-
resentations, 2017.
Kearns, M. and Koller, D. Efficient reinforcement learning
in factored mdps. In IJCAI, volume 16, pp. 740–747, Madumal, P., Miller, T., Sonenberg, L., and Vetere, F. Ex-
1999. plainable reinforcement learning through a causal lens.
In Proceedings of the AAAI conference on artificial intel-
Killian, T. W., Ghassemi, M., and Joshi, S. Counterfactually ligence, pp. 2493–2500, 2020.
guided policy transfer in clinical settings. In Conference
on Health, Inference, and Learning, pp. 5–31. PMLR, Mesnard, T., Weber, T., Viola, F., Thakoor, S., Saade, A.,
2022. Harutyunyan, A., Dabney, W., Stepleton, T. S., Heess,
N., Guez, A., et al. Counterfactual credit assignment in
Kipf, T., van der Pol, E., and Welling, M. Contrastive model-free reinforcement learning. In International Con-
learning of structured world models. In International ference on Machine Learning, pp. 7654–7664. PMLR,
Conference on Learning Representations, 2020. 2021.
Kumor, D., Zhang, J., and Bareinboim, E. Sequential causal Mutti, M., De Santi, R., Rossi, E., Calderon, J. F., Bronstein,
imitation learning with unobserved confounders. In Ad- M., and Restelli, M. Provably efficient causal model-
vances in Neural Information Processing Systems, vol- based reinforcement learning for systematic generaliza-
ume 34, pp. 14669–14680, 2021. tion. In Proceedings of the AAAI Conference on Artificial
Intelligence, pp. 9251–9259, 2023.
Lee, S. and Bareinboim, E. Structural causal bandits: Where
to intervene? In Advances in Neural Information Pro- Nair, S., Zhu, Y., Savarese, S., and Fei-Fei, L. Causal
cessing Systems, volume 31, 2018. induction from visual observations for goal directed tasks.
arXiv preprint arXiv:1910.01751, 2019.
Lee, S. and Bareinboim, E. Characterizing optimal mixed
policies: Where to intervene and what to observe. In Oberst, M. and Sontag, D. Counterfactual off-policy eval-
Advances in Neural Information Processing Systems, vol- uation with gumbel-max structural causal models. In
ume 33, pp. 8565–8576, 2020. International Conference on Machine Learning, pp. 4881–
4890. PMLR, 2019.
Li, M., Zhang, J., and Bareinboim, E. Causally aligned
curriculum learning. In The Twelfth International Confer- Ozair, S., Li, Y., Razavi, A., Antonoglou, I., Van Den Oord,
ence on Learning Representations, 2024. A., and Vinyals, O. Vector quantized models for planning.
11
Fine-Grained Causal Dynamics Learning
In International Conference on Machine Learning, pp. Sontakke, S. A., Mehrjou, A., Itti, L., and Schölkopf, B.
8302–8313. PMLR, 2021. Causal curiosity: Rl agents discovering self-supervised
experiments for causal representation learning. In Inter-
Pearl, J. Causality. Cambridge university press, 2009. national conference on machine learning, pp. 9848–9858.
PMLR, 2021.
Peters, J., Janzing, D., and Schölkopf, B. Elements of causal
inference: foundations and learning algorithms. The MIT Spirtes, P., Glymour, C. N., Scheines, R., and Heckerman,
Press, 2017. D. Causation, prediction, and search. MIT press, 2000.
Pitis, S., Creager, E., and Garg, A. Counterfactual data Sutton, R. S. and Barto, A. G. Reinforcement learning: An
augmentation using locally factored dynamics. Advances introduction. MIT press, 2018.
in Neural Information Processing Systems, 33, 2020.
Takida, Y., Shibuya, T., Liao, W., Lai, C.-H., Ohmura, J.,
Pitis, S., Creager, E., Mandlekar, A., and Garg, A. MocoDA:
Uesaka, T., Murata, N., Takahashi, S., Kumakura, T.,
Model-based counterfactual data augmentation. In Ad-
and Mitsufuji, Y. Sq-vae: Variational bayes on discrete
vances in Neural Information Processing Systems, 2022.
representation with self-annealed stochastic quantization.
Poole, D. Context-specific approximation in probabilistic In International Conference on Machine Learning, pp.
inference. In Proceedings of the Fourteenth conference on 20987–21012. PMLR, 2022.
Uncertainty in artificial intelligence, pp. 447–454, 1998.
Tikka, S., Hyttinen, A., and Karvanen, J. Identifying causal
Ramsey, J., Spirtes, P., and Zhang, J. Adjacency-faithfulness effects via context-specific independence relations. Ad-
and conservative causal inference. In Proceedings of the vances in Neural Information Processing Systems, 32:
Twenty-Second Conference on Uncertainty in Artificial 2804–2814, 2019.
Intelligence, pp. 401–408, 2006.
Tomar, M., Zhang, A., Calandra, R., Taylor, M. E.,
Razavi, A., van den Oord, A., and Vinyals, O. Generating and Pineau, J. Model-invariant state abstractions for
diverse high-fidelity images with vq-vae-2. In Advances model-based reinforcement learning. arXiv preprint
in Neural Information Processing Systems, volume 32, arXiv:2102.09850, 2021.
2019.
Van Den Oord, A., Vinyals, O., et al. Neural discrete rep-
Rezende, D. J., Danihelka, I., Papamakarios, G., Ke, N. R., resentation learning. Advances in neural information
Jiang, R., Weber, T., Gregor, K., Merzic, H., Viola, F., processing systems, 30, 2017.
Wang, J., et al. Causally correct partial models for re-
inforcement learning. arXiv preprint arXiv:2002.02836, Volodin, S., Wichers, N., and Nixon, J. Resolving spuri-
2020. ous correlations in causal models of environments via
interventions. arXiv preprint arXiv:2002.05217, 2020.
Rubinstein, R. Y. and Kroese, D. P. The cross-entropy
method: a unified approach to combinatorial optimiza- Wang, Z., Xiao, X., Zhu, Y., and Stone, P. Task-independent
tion, Monte-Carlo simulation, and machine learning, vol- causal state abstraction. In Proceedings of the 35th Inter-
ume 133. Springer, 2004. national Conference on Neural Information Processing
Systems, Robot Learning workshop, 2021.
Schölkopf, B., Locatello, F., Bauer, S., Ke, N. R., Kalch-
brenner, N., Goyal, A., and Bengio, Y. Toward causal Wang, Z., Xiao, X., Xu, Z., Zhu, Y., and Stone, P. Causal
representation learning. Proceedings of the IEEE, 109(5): dynamics learning for task-independent state abstraction.
612–634, 2021. In International Conference on Machine Learning, pp.
23151–23180. PMLR, 2022.
Schrittwieser, J., Antonoglou, I., Hubert, T., Simonyan, K.,
Sifre, L., Schmitt, S., Guez, A., Lockhart, E., Hassabis, Wang, Z., Hu, J., Stone, P., and Martı́n-Martı́n, R. ELDEN:
D., Graepel, T., et al. Mastering atari, go, chess and shogi Exploration via local dependencies. In Thirty-seventh
by planning with a learned model. Nature, 588(7839): Conference on Neural Information Processing Systems,
604–609, 2020. 2023.
Seitzer, M., Schölkopf, B., and Martius, G. Causal influ- Williams, W., Ringer, S., Ash, T., MacLeod, D., Dougherty,
ence detection for improving efficiency in reinforcement J., and Hughes, J. Hierarchical quantized autoencoders.
learning. In Advances in Neural Information Processing Advances in Neural Information Processing Systems, 33:
Systems, 2021. 4524–4535, 2020.
12
Fine-Grained Causal Dynamics Learning
13
Fine-Grained Causal Dynamics Learning
CSI has been widely studied especially for discrete variables with low cardinality, e.g., binary variables. Context-set specific
independence (CSSI) generalizes the notion of CSI allowing continuous variables.
Definition 3 (Context-Set Specific Independence (CSSI) (Hwang et al., 2023)). Let X = {X1 , · · · , Xd } be a non-empty set
of the parents of Y in a causal graph, and E ⊆ X be an event with a positive probability. E is said to′ be a context
set which
induces context-set specific independence (CSSI) of X A c from Y if p y | xAc , xA = p y | x c , xA holds for every
A
(xAc , xA ) , x′Ac , xA ∈ E. This will be denoted by Y ⊥
⊥ XAc | XA , E.
Intuitively, it denotes that the conditional distribution p(y | x) = p(y | xAc , xA ) is the same for different values of xAc , for
all x = (xAc , xA ) ∈ E. In other words, only a subset of the parent variables is sufficient for modeling p(y | x) on E.
It implies that only a subset of the parent variables (xT ) is locally relevant on E, and any other remaining variables (xT c )
are locally irrelevant, i.e., p(s′j | x) is a function of xT on E. Local independence generalizes conditional independence
3
T denotes an index set of T.
14
Fine-Grained Causal Dynamics Learning
In other words, P a(j; E) is a minimal subset of P a(j) in which the local independence on E holds. Clearly, P a(j; X ) =
P a(j), i.e., local independence on X is equivalent to the (global) conditional independence.
LCG (Def. 1) describes fine-grained causal relationships specific to E. LCG is always a subgraph of the (global) causal
graph, i.e., GD ⊆ G, because if a dependency (i.e., edge) does not exist under the whole domain, it cannot exist under any
context. Note that GX = G, i.e., local independence and LCG under X are equivalent to conditional independence and CG,
respectively.
Analogous to the faithfulness assumption (Peters et al., 2017) that no conditional independences other than ones entailed by
CG are present, we introduce a similar assumption for LCG and local independence.
Assumption 2 (E-Faithfulness). For any E, no local independences on E other than the ones entailed by GE are present, i.e.,
for any j, there does not exists any T such that P a(j; E) \ T ̸= ∅ and Sj′ ⊥
⊥ X \ T | T, E.
Regardless of E-faithfulness assumption, LCG always exists because P a(j; E) always exists. However, such LCG may not
be unique without this (see Hwang et al. (2023, Example. 2) for this example). Assumption 2 implies the uniqueness of
P a(j; E) and GE , and thus it is required to properly identify fine-grained causal relationships between the variables.
Such fine-grained causal relationships are prevalent in the real world. Physical law; To move a static object, a force
exceeding frictional resistance must be exerted. Otherwise, the object would not move. Logic; Consider A ∨ B ∨ C. When
A is true, any changes of B or C no longer affect the outcome. Biology; In general, smoking has a causal effect on blood
pressure. However, one’s blood pressure becomes independent of smoking if a ratio of alpha and beta lipoproteins is larger
than a certain threshold (Edwards & Toma, 1985).
where p(z | s, a) = 1 if (s, a) ∈ Ez . This illustrates our approach to dynamics modeling based on fine-grained causal
dependencies: p(s′j | s, a) is a function of P a(j, Ez ) on Ez , and our dynamics model employs locally relevant dependencies
P a(j, Ez ) for predicting Sj′ . Our dynamics modeling with some graphs {Gz }K z=1 is:
(j) (j)
where ϕz takes P aGz (j) as an input and outputs the parameters of the density function p̂j and ϕ := {ϕz }. We denote
p̂{Gz ,Ez },ϕ := p̂(s′ | s, a; {Gz , Ez }, ϕ) and p̂Gz ,ϕz := p̂(s′ | s, a; Gz , ϕz ). In other words, p̂{Gz ,Ez },ϕ (s′ | s, a) = p̂Gz ,ϕz (s′ |
s, a) if (s, a) ∈ Ez .
Now, we revisit the score function in Eq. (4):
′
S({Gz , Ez }K
z=1 ) := sup Ep(s,a,s′ ) log p̂(s | s, a; {Gz , Ez }, ϕ) − λ|Gz | , (9)
ϕ
15
Fine-Grained Causal Dynamics Learning
(j)
where p̂(s′ | s, a; Gz , ϕz ) = p̂j (s′j | P aGz (j); ϕz ).
Q
j
Assumption 3 states that the model parametrized by neural network has sufficient capacity to represent the ground truth
density. Assumption 4 is a technical tool for handling the score differences as we will see later.
Lemma 1. Let Gz∗ be a true LCG on Ez for all z. Then, S({Gz∗ , Ez }K ′
∗
z=1 ) = Ep(s,a,s′ ) log p(s | s, a) − λ · E |Gz | .
Proof. First,
0 ≤ DKL (p ∥ p̂{Gz∗ ,Ez },ϕ ) = Ep(s,a,s′ ) log p(s′ | s, a) − Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz∗ , Ez }, ϕ), (13)
where the equality holds because Ep(s,a,s′ ) log p(s′ | s, a) < ∞ by Assumption 4. Therefore,
sup Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz∗ , Ez }, ϕ) ≤ Ep(s,a,s′ ) log p(s′ | s, a). (14)
ϕ
On the other hand, by Assumption 3, there exists ϕ∗ such that p = p̂{Gz∗ ,Ez },ϕ∗ . Hence,
sup Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz∗ , Ez }, ϕ) ≥ Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz∗ , Ez }, ϕ∗ ) = Ep(s,a,s′ ) log p(s′ | s, a). (15)
ϕ
By Eqs. (14) and (15), we have supϕ Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz∗ , Ez }, ϕ) = Ep(s,a,s′ ) log p(s′ | s, a). Therefore, we have
S({Gz∗ , Ez }K ′
∗
z=1 ) = Ep(s,a,s′ ) log p(s | s, a) − λ · E |Gz | .
′
S({Gz , Ez }K
z=1 ) = sup Ep(s,a,s′ ) log p̂(s | s, a; {Gz , Ez }, ϕ) − λ · E |Gz | (17)
ϕ
16
Fine-Grained Causal Dynamics Learning
The last equality holds by Assumption 4. Subtracting Eq. (19) from Lemma 1, we obtain:
X
S({Gz∗ , Ez }K K
z=1 ) − S({Gz , Ez }z=1 ) = inf DKL (p ∥ p̂{Gz ,Ez },ϕ ) + λ p(Ez )(|Gz | − |Gz∗ |).
ϕ
z
where pz (s, a) := p(s, a | z) = p(s, a)/p(Ez ) for all (s, a) ∈ Ez , i.e., density function of the distribution PS×A|Ez .
Now, let fˆ(s′ | s, a) := j pz (s′j | P aGz (j)) for all (s, a) ∈ Ez . Then, for any f ∈ FEz (Gz ),
Q
p(s′ | s, a) fˆ(s′ | s, a)
Z Z Z
pz (s, a)DKL (p(· | s, a) ∥ f (· | s, a)) = pz (s, a) p(s′ | s, a) log
fˆ(s′ | s, a) f (s′ | s, a)
fˆ(s′ | s, a)
Z Z Z
ˆ
= pz (s, a)DKL (p ∥ f ) + pz (s, a) p(s′ | s, a) log
f (s′ | s, a)
Z
≥ pz (s, a)DKL (p ∥ fˆ).
Proof. To simplify the notation, let Gz∗ be a true LCG on Ez for all z, i.e., Gz∗ := GEz for brevity. It is enough to show that
S({Gz∗ , Ez }K K ∗
z=1 ) > S({Gz , Ez }z=1 ) if Gz ̸= Gz for some z.
17
Fine-Grained Causal Dynamics Learning
Now, by Lemma 2,
S({Gz∗ , Ez }K K
z=1 ) − S({Gz , Ez }z=1 ) (24)
X
= inf DKL (p ∥ p̂{Gz ,Ez },ϕ ) + λ p(Ez )(|Gz | − |Gz∗ |) (25)
ϕ
z
Z X
= inf p(s, a)DKL (p(· | s, a) ∥ p̂{Gz ,Ez },ϕ (· | s, a)) + λ p(Ez )(|Gz | − |Gz∗ |) (26)
ϕ
z
XZ X
= inf p(s, a)DKL (p(· | s, a) ∥ p̂Gz ,ϕz (· | s, a)) + λ p(Ez )(|Gz | − |Gz∗ |) (27)
ϕ (s,a)∈Ez
z z
X Z X
= inf p(Ez ) pz (s, a)DKL (p(· | s, a) ∥ p̂Gz ,ϕz (· | s, a)) + λ p(Ez )(|Gz | − |Gz∗ |) (28)
ϕz
z z
X X
= p(Ez ) inf DKL (pz ∥ p̂Gz ,ϕz ) + λ p(Ez )(|Gz | − |Gz∗ |) (29)
ϕz
z z
X X
∗
= p(Ez ) inf DKL (pz ∥ p̂Gz ,ϕz ) + λ(|Gz | − |Gz |) = p(Ez ) · Az . (30)
ϕz
z z
R
For brevity, we denote DKL (pz ∥ p̂Gz ,ϕz ) := pz (s, a)DKL (p(· | s, a) ∥ p̂Gz ,ϕz (· | s, a)) and Az := inf ϕz DKL (pz ∥
p̂Gz ,ϕz ) + λ(|Gz | − |Gz∗ |). Now, we will show that for all z ∈ [K], Az > 0 if and only if Gz∗ ̸= Gz .
Case 0: Gz∗ = Gz . Clearly, Az = 0 in this case.
Case 1: Gz∗ ⊊ Gz . Then, |Gz | > |Gz∗ | and thus Az > 0 since λ(|Gz | − |Gz∗ |) > 0.
Case 2: Gz∗ ̸⊆ Gz . In this case, there exists (i → j) ∈ Gz∗ such that (i → j) ∈/ Gz . Thus, Sj′ ⊥
⊥ Gz Xi | X \ {Xi }
and Sj′ ̸⊥
⊥Gz∗ Xi | X \ {Xi }. Therefore, Sj′ ̸⊥
⊥p Xi | X \ {Xi }, Ez by Assumption 2. Thus, p ∈
/ FEz (Gz ) and we have
inf ϕz DKL (pz ∥ p̂Gz ,ϕz ) > 0 by Lemma 3.
′ ∗ ′ ′ ∗ − ′ ∗ ′ ′
Now, we consider two subcases: (i) Gz ∈ G+ z := {G | Gz ̸⊆ G , |G | ≥ |Gz |}, and (ii) Gz ∈ Gz := {G | Gz ̸⊆ G , |G | <
∗ + −
|Gz |}. Clearly, if Gz ∈ Gz then Az > 0. Suppose Gz ∈ Gz . Then,
1
λ ≤ ηz := min inf DKL (pz ∥ p̂G ′ ,ϕz ) (31)
N (N + M ) + 1 G ′ ∈G− z
ϕz
inf ϕz DKL (pz ∥ p̂G ′ ,ϕz ) inf ϕz DKL (pz ∥ p̂G ′ ,ϕz )
=⇒ λ ≤ < for ∀G ′ ∈ G−
z (32)
N (N + M ) + 1 |Gz∗ | − |G ′ |
∗ ′ −
=⇒ inf DKL (pz ∥ p̂G ′ ,ϕz ) + λ(|Gz | − |Gz |) > 0 for ∀G ∈ Gz . (33)
ϕz
Here, we use the fact that |Gz∗ | − |G ′ | ≤ |Gz∗ | < N (N + M ) + 1. Therefore, for 0 < ∀λ ≤ ηz , we have Az > 0 if Gz∗ ̸= Gz .
Here, we note that ηz > 0 for all z, since G− ′ −
z is finite and inf ϕz DKL (pz ∥ p̂G ′ ,ϕz ) > 0 for any G ∈ Gz by Lemma 3.
Recall that Thm. 1 holds for 0 < λ ≤ η({Ez }). Here, η({Ez }) is the value corresponding to the specific decomposition
{Ez }. For the arguments henceforth, we consider the arbitrary decomposition and thus introduce the following assumption.
Assumption 5. inf {Ez }∈T η({Ez }) > 0.
Note that η({Ez }) > 0 for any {Ez }, and thus inf {Ez }∈T η({Ez }) ≥ 0. We now take 0 < λ ≤ inf {Ez }∈T η({Ez }) with
Assumption 5, which allows Thm. 1 to hold on any arbitrary decomposition. It is worth noting that this assumption is purely
technical because for a small fixed λ > 0, the arguments henceforth hold for all {Ez } ∈ Tλ , where Tλ → T as λ → 0.
18
Fine-Grained Causal Dynamics Learning
Proof. Let 0 < λ ≤ inf {Ez }∈T η({Ez }). (i) First, {Gz∗ , Ez∗ }K K ∗ K
z=1 ∈ argmax{Gz ,Ez } S({Gz , Ez }z=1 ) implies that {Gz }z=1
∗ K ∗ K ∗ K ∗
also maximizes the score on the fixed {Ez }z=1 , i.e., {Gz }z=1 ∈ argmax{Gz } S({Gz , Ez }z=1 ). Thus, each Gz is true LCG
on Ez∗ by Thm. 1, i.e., Gz∗ = GEz∗ .
∗ ∗
(ii) Also, since {Ez }K
z=1 is the arbitrary decomposition, S({Gz , Ez }) ≥ S({Gz , Ez }) holds. Since {Gz } is the true LCGs on
each Ez , i.e., Gz = GEz , by Lemma 1,
X
′
S({Gz , Ez }Kz=1 ) = Ep(s,a,s′ ) log p(s | s, a) − λ p(Ez ) · |Gz |. (34)
z
Therefore, 0 ≤ S({Gz∗ , Ez∗ }) − S({Gz , Ez }) = E |Gz | − E |Gz∗ | holds, and thus E |Gz∗ | ≤ E |Gz | .
X X
0 ≤ E |Gz | − E |Gz∗ | = p(Ej∗ )|Gj∗ |
p(Fi )|Gi | −
i j
X
= p(Fi ∩ Ej∗ )(|Gi | − |Gj∗ |). (36)
i,j
Suppose p(Fi ∩ Ej∗ ) > 0 for some i, j. Let Cij := Fi ∩ Ej∗ . Since Fi ⊂ Dm for some m and Dm is canonical, Fi is also
canonical. Therefore, Gi = GCij since Cij ⊂ Fi . Since Cij ⊂ Ej∗ , we have GCij ⊆ Gj∗ by Lemma 5. Therefore, we have
Gi ⊆ Gj∗ . Therefore, |Gi | − |Gj∗ | ≤ 0 for any i, j such that p(Fi ∩ Ej∗ ) > 0. Thus, by Eq. (36), |Gi | = |Gj∗ | if p(Fi ∩ Ej∗ ) > 0.
Since Gi ⊆ Gj∗ if p(Fi ∩ Ej∗ ) > 0, we conclude that
Now, for arbitrary Ej∗ , suppose there exist s ̸= t such that p(Ds ∩ Ej∗ ) > 0 and p(Dt ∩ Ej∗ ) > 0. Then, there exist some
Fi ⊂ Ds and Fk ⊂ Dt such that p(Fi ∩ Ej∗ ) > 0 and p(Fk ∩ Ej∗ ) > 0. By Eq. (37), we have Gi = Gj∗ = Gk . Also,
GDs = Gi and GDt = Gk since Ds , Dt are canonical. Therefore, we have GDs = GDt , which contradicts that GDm is distinct
19
Fine-Grained Causal Dynamics Learning
Root node
{ , , }
? ?
Full
? ?
Root node
{ , }
Fork
Figure 8. Illustration of C HEMICAL (full-fork) environment with 4 nodes. (Left) the color of the root node determines the activation of
local causal graph fork. (Right) the noisy nodes are redundant for predicting the colors of other nodes under the local causal graph.
for all m. Therefore, for any Ej∗ , there exists a unique Dm such that p(Dm ∩ Ej∗ ) > 0, which leads p(Ej∗ \ Dm ) = 0 since
{Dm }m∈[H] is a decomposition of X . Let Im = {j ∈ [K] | p(Dm ∩ Ej∗ ) > 0}. Here, we have
[ X
Ez∗ \ Dm = p Ez∗ \ Dm = 0.
p (38)
z∈Im z∈Im
C.1.1. C HEMICAL
Here, we describe two settings, namely full-fork and full-chain, modified from Ke et al. (2021). In both settings, there are 10
state variables representing the color of corresponding nodes, with each color represented as a one-hot encoding. The action
variable is a 50-dimensional categorical variable that changes the color of a specific node to a new color (e.g., changing
the color of the third node to blue). According to the underlying causal graph and pre-defined conditional probability
distributions, implemented with randomly initialized neural networks, an action changes the colors of the intervened object’s
descendants as depicted in Fig. 8. As shown in Fig. 3(a), the (global) causal graph is full in both settings, and the LCG is
20
Fine-Grained Causal Dynamics Learning
fork and chain, respectively. For example in full-fork, the LCG fork is activated according to the particular color of the root
node, as shown in Fig. 8.
In both settings, the task is to match the colors of each node to the given target. The reward function is defined as:
1 X
r= 1 [si = gi ] , (40)
|O|
i∈O
where O is a set of the indices of observable nodes, si is the current color of the i-th node, and gi is the target color of the
i-th node in this episode. Success is determined if all colors of observable nodes are the same as the target. During training,
all 10 nodes are observable, i.e., O = {0, · · · , 9}. In downstream tasks, the root color is set to induce the LCG, and the
agent receives noisy observations for a subset of nodes, aiming to match the colors of the rest of the observable nodes. As
shown in Fig. 8, noisy nodes are spurious for predicting the colors of other nodes under the LCG. Thus, the agent capable of
reasoning the fine-grained causal relationships would generalize well in downstream tasks. Note that the transition dynamics
of the environment is the same in training and downstream tasks. To create noisy observations, we use a noise sampled from
N (0, σ 2 ), similar to Wang et al. (2022), where the noise is multiplied to the one-hot encoding representing color during the
test. In our experiments, we use σ = 100.
As the root color determines the local causal graph in both settings, the root node is always observable to the agent during
the test. The root colors of the initial state and the goal state are the same, inducing the local causal graph. As the root color
can be changed by the action during the test, this may pose a challenge in evaluating the agent’s reasoning of local causal
relationships. This can be addressed by modifying the initial distribution of CEM to exclude the action on the root node and
only act on the other nodes during the test. Nevertheless, we observe that restricting the action on the root during the test
has little impact on the behavior of any model, and we find that this is because the agent rarely changes the root color as it
already matches the goal color in the initial state.
C.1.2. M AGNETIC
In this environment, there are two objects on a table, a moving ball and a box, colored either red or black, as shown in
Fig. 3(b). The red color indicates that the object is magnetic. In other words, when they are both colored red, magnetic force
will be applied and the ball will move toward the box. If one of the objects is colored black, the ball would not move since
the box has no influence on the ball.
The state consists of the color, x, y position of each object, and x, y, z position of the end-effector of the robot arm, where
the color is given as the 2-dimensional one-hot encoding. The action is a 3-dimensional vector that moves the robot arm.
The causal graph of the Magnetic environment is shown in Fig. 9(a). LCGs under magnetic and non-magnetic context are
shown in Figs. 9(b) and 9(c), respectively. The table in our setup has a width of 0.9 and a length of 0.6, with the y-axis
defined by the width and the x-axis defined by the length. For each episode, the initial positions of a moving ball and a box
are randomly sampled within the range of the table.
The task is to move the robot arm to reach the moving ball. Thus, accurately predicting the trajectory of the ball is crucial.
The reward function is defined as:
r = 1 − tanh(5 · ∥eef − g∥1 ), (41)
21
Fine-Grained Causal Dynamics Learning
where the eef ∈ R3 is the current position of the end-effector, g = (bx , by , 0.8) ∈ R3 , and (bx , by ) is the current position of
the moving ball. Success is determined if the distance is smaller than 0.05. During the test, the color of one of the objects
is black and the box is located at the position unseen during the training. Specifically, the box position is sampled from
N (0, σ 2 ) during the test. Note that the box can be located outside of the table, which never happens during the training. In
our experiments, we use σ = 100.
22
Fine-Grained Causal Dynamics Learning
s0
<latexit sha1_base64="rLydIPTJEtYRfoZpDbg2DKGKs/w=">AAAB6XicbVDLSgNBEOyNrxhfUY9eBoPoKeyKr2PAi8co5gHJEmYnvcmQ2dllZlYIS/7AiwdFvPpH3vwbJ8keNLGgoajqprsrSATXxnW/ncLK6tr6RnGztLW9s7tX3j9o6jhVDBssFrFqB1Sj4BIbhhuB7UQhjQKBrWB0O/VbT6g0j+WjGSfoR3QgecgZNVZ60Ke9csWtujOQZeLlpAI56r3yV7cfszRCaZigWnc8NzF+RpXhTOCk1E01JpSN6AA7lkoaofaz2aUTcmKVPgljZUsaMlN/T2Q00nocBbYzomaoF72p+J/XSU1442dcJqlByeaLwlQQE5Pp26TPFTIjxpZQpri9lbAhVZQZG07JhuAtvrxMmudV76p6eX9RqVXzOIpwBMdwBh5cQw3uoA4NYBDCM7zCmzNyXpx352PeWnDymUP4A+fzBzxjjR4=</latexit>
Auxiliary Dynamics
Network Model
s0
<latexit sha1_base64="rLydIPTJEtYRfoZpDbg2DKGKs/w=">AAAB6XicbVDLSgNBEOyNrxhfUY9eBoPoKeyKr2PAi8co5gHJEmYnvcmQ2dllZlYIS/7AiwdFvPpH3vwbJ8keNLGgoajqprsrSATXxnW/ncLK6tr6RnGztLW9s7tX3j9o6jhVDBssFrFqB1Sj4BIbhhuB7UQhjQKBrWB0O/VbT6g0j+WjGSfoR3QgecgZNVZ60Ke9csWtujOQZeLlpAI56r3yV7cfszRCaZigWnc8NzF+RpXhTOCk1E01JpSN6AA7lkoaofaz2aUTcmKVPgljZUsaMlN/T2Q00nocBbYzomaoF72p+J/XSU1442dcJqlByeaLwlQQE5Pp26TPFTIjxpZQpri9lbAhVZQZG07JhuAtvrxMmudV76p6eX9RqVXzOIpwBMdwBh5cQw3uoA4NYBDCM7zCmzNyXpx352PeWnDymUP4A+fzBzxjjR4=</latexit>
(s, a) (s, a)
Decoder
Encoder
Dynamics
Model
(a) Local Causal Graph Inference (b) Masked Prediction with LCG
fi
fi fi
fi
Figure 10. Comparison of the sample-specific inference of NCD (top) and quantization-based inference of our method (bottom).
C.4.2. BACKPROPAGATION
We now describe how each component of our method is updated by the training objective in Eq. (6). First, Lpred updates
the encoder genc (s, a), decoder gdec (e), and the dynamics model p̂. Recall that A ∼ gdec (e), backpropagation from A in
Lpred updates the quantization decoder gdec through e. During the backward path in Eq. (5), gradients are copied from e (=
input of gdec ) to h (= output of genc ), following VQ-VAE (Van Den Oord et al., 2017). By doing so, Lpred also updates the
quantization encoder genc and h. Second, Lquant updates genc and the codebook C. We note that Lpred also affects the
learning of the codebook C since h is updated with Lpred . The rationale behind this trick of VQ-VAE is that the gradient
∇e Lpred could guide the encoder genc to change its output h = genc (s, a) to lower the prediction loss Lpred , altering the
23
Fine-Grained Causal Dynamics Learning
Figure 11. (a,b) Codebook histogram on (a) ID states during training and (b) OOD states during the test in Chemical (full-fork). (c) True
causal graph of the fork structure. (d) Learned LCG corresponding to the most used code in (b).
quantization (i.e., assignment of the cluster) in the next forward pass. A larger prediction loss (which implies that this
sample (s, a) is assigned to the wrong cluster) induces a bigger change on h, and consequently, it would be more likely to
cause a re-assignment of the cluster.
C.4.3. H YPERPARAMETERS
For all experiments, we fix the codebook size K = 16, regularization coefficient λ = 0.001, and commitment coefficient
β = 0.25, as we found that the performance did not vary much for any K > 2, λ ∈ {10−4 , 10−3 , 10−2 } and β ∈ {0.1, 0.25}.
24
Fine-Grained Causal Dynamics Learning
We observe that the training of latent codebook with vector quantization is often unstable when K = 2. We demonstrate the
success (Fig. 17) and failure (Fig. 18) cases of our method with a quantization degree of 2. In a failure case, we observe that
the embeddings frequently fluctuate between the two codes, resulting in both codes corresponding to the global causal graph
and failing to capture the LCG, as shown in Fig. 18.
25
Fine-Grained Causal Dynamics Learning
(a) (b)
Figure 16. More fine-grained LCGs learned by our method with quantization degree of 16 in Magnetic.
D. Additional Discussions
D.1. Difference from Sample-based Inference
Sample-based inference methods, e.g., NCD (Hwang et al., 2023) for LCG or ACD (Löwe et al., 2022) for CG, can be seen
as learning causal graphs with gated edges. They learn a function that maps each sample to the adjacency matrix where
each entry is the binary variable indicating whether the corresponding edge is on or off under the current state. The critical
difference from ours is that LCGs learned from sample-based inference methods are unbounded and blackbox.
Specifically, it is hard to understand which local structures and contexts are identified since they can only be examined by
observing the inference outcome from all samples (i.e., blackbox). Also, there is no (practical or theoretical) guarantee that
26
Fine-Grained Causal Dynamics Learning
it outputs the same graph from the states within the same context, since the output of the function is unbounded. In contrast,
our method learns a finite set of LCGs where the contexts are explicitly identified by latent clustering. In other words, the
outcome is bounded (infers one of the K graphs) and the contexts are more interpretable.
For the robustness of the model and principled understanding of the fine-grained structures, the practical or theoretical
guarantee and interpretability are crucial, and we demonstrate the improved robustness of our method compared to prior
sample-based inference methods. On the other hand, sample-based inference or local edge switch methods have strength in
their simple design and efficiency, and it is known that the signals from such local edge switch enhance exploration in RL
(Seitzer et al., 2021; Wang et al., 2023). For the practitioners, the choice would depend on their purpose, e.g., whether their
primary interest is on the robustness and principled understanding of the fine-grained structures.
27
Fine-Grained Causal Dynamics Learning
28
Fine-Grained Causal Dynamics Learning
episode reward
episode reward
15 15
5
10 10
5 5
0
0 0.5 1.0 1.5 0 0.5 1.0 1.5 0.5 1.0 1.5 2.0
5 5 5
Environment step (x10 ) Environment step (x10 ) Environment step (x10 )
Figure 19. Learning curves during training as measured by the episode reward.
episode reward
episode reward
15 15 15
10 10 10
5 5 5
success ratio
success ratio
0 0 0
0 0.5 1.0 1.5 0 0.5 1.0 1.5 0 0.5 1.0 1.5
5 5 5
Environment step (x10 ) Environment step (x10 ) Environment step (x10 )
Figure 20. Learning curves on downstream tasks in Chemical (full-fork) as measured on the episode reward (top) and success rate
(bottom).
15 15 15
episode reward
episode reward
episode reward
10 10 10
5 5 5
0 0.5 1.0 1.5 0 0.5 1.0 1.5 0 0.5 1.0 1.5
5 5 5
Environment step (x10 ) Environment step (x10 ) Environment step (x10 )
0.6 0.6 1.0
0.8
success ratio
success ratio
success ratio
0.4 0.4
0.6
0.4
0.2 0.2
0.2
0 0 0
0 0.5 1.0 1.5 0 0.5 1.0 1.5 0 0.5 1.0 1.5
5 5 5
Environment step (x10 ) Environment step (x10 ) Environment step (x10 )
Figure 21. Learning curves on downstream tasks in Chemical (full-chain) as measured on the episode reward (top) and success rate
(bottom).
29