Fine-Grained Causal Dynamics Learning With Quantization For Improving Robustness in Reinforcement Learning

Download as pdf or txt
Download as pdf or txt
You are on page 1of 29

Fine-Grained Causal Dynamics Learning with Quantization

for Improving Robustness in Reinforcement Learning

Inwoo Hwang 1 Yunhyeok Kwak 1 Suhyung Choi 1 Byoung-Tak Zhang 1 Sanghack Lee 1 2

Abstract
S ⇥A
<latexit sha1_base64="m9iHAv9KiC4kj6ms0spC2L8VFQg=">AAACBXicbVDLSsNAFJ34rPUVdamLYBFclUTqY1lx47KifUATymQ6aYdOJmHmRighGzf+ihsXirj1H9z5N07aINp64MLhnHu59x4/5kyBbX8ZC4tLyyurpbXy+sbm1ra5s9tSUSIJbZKIR7LjY0U5E7QJDDjtxJLi0Oe07Y+ucr99T6VikbiDcUy9EA8ECxjBoKWeeeCGGIYE8/Q2c4GFVP0Il1nPrNhVewJrnjgFqaACjZ756fYjkoRUAOFYqa5jx+ClWAIjnGZlN1E0xmSEB7SrqcB6n5dOvsisI630rSCSugRYE/X3RIpDpcahrzvzE9Wsl4v/ed0EggsvZSJOgAoyXRQk3ILIyiOx+kxSAnysCSaS6VstMsQSE9DBlXUIzuzL86R1UnXOqqc3tUq9VsRRQvvoEB0jB52jOrpGDdREBD2gJ/SCXo1H49l4M96nrQtGMbOH/sD4+AZ9mJkt</latexit>

Causal dynamics learning has recently emerged as


arXiv:2406.03234v1 [cs.LG] 5 Jun 2024

a promising approach to enhancing robustness in


reinforcement learning (RL). Typically, the goal
is to build a dynamics model that makes predic-
tions based on the causal relationships among
the entities. Despite the fact that causal connec- (a) (b) (c)
tions often manifest only under certain contexts,
existing approaches overlook such fine-grained Figure 1. (a) Previous causal dynamics models infer the global
relationships and lack a detailed understanding of causal structure of the transition dynamics. (b) Existing approaches
the dynamics. In this work, we propose a novel to discovering fine-grained relationships examine individual sam-
dynamics model that infers fine-grained causal ples. (c) Our approach quantizes the state-action space into sub-
groups and infers causal relationships specific to each subgroup.
structures and employs them for prediction, lead-
ing to improved robustness in RL. The key idea
is to jointly learn the dynamics model with a dis- causal relationships between the environmental variables,
crete latent variable that quantizes the state-action such as objects and the agent, into dynamics learning (Wang
space into subgroups. This leads to recognizing et al., 2022; Ding et al., 2022). Unlike the traditional dense
meaningful context that displays sparse dependen- models that employ the whole state and action variables
cies, where causal structures are learned for each to predict the future state, causal dynamics models infer
subgroup throughout the training. Experimental the causal structure of the transition dynamics and make
results demonstrate the robustness of our method predictions based on it. Consequently, they are more robust
to unseen states and locally spurious correlations to unseen states by discarding spurious dependencies.
in downstream tasks where fine-grained causal
reasoning is crucial. We further illustrate the ef- Our motivation stems from the observation that causal con-
fectiveness of our subgroup-based approach with nections often manifest only under certain contexts in many
quantization in discovering fine-grained causal practical scenarios. Consider autonomous driving, where
relationships compared to prior methods. recognizing the traffic signal is crucial for its safety (e.g.,
stops at red lights). However, in the presence of a pedestrian
on the road, it must stop, even with a green light, ignoring
the signal, i.e., the traffic signal becomes locally spurious.
1. Introduction
Therefore, such fine-grained causal reasoning will be crucial
Model-based reinforcement learning (MBRL) has show- to the robustness of MBRL for its real-world deployment.
cased its capability of solving various sequential decision
Fine-grained causal relationships can be understood with lo-
making problems (Kaiser et al., 2020; Schrittwieser et al.,
cal independence between the variables, which holds under
2020). Since learning an accurate and robust dynamics
certain contexts but does not hold in general (Boutilier et al.,
model is crucial in MBRL, recent works incorporate the
2013). Our goal is to incorporate them into dynamics mod-
1
AI Institute, Seoul National University 2 Graduate School eling by capturing meaningful contexts that exhibit more
of Data Science, Seoul National University. Correspondence sparse dependencies than the entire domain. Unfortunately,
to: Sanghack Lee <sanghack@snu.ac.kr>, Byoung-Tak Zhang prior causal dynamics models examining global indepen-
<btzhang@snu.ac.kr>. dence (Fig. 1-(a)) cannot harness them. On the other hand,
Proceedings of the 41 st International Conference on Machine existing methods for discovering fine-grained relationships
Learning, Vienna, Austria. PMLR 235, 2024. Copyright 2024 by have focused on examining sample-specific dependencies
the author(s). (Pitis et al., 2020; Hwang et al., 2023) (Fig. 1-(b)). However,

1
Fine-Grained Causal Dynamics Learning

it is unclear under which circumstances the inferred depen- a tuple V, U, F, P (U) , where V = {X1 , · · · , Xd } is a
dencies hold, making them hard to interpret and challenging set of endogenous variables and U is a set of exogenous
to generalize to unseen states. variables. A set of functions F = {f1 , · · · , fd } determine
how each variable is generated; Xj = fj (P a(j), Uj ) where
In this work, we propose a dynamics model that infers fine-
P a(j) ⊆ V \{Xj } is parents of Xj and Uj ⊆ U. An SCM
grained causal structures and employs them for prediction,
M induces a directed acyclic graph (DAG) G = (V, E),
leading to improved robustness in MBRL. For this, we es-
i.e., a causal graph (CG) (Peters et al., 2017), where
tablish a principled way to examine fine-grained causal rela-
V = {1, · · · , d} and E ⊆ V × V are the set of nodes
tionships based on the quantization of the state-action space.
and edges, respectively. Each edge denotes a direct causal
Importantly, this provides a clear understanding of mean-
relationship from Xi to Xj . An SCM entails the condi-
ingful contexts displaying sparse dependencies (Fig. 1-(c)).
tional independence relationship of each variable (namely,
However, this involves the optimization of the regularized
local Markov property): Xi ⊥ ⊥ N D(Xi ) | P a(Xi ), where
maximum likelihood score over the quantization which is
N D(Xi ) is a non-descendant of Xi , which can be read off
generally intractable. To this end, we present a practical dif-
from the corresponding causal graph.
ferentiable method that jointly learns the dynamics model
and a discrete latent variable that decomposes the state- Factored Markov Decision Process. A Markov Decision
action space into subgroups by utilizing vector quantization Process (MDP) (Sutton & Barto, 2018) is defined as a tuple
(Van Den Oord et al., 2017). Theoretically, we show that ⟨S, A, T, r, γ⟩ where S is a state space, A is an action space,
joint optimization leads to identifying meaningful contexts T : S × A → P(S) is a transition dynamics, r is a reward
and fine-grained causal structures. function, and γ is a discount factor. We consider a factored
MDP (Kearns & Koller, 1999) where the state and action
We evaluate our method on both discrete and continuous
spaces are factorized as S = S1 × · · · × SN and A =
control environments where fine-grained causal reasoning
A1 × · · · × AQ M . A transition dynamics is factorized as
is crucial. Experimental results demonstrate the superior
p(s′ | s, a) = j p(s′j | s, a) where s = (s1 , · · · , sN ) and
robustness of our approach to locally spurious correlations
a = (a1 , · · · , aM ).
and unseen states in downstream tasks compared to prior
causal/non-causal approaches. Finally, we illustrate that our Assumptions and notations. We are concerned with an
method infers fine-grained relationships in a more effective SCM associated with the transition dynamics in a factored
and robust manner compared to sample-specific approaches. MDP where the states are fully observable. To properly
identify the causal relationships, we make assumptions stan-
Our contributions are summarized as follows.
dard in the field, namely, Markov property (Pearl, 2009),
faithfulness (Peters et al., 2017), causal sufficiency (Spirtes
• We establish a principled way to examine fine-grained
et al., 2000), and that causal connections only appear within
causal relationships based on the quantization of the
consecutive time steps. With these assumptions, we con-
state-action space which offers an identifiability guar-
sider a bipartite causal graph G = (V, E) which consists
antee and better interpretability.
of the set of nodes V = X ∪ Y and the set of edges
• We present a theoretically grounded and practical ap- E ⊆ X × Y, where X = {S1 , · · · , SN , A1 , · · · , AM }
proach to dynamics learning that infers fine-grained and Y = {S1′ , · · · , SN′
}. P a(j) denotes parent variables

causal relationships by utilizing vector quantization. of Sj . The conditional independence

• We empirically demonstrate that the agent capable of Sj′ ⊥


⊥ X \ P a(j) | P a(j), (1)
fine-grained causal reasoning is more robust to locally
entailed by the causal graph G represents the
Q causal structure
spurious correlations and generalizes well to unseen
of the transition dynamics p(s′ | s, a) = j p(s′j | P a(j)).
states compared to past causal/non-causal approaches.
Dynamics modeling. The traditionalQway is to use dense de-
2. Preliminaries pendencies for dynamics modeling: j p(s′j | s, a). Causal
dynamics models (Wang et al., 2021; 2022; Ding et al.,
We first introduce the notations and terminologies. Then, 2022) examine the causal structure G toQemploy only rel-
we examine related works on causal dynamics learning for evant dependencies: p(s′ | s, a; G) = j p(s′j | P a(j))
RL and fine-grained causal relationships. (Fig. 1-(a)). Consequently, they are more robust to spurious
correlations and unseen states.
2.1. Background
2.2. Related Work
Structural causal model. We adopt a framework of a struc-
tural causal model (SCM) (Pearl, 2009) to understand the Causal dynamics models in RL. There is a growing
relationship among variables. An SCM M is defined as body of literature on the intersection of causality and RL

2
Fine-Grained Causal Dynamics Learning

(De Haan et al., 2019; Buesing et al., 2019; Zhang et al., We then provide a principled way to examine fine-grained
2020a; Sontakke et al., 2021; Schölkopf et al., 2021; Zholus causal relationships (Sec. 3.2). Based on this, we propose a
et al., 2022; Zhang et al., 2020b). One focus is dynamics theoretically grounded and practical method for fine-grained
learning, which involves the causal structure of the transi- causal dynamics modeling (Sec. 3.3). Finally, we provide a
tion dynamics (Li et al., 2020; Yao et al., 2022; Bongers theoretical analysis with discussions (Sec. 3.4). All omitted
et al., 2018; Wang et al., 2022; Ding et al., 2022; Feng et al., proofs are provided in Appendix B.
2022; Huang et al., 2022) (more broad literature on causal
reasoning in RL is discussed in Appendix A.1). Recent 3.1. Preliminary
works proposed causal dynamics models that make robust
predictions based on the causal dependencies (Fig. 1-(a)), Analogous to the conditional independence explaining the
utilizing conditional independence tests (Ding et al., 2022) causal relationship between the variables (i.e., Eq. (1)), their
or conditional mutual information (Wang et al., 2022) to fine-grained relationships can be understood with local in-
infer the causal graph in a factored MDP. However, prior dependence (Hwang et al., 2023). This is written as:
methods cannot harness fine-grained causal relationships
Sj′ ⊥
⊥ X \ P a(j; D) | P a(j; D), D, (2)
that provide a more detailed understanding of the dynamics.
In contrast, our work aims to discover and incorporate them where D ⊆ X = S × A is a local subset of the joint state-
into dynamics modeling, demonstrating that fine-grained action space, which we say context, and P a(j; D) ⊆ X
causal reasoning leads to improved robustness in MBRL. is a set of state and action variables locally relevant for
Discovering fine-grained causal relationships. In the con- predicting Sj′ under D. We provide a formal definition and
text of RL, a fine-grained structure of the environment dy- detailed background of local independence in Appendix A.3.
namics has been leveraged in various ways, e.g., with data For example, consider a mobile home robot interacting with
augmentation (Pitis et al., 2022), efficient planning (Hoey various objects (P a(j)). Under the context of the door
et al., 1999; Chitnis et al., 2021), or exploration (Seitzer closed (D), only objects within the same room (P a(j; D))
et al., 2021). For this, previous works often exploited do- become relevant. On the other hand, all objects remain rel-
main knowledge (Pitis et al., 2022) or true dynamics model evant under the context of the door opened. We say that
(Chitnis et al., 2021). However, such prior knowledge is of- a context is meaningful if it displays sparse dependencies:
ten unavailable in the context of dynamics learning. Existing P a(j; D) ⊊ P a(j), e.g., door closed. We are concerned
methods for discovering fine-grained relationships examine with the subgraph of the (global) causal graph G as a graph-
the gradient (Wang et al., 2023) or attention score (Pitis ical representation of such local dependencies.
et al., 2020) of each sample (Fig. 1-(b)). However, such
Definition 1. Local subgraph of the causal graph1 (LCG)
sample-specific approaches lack an understanding of under
on D ⊆ X is GD = (V, ED ) where ED = {(i, j) | i ∈
which circumstances the inferred dependencies hold, and it
P a(j; D)}.
is unclear whether they can generalize to unseen states.
In the field of causality, fine-grained causal relationships LCG GD represents a causal structure of the transition dy-
have been widely studied, e.g., context-specific indepen- namics specific to a certain context D. It is useful for our
dence (Boutilier et al., 2013; Zhang & Poole, 1999; Poole, approach to fine-grained dynamics modeling, e.g., it is suf-
1998; Dal et al., 2018; Tikka et al., 2019; Jamshidi et al., ficient to consider only objects in the same room when the
2023) (see Appendix A.2 for the background). Recently, door is closed. In contrast, prior causal dynamics models
Hwang et al. (2023) proposed an auxiliary network that ex- consider all objects under any circumstances (Fig. 1-(a)).
amines local independence for each sample. However, it Importantly, such information (e.g., D and GD ) is not known
also does not explicitly capture the context where the local in advance, and it is our goal to discover them. For this, ex-
independence holds. In contrast to existing approaches re- isting sample-specific approaches have focused on inferring
lying on sample-specific inference (Löwe et al., 2022; Pitis LCG directly from individual samples (Pitis et al., 2020;
et al., 2020; Hwang et al., 2023), we propose to examine Hwang et al., 2023) (Fig. 1-(b)). However, it is unclear
causal dependencies at a subgroup level through quantiza- under which context the inferred dependencies hold.
tion (Fig. 1-(c)), providing a more robust and principled
way of discovering fine-grained causal relationships with a Our approach is to quantize the state-action space into sub-
theoretical guarantee. groups and examine causal structures on each subgroup
(Fig. 1-(c)). This now makes it clear that each inferred LCG
will represent fine-grained causal relationships under the
3. Fine-Grained Causal Dynamics Learning corresponding subgroup. We now proceed to describe a
In this section, we first describe a brief background on principled way to discover LCGs based on quantization.
local independence and intuition of our approach (Sec. 3.1). 1
For brevity, we will henceforth denote it as local causal graph.

3
Fine-Grained Causal Dynamics Learning

Quantization Codebook Local causal graphs


State or Action Variables
Masked Variables

s0
<latexit sha1_base64="rLydIPTJEtYRfoZpDbg2DKGKs/w=">AAAB6XicbVDLSgNBEOyNrxhfUY9eBoPoKeyKr2PAi8co5gHJEmYnvcmQ2dllZlYIS/7AiwdFvPpH3vwbJ8keNLGgoajqprsrSATXxnW/ncLK6tr6RnGztLW9s7tX3j9o6jhVDBssFrFqB1Sj4BIbhhuB7UQhjQKBrWB0O/VbT6g0j+WjGSfoR3QgecgZNVZ60Ke9csWtujOQZeLlpAI56r3yV7cfszRCaZigWnc8NzF+RpXhTOCk1E01JpSN6AA7lkoaofaz2aUTcmKVPgljZUsaMlN/T2Q00nocBbYzomaoF72p+J/XSU1442dcJqlByeaLwlQQE5Pp26TPFTIjxpZQpri9lbAhVZQZG07JhuAtvrxMmudV76p6eX9RqVXzOIpwBMdwBh5cQw3uoA4NYBDCM7zCmzNyXpx352PeWnDymUP4A+fzBzxjjR4=</latexit>

<latexit sha1_base64="7o47/186CuYTcMwNodu2vFO8Xng=">AAAB7XicbVDLSgNBEOyNrxhfUY9eBoMQQZZd8XUMePEYwTwgWcLsZDYZMzuzzMwKYck/ePGgiFf/x5t/4yTZgyYWNBRV3XR3hQln2njet1NYWV1b3yhulra2d3b3yvsHTS1TRWiDSC5VO8SaciZowzDDaTtRFMchp61wdDv1W09UaSbFgxknNIjxQLCIEWys1KzqM4RPe+WK53ozoGXi56QCOeq98le3L0kaU2EIx1p3fC8xQYaVYYTTSambappgMsID2rFU4JjqIJtdO0EnVumjSCpbwqCZ+nsiw7HW4zi0nTE2Q73oTcX/vE5qopsgYyJJDRVkvihKOTISTV9HfaYoMXxsCSaK2VsRGWKFibEBlWwI/uLLy6R57vpX7uX9RaXm5nEU4QiOoQo+XEMN7qAODSDwCM/wCm+OdF6cd+dj3lpw8plD+APn8wcZR44d</latexit> <latexit sha1_base64="7o47/186CuYTcMwNodu2vFO8Xng=">AAAB7XicbVDLSgNBEOyNrxhfUY9eBoMQQZZd8XUMePEYwTwgWcLsZDYZMzuzzMwKYck/ePGgiFf/x5t/4yTZgyYWNBRV3XR3hQln2njet1NYWV1b3yhulra2d3b3yvsHTS1TRWiDSC5VO8SaciZowzDDaTtRFMchp61wdDv1W09UaSbFgxknNIjxQLCIEWys1KzqM4RPe+WK53ozoGXi56QCOeq98le3L0kaU2EIx1p3fC8xQYaVYYTTSambappgMsID2rFU4JjqIJtdO0EnVumjSCpbwqCZ+nsiw7HW4zi0nTE2Q73oTcX/vE5qopsgYyJJDRVkvihKOTISTV9HfaYoMXxsCSaK2VsRGWKFibEBlWwI/uLLy6R57vpX7uX9RaXm5nEU4QiOoQo+XEMN7qAODSDwCM/wCm+OdF6cd+dj3lpw8plD+APn8wcZR44d</latexit>

(s, a) (s, a)

Decoder
Encoder
Dynamics
Model

(a) Local Causal Graph Inference (b) Masked Prediction with LCG

Figure 2. Overall framework. (a) For each sample (s, a), our method determines the subgroup to which the sample belongs through
quantization and infers the local causal graph (LCG) that represents fine-grained causal relationships specific to the corresponding
subgroup. (b) The dynamics model predicts the future state based on the inferred LCG. All components (e.g., dynamics model and
codebook) are jointly learned throughout the training in an end-to-end manner.

3.2. Score for Decomposition and Graphs decomposition is now also a learning objective towards max-
imizing Eq. (4), i.e., {Gz∗ , Ez∗ } ∈ argmax S({Gz , Ez }K
z=1 ).
Let us consider arbitrary decomposition {Ez }K z=1 of the However, a naive optimization with respect to decomposi-
state-action space X , where K is the degree of the quantiza-
tion is generally intractable. Thus, we devise a practical
tion. The transition dynamics can be decomposed as:
method allowing joint training with the dynamics model.
X
p(s′j | s, a) = p(s′j | s, a, z)p(z | s, a)
3.3. Fine-Grained Causal Dynamics Learning with
z
X Quantization
= p(s′j | P a(j; Ez ), z)p(z | s, a), (3)
z We propose a practical differentiable method that allows
joint optimization of Eq. (4) over dynamics model p̂, decom-
where p(z | s, a) = 1 if (s, a) ∈ Ez . This illustrates our position {Ez }, and graphs {Gz }, in an end-to-end manner.
approach to fine-grained dynamics modeling, employing The key component is a discrete latent codebook C = {ez }
only locally relevant dependencies according to GEz on each where each code ez represents the pair of a subgroup Ez and
subgroup Ez . We now aim to learn each LCG GEz based on a graph Gz . The codebook learning is differentiable, and
Eq. (3). Specifically, we consider the regularized maximum these pairs will be learned throughout the training with the
likelihood score S({Gz , Ez }K K
z=1 ) of the graphs {Gz }z=1 and dynamics model. The overall framework is shown in Fig. 2.
decomposition {Ez }K z=1 which is defined as:
Quantization. The encoder genc maps each sample (s, a)
sup Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz , Ez }, ϕ) − λ|Gz | , (4) into a latent embedding h, which is then quantized to the
 
ϕ nearest prototype vector e (i.e., code) in the codebook C =
{e1 , · · · , eK }, following Van Den Oord et al. (2017):
where ϕ is the parameters of the dynamics model p̂ which
employs the graph Gz for prediction on corresponding sub- e = ez , where z = argmin ∥h − ej ∥2 . (5)
group Ez . We now show that graphs that maximize the score j∈[K]
faithfully represent causal dependencies on each subgroup.
This entails the subgroups since each sample corresponds
Theorem 1 (Identifiability of LCGs). With Assumptions 1 to exactly one of the codes, i.e., each code ez represents the
to 4, let {Ĝz } ∈ argmax S({Gz , Ez }K
z=1 ) for λ > 0 small subgroup Ez = {(s, a) | e = ez }. Thus, this corresponds to
enough. Then, each Ĝz is true LCG on Ez , i.e., Ĝz = GEz . the term p(z | s, a) in Eq. (3). In other words, the codebook
C serves as a proxy for decomposition {Ez }K z=1 .
Given the subgroups, corresponding LCGs can be recovered
by score maximization. Therefore, it provides a principled Local causal graphs. Quantized embedding e is then de-
way to discover LCGs, which is valid for any quantization. coded to an adjacency matrix A ∈ {0, 1}(N +M )×N . The
output of the decoder gdec is the parameters of Bernoulli dis-
Unfortunately, not all quantization is useful for fine-grained
tributions from which the matrix is sampled: A ∼ gdec (e).
dynamics modeling, e.g., by dividing into lights on and
In other words, each code ez corresponds to the matrix Az
lights off, it still needs to consider all objects under both
that represents the graph Gz . To properly backpropagate gra-
circumstances. Thus, it is crucial for quantization to cap-
dients, we adopt Gumbel-Softmax reparametrization trick
ture meaningful contexts displaying sparse dependencies.
(Jang et al., 2017; Maddison et al., 2017).
Such useful quantization will allow more sparse dynamics
modeling, i.e., the higher score of Eq. (4). Therefore, the Dynamics learning. The dynamics model p̂ employs the

4
Fine-Grained Causal Dynamics Learning

matrix A for prediction: j log p̂(s′j | s, a; A(j) ), where proach towards the maximization of S({Gz , Ez }K
P
z=1 ) since
A(j) ∈ {0, 1}(N +M ) is the j-th column of A. Each entry Lpred corresponds to Eq. (4) and Lquant is a mean squared
of A(j) indicates whether the corresponding state or action error in the latent space which can be minimized to 0. In
variable will be used to predict the next state s′j . This cor- this section, we provide its implications and discussions.
responds to the term p(s′j | P a(j; Ez ), z) in Eq. (3). For Proposition 1. Let {Gz∗ , Ez∗ } ∈ argmax S({Gz , Ez }K z=1 )
the implementation, we mask out the features of unused for λ > 0 small enough, with Assumptions 1 to 5. Then,
(i) each Gz∗ is true LCG on Ez∗ , and (ii) E |Gz∗ | ≤ E |Gz |
   
variables according to A. We found that this is more stable
compared to the input masking (Brouillard et al., 2020). where {Gz } are LCGs on arbitrary decomposition {Ez }K z=1 .

Training objective. We employ a regularization loss λ · In other words, the decomposition that maximizes the score
∥A∥1 to induce a sparse LCG, where λ is a hyperparameter.
  P
is optimal in terms of E |Gz | = z p(Ez )|Gz |. This is an
To update the codebook, we use a quantization loss (Van important property involving the contexts which are more
Den Oord et al., 2017). The training objective is as follows: likely (i.e., large p(E)) and more meaningful (i.e., sparse
Ltotal = − log p̂(s′ | s, a; A) + λ · ∥A∥1 GE ). Therefore, Prop. 1 implies that score maximization
| {z } would lead to the fine-grained understanding of the dynam-
Lpred
ics at best it can achieve given the quantization degree K.
+ ∥sg [h] − e∥22 + β · ∥h − sg [e] ∥22 . (6)
| {z } We now illustrate how the optimal decomposition {Ez∗ }K z=1
Lquant in Prop. 1 with sufficient quantization degree identifies
Here, Lpred is the masked prediction loss with regulariza- important context D (e.g., door closed) that displays fine-
tion. Lquant is the quantization loss where sg [·] is a stop- grained causal relationships. We say the context D is canon-
gradient operator and β is a hyperparameter. Specifically, ical if GF = GD for any F ⊂ D.
∥sg [h] − e∥22 moves each code toward the center of the em- Theorem 2 (Identifiability of contexts). Let {Gz∗ , Ez∗ } ∈
beddings assigned to it and β · ∥h − sg [e] ∥22 encourages argmax S({Gz , Ez }K z=1 ) for λ > 0 small enough, with As-
the encoder to output the embeddings close to the codes. sumptions 1 to 5. Suppose X = ∪m∈[H] Dm where GDm is
This allows us to jointly train the dynamics model and the distinct for all m ∈ [H], and D1 , · · · , DH are disjoint and
codebook in an end-to-end manner. Intuitively, vector quan- canonical. Suppose K ≥ H. Then, S for all m ∈ [H], there
tization clusters the samples under a similar context and exists Im ⊂ [K] such that Dm = z∈Im Ez∗ almost surely.
reconstructs the LCGs for each clustering. The rationale is
that any error in the graph Gz or clustering Ez would lead to In other words, the joint optimization of Eq. (4) over the
the prediction error of the dynamics model. We provide the quantization and dynamics model with sufficient quantiza-
details of our model in Appendix C.4. tion degree perfectly captures meaningful contexts that exist
in the system (Thm. 2) and recovers corresponding LCGs
We note that prior works on learning a discrete latent code- (Prop. 1-(i)), thereby leading to a fine-grained understand-
book have mostly focused on the reconstruction of the ob- ing of the dynamics. Our method described in the previous
servation (Van Den Oord et al., 2017; Ozair et al., 2021). section serves as a practical approach toward this goal.
To the best of our knowledge, our work is the first to utilize
vector quantization for discovering diverse causal structures. Discussion on the codebook size. Thm. 2 also implies
that the identification of the meaningful contexts is agnos-
Discussion on the codebook collapsing. It is well known tic to the quantization degree K, as long as K ≥ H. In
that training a discrete latent codebook with vector quanti- Sec. 4.2, we demonstrate that our method works reason-
zation often suffers from the codebook collapsing, where ably well for various quantization degrees in practice. We
many codes learn the same output and converge to a trivial note that determining a minimal and sufficient number of
solution. For this, we employ exponential moving averages quantization H is not our primary focus. This is because
(EMA) to update the codebook, following Van Den Oord over-parametrization of quantization incurs only a small
et al. (2017). In practice, we found that the training was memory cost for additional codebook vectors in practice.
relatively stable for any choice of the codebook size K > 2. Note that even if K < H, Prop. 1-(ii) guarantees that it
In our experiments, we simply fixed it to 16 across all envi- would still discover meaningful
 fine-grained causal relation-
ronments since they all performed comparably well, which

ships, optimal in terms of E |Gz | .
we will demonstrate in Sec. 4.2.
Relationship to past approaches. To better understand our
approach, we draw connections to (i) prior causal dynamics
3.4. Theoretical Analysis and Discussions
models and (ii) sample-specific approaches to discovering
So far, we have described how our method learns the de- fine-grained dependencies. First, our method with the quan-
composition and LCGs through the discrete latent codebook tization degree K = 1 degenerates to prior causal dynamics
C as a proxy. Our method can be viewed as a practical ap- models (Wang et al., 2022; Ding et al., 2022): it would

5
Fine-Grained Causal Dynamics Learning

discover (global) causal dependencies (i.e., special case


of Thm. 1) but cannot harness fine-grained relationships.
Second, our method without quantization reverts to sample-
specific approaches (K → ∞), e.g., the auxiliary network Full Chain Fork

that infers local independence directly from each sample


(Hwang et al., 2023). As described earlier, it is unclear under
which context the inferred dependencies hold. In Sec. 4.2,
we demonstrate that this makes their inferences often in-
consistent within the same context and prone to overfitting, (a) Chemical
while our approach with quantization infers fine-grained
causal relationships in a more effective and robust manner.

4. Experiments
In this section, we evaluate our method, coined Fine-Grained
Causal Dynamics Learning (FCDL), to investigate the fol-
lowing questions: (1) Does our method improve robustness
in MBRL (Tables 1 and 2)? (2) Does our method discover
(b) Magnetic
fine-grained causal relationships and capture meaningful
contexts (Figs. 5 to 7)? (3) Is our method more effective Figure 3. Illustrations for each environment. (a) In Chemical, col-
and robust compared to sample-specific approaches (Figs. 6 ors change by the action according to the underlying causal graph.
and 7)? (4) How does the degree of quantization affect (b) In Magnetic, the red object exhibits magnetism.
performance (Fig. 7 and Table 3)?
fine-grained causal reasoning would generalize well since
4.1. Experimental Setup corrupted nodes are locally spurious to predict other nodes.
The environments are designed to exhibit fine-grained causal Magnetic. We designed a robot arm manipulation environ-
relationships under a particular context D. The state vari- ment based on the Robosuite framework (Zhu et al., 2020).
ables (e.g., position, velocity) are fully observable, follow- There is a moving ball and a box on the table, colored red
ing prior works (Ding et al., 2022; Wang et al., 2022; Seitzer or black (Fig. 3(b)). Red color indicates that the object is
et al., 2021; Pitis et al., 2020; 2022). Experimental details magnetic, and attracts the other magnetic object. For exam-
are provided in Appendix C.2 ple, when both are red, magnetic force will be applied, and
the ball will move toward the box. Otherwise, i.e., under the
4.1.1. E NVIRONMENTS non-magnetic context, the box would have no influence on
Chemical (Ke et al., 2021). It is a widely used benchmark the ball. The color and position of the objects are randomly
for systematic evaluation of causal reasoning in RL. There initialized for each episode, i.e., each episode is under the
are 10 nodes, each colored with one of 5 colors. According magnetic or non-magnetic context during training. The task
to the underlying causal graph, an action changes the colors is to reach the ball, predicting its trajectory. In this environ-
of the intervened node’s descendants as depicted in Fig. 3(a). ment, non-magnetic context D displays sparse dependencies
The task is to match the colors of each node to the given (GD ⊊ G) because the box no longer influences the ball un-
target. We designed two settings, named full-fork and full- der this context. In contrast, all causal dependencies remain
chain. In both settings, the underlying CG is both full. When the same under the magnetic context Dc , i.e., GDc = G. CG
the color of the root node is red (D), the colors change and LCGs are shown in Appendix C.1 (Fig. 9). During the
according to fork or chain, respectively (GD ). For example, test, one of the objects is black, and the box is located at
in full-fork, all other parent nodes except the root become an unseen position. Under this non-magnetic context, the
irrelevant under this context. Otherwise (Dc ), the transition box becomes locally spurious, and thus, the agent aware of
respects the graph full (i.e., GDc = G). During the test, the fine-grained causal relationships would generalize well to
root color is set to red, and LCG (fork or chain) is activated. unseen out-of-distribution (OOD) states.
Here, the agent receives a noisy observation for some nodes,
and the task is to match the colors of other clean nodes, as 4.1.2. E XPERIMENTAL DETAILS
depicted in Appendix C.1 (Fig. 8). The agent capable of Baselines. We first consider dense models, i.e., a monolithic
2 network implemented as MLP which learns p(s′ | s, a),
Our code is publicly available at https://github.com/
iwhwang/Fine-Grained-Causal-RL. and a modular
Q network having a separate network for each
variable: j p(s′j | s, a). We also include a graph neural

6
Fine-Grained Causal Dynamics Learning

Table 1. Average episode reward on training and downstream tasks in each environment. In Chemical, n denotes the number of noisy
nodes in downstream tasks.
Chemical (full-fork) Chemical (full-chain) Magnetic
Train Test Test Test Train Test Test Test
Methods Train Test
(n = 0) (n = 2) (n = 4) (n = 6) (n = 0) (n = 2) (n = 4) (n = 6)
MLP 19.00±0.83 6.49±0.48 5.93±0.71 6.84±1.17 17.91±0.87 7.39±0.65 6.63±0.58 6.78±0.93 8.37±0.74 0.86±0.45
Modular 18.55±1.00 6.05±0.70 5.65±0.50 6.43±1.00 17.37±1.63 6.61±0.63 7.01±0.55 7.04±1.07 8.45±0.80 0.88±0.52
GNN (Kipf et al., 2020) 18.60±1.19 6.61±0.92 6.15±0.74 6.95±0.78 16.97±1.85 6.89±0.28 6.38±0.28 6.56±0.53 8.53±0.83 0.92±0.51
NPS (Goyal et al., 2021a) 7.71±1.22 5.82±0.83 5.75±0.57 5.54±0.80 8.20±0.54 6.92±1.03 6.88±0.79 6.80±0.39 3.13±1.00 0.91±0.69
CDL (Wang et al., 2022) 18.95±1.40 9.37±1.33 8.23±0.40 9.50±1.18 17.95±0.83 8.71±0.55 8.65±0.38 10.23±0.50 8.75±0.69 1.10±0.67
GRADER (Ding et al., 2022) 18.65±0.98 9.27±1.31 8.79±0.65 10.61±1.31 17.71±0.54 8.69±0.56 8.75±0.80 10.14±0.33 - -
Oracle 19.64±1.18 7.83±0.87 8.04±0.62 9.66±0.21 17.79±0.76 8.47±0.69 8.85±0.78 10.29±0.37 8.42±0.86 0.95±0.55
NCD (Hwang et al., 2023) 19.30±0.95 10.95±1.63 9.11±0.63 10.32±0.93 18.27±0.27 9.60±1.52 8.86±0.23 10.32±0.37 8.48±0.70 1.31±0.77
FCDL (Ours) 19.28±0.87 15.27±2.53 14.73±1.68 13.62±2.56 17.22±0.61 13.36±3.60 12.35±3.23 12.00±1.21 8.52±0.74 4.81±3.01

MLP GNN GRADER NCD


Modular NPS Oracle Ours

Chemical (full-fork) (n=2) Magnetic


20 10

8
episode reward

episode reward

15
6

10 4
(a) ID (all) (b) ID (fork) (c) OOD (fork)
2
5
0
0 0.5 1.0 1.5 0.5 1.0 1.5 2.0
5 5
Environment step (x10 ) Environment step (x10 )

Figure 4. Learning curves on downstream tasks as measured on


the average episode reward. Lines and shaded areas represent the
mean and standard deviation, respectively.

network (GNN) (Kipf et al., 2020), which learns the rela-


tional information, and NPS (Goyal et al., 2021a), which
learns sparse and modular dynamics. Causal models, in- (d) True LCG (fork) (e) Learned LCG (fork)
cluding CDL (Wang et al., 2022) and GRADER (Ding Figure 5. (Top) Codebook histogram of the sample allocations to
et
Q al., 2022), infer causal structure for dynamics learning: each of the codes on (a) all ID states, (b) ID states under fork, and

j p(sj | P a(j)). We also consider an oracle model, which (c) OOD states under fork. (Bottom) (d) True LCG (fork). (e)
leverages the ground truth (global) causal graph. Finally, Learned LCG corresponding to the most frequently allocated code
we compare to NCD (Hwang et al., 2023), a sample-specific in (b) and (c).
approach that examines local independence for each sample.
Planning algorithm. For all baselines and our method, we states in the downstream tasks. Causal models are generally
use a model predictive control (Camacho & Alba, 2013) more robust compared to dense models, as they infer the
which selects the actions based on the prediction of the causal graph and discard spurious dependencies. NCD, a
learned dynamics model. Specifically, we use the cross- sample-specific approach to infer fine-grained dependencies,
entropy method (CEM) (Rubinstein & Kroese, 2004), which performs better than causal models on a few downstream
iteratively generates and optimizes action sequences. tasks, but not always. In contrast, our method consistently
outperforms the baselines across all downstream tasks. This
Implementation. For our method, we set the hyperparame- empirically validates our hypothesis that fine-grained causal
ters K = 16, λ = 0.001, and β = 0.25 in all experiments. reasoning leads to improved robustness in MBRL.
All methods have a similar model capacity for a fair compar-
ison. For the evaluation, we ran 10 test episodes for every Prediction accuracy (Table 2). To better understand the
40 training episodes. The results are averaged over eight dif- robustness of our method in downstream tasks, we inves-
ferent runs. All learning curves are shown in Appendix C.5. tigate the prediction accuracy on ID and OOD states over
the clean nodes in Chemical. As described earlier, noisy
nodes are irrelevant for predicting the clean nodes under the
4.2. Results
LCG (i.e., fork or chain); thus, they are locally spurious
Downstream task performance (Table 1, Fig. 4). All on OOD states in downstream tasks. While all methods
methods show similar performance on in-distribution (ID) perform reasonably well on ID states, dense models show
states in training. However, dense models suffer from OOD a significant performance drop under the presence of noisy

7
Fine-Grained Causal Dynamics Learning

Table 2. Prediction accuracy on ID (n = 0) and OOD (n = 2, 4, 6) states in Chemical environment.


Setting / n MLP Modular GNN NPS CDL GRADER Oracle NCD FCDL (Ours)
(n = 0) 88.31±1.58 89.24±1.52 88.81±1.44 58.34±2.08 89.22±1.67 87.75±1.64 89.63±1.62 90.07±1.22 89.46±1.40
(n = 2) 31.11±1.69 26.53±3.45 36.29±3.45 40.56±4.61 35.59±1.85 37.93±1.06 33.87±1.34 41.60±5.08 66.44±12.22
full-fork
(n = 4) 30.44±2.28 24.73±5.61 25.80±3.48 26.81±4.37 35.82±1.40 38.94±1.63 36.48±1.80 37.47±2.13 58.49±10.20
(n = 6) 32.39±1.76 26.73±8.31 21.58±3.44 23.02±4.27 42.22±1.39 45.74±2.25 42.47±0.75 42.27±1.82 49.09±4.77
(n = 0) 84.38±1.31 85.92±1.15 85.41±1.84 58.48±2.81 86.85±1.47 84.24±1.22 85.76±1.56 85.63±1.01 86.07±1.62
(n = 2) 28.66±3.65 25.24±4.68 29.22±3.39 38.73±2.63 34.90±1.59 36.82±3.12 34.63±1.78 40.04±6.21 60.34±12.10
full-chain
(n = 4) 26.52±4.26 24.94±4.81 23.28±4.98 27.69±4.28 36.52±1.72 37.41±2.84 38.31±2.48 37.47±2.98 56.64±9.40
(n = 6) 24.15±4.17 25.09±5.91 20.53±6.96 24.45±3.84 42.06±1.29 43.48±4.14 42.87±2.08 41.19±1.66 53.29±6.63

(a) FCDL (CG) (b) FCDL (LCG) (c) NCD (LCG, ID) (d) NCD (LCG, OOD)

Figure 6. Red boxes indicate edges included in global CG, but not in LCG under the non-magnetic context. (a) CG inferred by our
method. (b-d) LCG under the non-magnetic context inferred by (b) our method, and NCD on (c) ID and (d) OOD state.

variables, merely above 20% which is an expected accuracy our method in fine-grained causal reasoning compared to the
of random guessing. As expected, causal dynamics mod- sample-specific approach. For this, we examine the inferred
els tend to be more robust compared to dense models, but LCGs in Magnetic, where true LCGs and CG are shown in
they still suffer from OOD states. NCD is more robust than Appendix C.1 (Fig. 9). First, our method accurately learns
causal models when n = 2, but eventually becomes similar LCG under the non-magnetic context (Fig. 6(b)). On the
to them as the number of noisy nodes increases. In contrast, other hand, the LCG inferred by NCD is rather inaccurate
our method outperforms baselines by a large margin across (Fig. 6(c)), including some locally spurious dependencies (3
all downstream tasks, which demonstrates its effectiveness among 6 red boxes). Furthermore, its inference is inconsis-
and robustness in fine-grained causal reasoning. tent between ID and OOD states in the same non-magnetic
context and completely fails on OOD states (Fig. 6(d)). This
Recognizing important contexts and fine-grained causal
demonstrates that our approach is more effective and robust
relationships (Fig. 5). To illustrate the fine-grained causal
in discovering fine-grained causal relationships.
reasoning of our method, we closely examine the behavior of
our model with the quantization degree K = 4 in Chemical Evaluation of local causal discovery (Fig. 7). We evaluate
(full-fork, n = 2). Recall each code corresponds to the pair our method and NCD using structural hamming distance
of a subgroup and LCG, Fig. 5(a) shows how ID samples in (SHD) in Magnetic. For each sample, we compare the
the batch are allocated to one of the four codes. Interestingly, inferred LCG with the true LCG based on the magnetic/non-
ID samples corresponding to LCG fork are all allocated to magnetic context, and the SHD scores are averaged over
the last code (Fig. 5(b)), i.e., the subgroup corresponding the data samples in the evaluation batch. As expected, our
to the last code identifies this context. Furthermore, LCG method infers fine-grained relationships more accurately
decoded from this code (Fig. 5(e)) accurately captures the and maintains better performance on OOD states across var-
true fork structure (Fig. 5(d)). This demonstrates that our ious quantization degrees, which validates its effectiveness
method successfully recognizes meaningful context and fine- and robustness compared to NCD. Lastly, we note that our
grained causal relationships. Notably, Fig. 5(c) shows that method with the quantization degree K = 1 would learn
most of the OOD samples under fork are correctly allocated only a single CG over the entire data domain, as shown in
to the last code. This illustrates the robustness of our method, Fig. 6(a). This explains its mean SHD score of 6 in non-
i.e., its inference is consistent between ID and OOD states. magnetic samples in Fig. 7, since CG includes six redundant
Additional examples, including the visualization of all the edges in non-magnetic context (i.e., red boxes in Fig. 9).
learned LCGs from all codes, are provided in Appendix C.5.
Ablation on the quantization degree (Table 3). Finally,
Inferred LCGs compared to sample-specific approach we observe that our method works reasonably well across
(Fig. 6). We investigated the effectiveness and robustness of various quantization degrees on all downstream tasks in

8
Fine-Grained Causal Dynamics Learning

FCDL NCD Table 3. Ablation on the quantization degree.


ID (all) ID (non-magnetic) OOD (non-magnetic) Chemical (full-fork)

6 6 6 Methods (n = 2) (n = 4) (n = 6)
CDL 9.37±1.33 8.23±0.40 9.50±1.18
SHD

SHD

SHD
4 4 4
NCD 10.95±1.63 9.11±0.63 10.32±0.93
2 2 2 FCDL (K = 2) 13.44±5.41 12.86±5.58 12.99±5.27
0 0 0 FCDL (K = 4) 15.73±4.13 16.50±3.40 12.40±2.81
1 2 4 8 16 32 1 2 4 8 16 32 1 2 4 8 16 32 FCDL (K = 8) 14.95±1.16 15.03±2.61 13.42±2.67
Quantization degree Quantization degree Quantization degree
FCDL (K = 16) 15.27±2.53 14.73±1.68 13.62±2.56
FCDL (K = 32) 16.12±1.43 14.35±1.37 14.79±2.13
Figure 7. Evaluation of local causal discovery in Magnetic environment.

Chemical (full-fork). Our method consistently outperforms domain knowledge, it would still be useful for discovering
the prior causal dynamics model (CDL) and sample-specific fine-grained relationships more efficiently (e.g., Thm. 1).
approach (NCD), which corroborates the results in Fig. 7.
Implications to real-world scenarios. We believe our work
During our experiments, we found that the training was
has potential implications in many practical applications
relatively stable for any quantization degree of K > 2.
since context-dependent causal relationships are prevalent
We also found that instability often occurs under K = 2,
in real-world scenarios. For example, in healthcare, a dy-
where the samples frequently fluctuate between two proto-
namic treatment regime is a task of determining a sequence
type vectors and result in the codebook collapsing. This is
of decision rules (e.g., treatment type, drug dosage) based
also shown in Table 3 where the performance of K = 2 is
on the patient’s health status where it is known that many
worse compared to other choices of K. We speculate that
pathological factors involve fine-grained causal relation-
over-parametrization of quantization could alleviate such
ships (Barash & Friedman, 2001; Edwards & Toma, 1985).
fluctuation in general.
Our experiments illustrate that existing causal/non-causal
RL approaches could suffer from locally spurious correla-
5. Discussions and Future Works tions and fail to generalize in downstream tasks. We believe
our work serves as a stepping stone for further investigation
High-dimensional observation. The factorization of the
into fine-grained causal reasoning of RL systems and their
state space is natural in many real-world domains (e.g.,
robustness in real-world deployment.
healthcare, recommender system, social science, economics)
where discovering causal relationships is an important prob-
lem. Extending our framework to the image would require 6. Conclusion
extracting causal factors from pixels (Schölkopf et al., 2021),
We present a novel approach to dynamics learning that in-
which is orthogonal to ours and could be combined with.
fers fine-grained causal relationships, leading to improved
Scalability and stability in training. Vector quantization robustness of MBRL. We provide a principled way to exam-
(VQ) is a well-established component in generative models ine fine-grained dependencies under certain contexts. As a
where the quantization degree is usually very high (e.g., practical approach, our method learns a discrete latent vari-
K = 512, 1024), yet effectively captures diverse visual fea- able that represents the pairs of a subgroup and local causal
tures. Its scalability is further showcased in complex large- graphs (LCGs), allowing joint optimization with the dynam-
scale datasets (Razavi et al., 2019). In this sense, we believe ics model. Consequently, our method infers fine-grained
our framework could extend to complex real-world environ- causal structures in a more effective and robust manner com-
ments. For training stability, techniques have been recently pared to prior approaches. As one of the first steps towards
proposed to prevent codebook collapsing, such as codebook fine-grained causal reasoning in sequential decision-making
reset (Williams et al., 2020) and stochastic quantization systems, we hope our work stimulates future research to-
(Takida et al., 2022). We consider that such techniques and ward this goal.
tricks could be incorporated into our framework.
Conditional independence test (CIT). A CIT is an effec- Impact Statement
tive tool for understanding causal relationships, although
In real-world applications, model-based RL requires a large
often computation-costly. Our method may utilize it to fur-
amount of data. As a large-scale dataset may contain sensi-
ther calibrate the learned LCGs, e.g., applying CIT on each
tive information, it would be advisable to discreetly evaluate
subgroup after the training, which we defer to future work.
the models within simulated environments before their real-
Domain knowledge. Our method could leverage prior infor- world deployment.
mation on important contexts displaying sparse dependen-
cies, if available. While our method does not rely on such

9
Fine-Grained Causal Dynamics Learning

Acknowledgements Chung, J., Gulcehre, C., Cho, K., and Bengio, Y. Empirical
evaluation of gated recurrent neural networks on sequence
We would like to thank Sujin Jeon and Hyundo Lee for modeling. arXiv preprint arXiv:1412.3555, 2014.
the useful discussions. We also thank anonymous review-
ers for their constructive comments. This work was partly Dal, G. H., Laarman, A. W., and Lucas, P. J. Parallel proba-
supported by the IITP (RS-2021-II212068-AIHub/10%, RS- bilistic inference by weighted model counting. In Inter-
2021-II211343-GSAI/10%, 2022-0-00951-LBA/10%, 2022- national Conference on Probabilistic Graphical Models,
0-00953-PICA/20%), NRF (RS-2023-00211904/20%, RS- pp. 97–108. PMLR, 2018.
2023-00274280/10%, RS-2024-00353991/10%), and KEIT
(RS-2024-00423940/10%) grant funded by the Korean gov- De Haan, P., Jayaraman, D., and Levine, S. Causal confu-
ernment. sion in imitation learning. Advances in Neural Informa-
tion Processing Systems, 32, 2019.

References Ding, W., Lin, H., Li, B., and Zhao, D. Generalizing
goal-conditioned reinforcement learning with variational
Acid, S. and de Campos, L. M. Searching for bayesian causal reasoning. In Advances in Neural Information
network structures in the space of restricted acyclic par- Processing Systems, 2022.
tially directed graphs. Journal of Artificial Intelligence
Research, 18:445–490, 2003. Edwards, D. and Toma, H. A fast procedure for
model search in multidimensional contingency tables.
Barash, Y. and Friedman, N. Context-specific bayesian Biometrika, 72:339–351, 1985.
clustering for gene expression data. In Proceedings of the
fifth annual international conference on Computational Feng, F. and Magliacane, S. Learning dynamic attribute-
biology, pp. 12–21, 2001. factored world models for efficient multi-object reinforce-
ment learning. In Advances in Neural Information Pro-
Bareinboim, E., Forney, A., and Pearl, J. Bandits with
cessing Systems, volume 36, pp. 19117–19144, 2023.
unobserved confounders: A causal approach. Advances
in Neural Information Processing Systems, 28, 2015. Feng, F., Huang, B., Zhang, K., and Magliacane, S. Fac-
tored adaptation for non-stationary reinforcement learn-
Bica, I., Jarrett, D., and van der Schaar, M. Invariant causal
ing. In Advances in Neural Information Processing Sys-
imitation learning for generalizable policies. Advances in
tems, 2022.
Neural Information Processing Systems, 34:3952–3964,
2021. Goyal, A., Didolkar, A. R., Ke, N. R., Blundell, C., Beau-
doin, P., Heess, N., Mozer, M. C., and Bengio, Y. Neural
Bongers, S., Blom, T., and Mooij, J. M. Causal modeling
production systems. In Advances in Neural Information
of dynamical systems. arXiv preprint arXiv:1803.08784,
Processing Systems, 2021a.
2018.
Goyal, A., Lamb, A., Gampa, P., Beaudoin, P., Blundell,
Boutilier, C., Friedman, N., Goldszmidt, M., and Koller,
C., Levine, S., Bengio, Y., and Mozer, M. C. Factorizing
D. Context-specific independence in bayesian networks.
declarative and procedural knowledge in structured, dy-
CoRR, abs/1302.3562, 2013.
namical environments. In International Conference on
Brouillard, P., Lachapelle, S., Lacoste, A., Lacoste-Julien, Learning Representations, 2021b.
S., and Drouin, A. Differentiable causal discovery from
Goyal, A., Lamb, A., Hoffmann, J., Sodhani, S., Levine,
interventional data. Advances in Neural Information Pro-
S., Bengio, Y., and Schölkopf, B. Recurrent independent
cessing Systems, 33:21865–21877, 2020.
mechanisms. In International Conference on Learning
Buesing, L., Weber, T., Zwols, Y., Heess, N., Racaniere, S., Representations, 2021c.
Guez, A., and Lespiau, J.-B. Woulda, coulda, shoulda:
Counterfactually-guided policy search. In International Hoey, J., St-Aubin, R., Hu, A., and Boutilier, C. Spudd:
Conference on Learning Representations, 2019. stochastic planning using decision diagrams. In Pro-
ceedings of the Fifteenth conference on Uncertainty in
Camacho, E. F. and Alba, C. B. Model predictive control. artificial intelligence, pp. 279–288, 1999.
Springer science & business media, 2013.
Huang, B., Lu, C., Leqi, L., Hernández-Lobato, J. M., Gly-
Chitnis, R., Silver, T., Kim, B., Kaelbling, L., and Lozano- mour, C., Schölkopf, B., and Zhang, K. Action-sufficient
Perez, T. Camps: Learning context-specific abstractions state representation learning for control with structural
for efficient planning in factored mdps. In Conference on constraints. In International Conference on Machine
Robot Learning, pp. 64–79. PMLR, 2021. Learning, pp. 9260–9279. PMLR, 2022.

10
Fine-Grained Causal Dynamics Learning

Hwang, I., Kwak, Y., Song, Y.-J., Zhang, B.-T., and Lee, Li, Y., Torralba, A., Anandkumar, A., Fox, D., and Garg,
S. On discovery of local independence over continuous A. Causal discovery in physical systems from videos.
variables via neural contextual decomposition. In Confer- Advances in Neural Information Processing Systems, 33:
ence on Causal Learning and Reasoning, pp. 448–472. 9180–9192, 2020.
PMLR, 2023.
Löwe, S., Madras, D., Zemel, R., and Welling, M. Amor-
Jamshidi, F., Akbari, S., and Kiyavash, N. Causal imitabil- tized causal discovery: Learning to infer causal graphs
ity under context-specific independence relations. In from time-series data. In Conference on Causal Learning
Advances in Neural Information Processing Systems, vol- and Reasoning, pp. 509–525. PMLR, 2022.
ume 36, pp. 26810–26830, 2023.
Lu, C., Schölkopf, B., and Hernández-Lobato, J. M. Decon-
Jang, E., Gu, S., and Poole, B. Categorical reparameteriza- founding reinforcement learning in observational settings.
tion with gumbel-softmax. In International Conference arXiv preprint arXiv:1812.10576, 2018.
on Learning Representations, 2017.
Lu, C., Huang, B., Wang, K., Hernández-Lobato, J. M.,
Kaiser, Ł., Babaeizadeh, M., Miłos, P., Osiński, B., Camp- Zhang, K., and Schölkopf, B. Sample-efficient reinforce-
bell, R. H., Czechowski, K., Erhan, D., Finn, C., Koza- ment learning via counterfactual-based data augmenta-
kowski, P., Levine, S., et al. Model based reinforcement tion. arXiv preprint arXiv:2012.09092, 2020.
learning for atari. In International Conference on Learn-
ing Representations, 2020. Lyle, C., Zhang, A., Jiang, M., Pineau, J., and Gal, Y. Re-
solving causal confusion in reinforcement learning via
Ke, N. R., Didolkar, A. R., Mittal, S., Goyal, A., Lajoie, robust exploration. In Self-Supervision for Reinforcement
G., Bauer, S., Rezende, D. J., Mozer, M. C., Bengio, Y., Learning Workshop-ICLR, volume 2021, 2021.
and Pal, C. Systematic evaluation of causal discovery in
visual model based reinforcement learning. In Thirty-fifth Maddison, C. J., Mnih, A., and Teh, Y. W. The concrete
Conference on Neural Information Processing Systems distribution: A continuous relaxation of discrete random
Datasets and Benchmarks Track (Round 2), 2021. variables. In International Conference on Learning Rep-
resentations, 2017.
Kearns, M. and Koller, D. Efficient reinforcement learning
in factored mdps. In IJCAI, volume 16, pp. 740–747, Madumal, P., Miller, T., Sonenberg, L., and Vetere, F. Ex-
1999. plainable reinforcement learning through a causal lens.
In Proceedings of the AAAI conference on artificial intel-
Killian, T. W., Ghassemi, M., and Joshi, S. Counterfactually ligence, pp. 2493–2500, 2020.
guided policy transfer in clinical settings. In Conference
on Health, Inference, and Learning, pp. 5–31. PMLR, Mesnard, T., Weber, T., Viola, F., Thakoor, S., Saade, A.,
2022. Harutyunyan, A., Dabney, W., Stepleton, T. S., Heess,
N., Guez, A., et al. Counterfactual credit assignment in
Kipf, T., van der Pol, E., and Welling, M. Contrastive model-free reinforcement learning. In International Con-
learning of structured world models. In International ference on Machine Learning, pp. 7654–7664. PMLR,
Conference on Learning Representations, 2020. 2021.

Kumor, D., Zhang, J., and Bareinboim, E. Sequential causal Mutti, M., De Santi, R., Rossi, E., Calderon, J. F., Bronstein,
imitation learning with unobserved confounders. In Ad- M., and Restelli, M. Provably efficient causal model-
vances in Neural Information Processing Systems, vol- based reinforcement learning for systematic generaliza-
ume 34, pp. 14669–14680, 2021. tion. In Proceedings of the AAAI Conference on Artificial
Intelligence, pp. 9251–9259, 2023.
Lee, S. and Bareinboim, E. Structural causal bandits: Where
to intervene? In Advances in Neural Information Pro- Nair, S., Zhu, Y., Savarese, S., and Fei-Fei, L. Causal
cessing Systems, volume 31, 2018. induction from visual observations for goal directed tasks.
arXiv preprint arXiv:1910.01751, 2019.
Lee, S. and Bareinboim, E. Characterizing optimal mixed
policies: Where to intervene and what to observe. In Oberst, M. and Sontag, D. Counterfactual off-policy eval-
Advances in Neural Information Processing Systems, vol- uation with gumbel-max structural causal models. In
ume 33, pp. 8565–8576, 2020. International Conference on Machine Learning, pp. 4881–
4890. PMLR, 2019.
Li, M., Zhang, J., and Bareinboim, E. Causally aligned
curriculum learning. In The Twelfth International Confer- Ozair, S., Li, Y., Razavi, A., Antonoglou, I., Van Den Oord,
ence on Learning Representations, 2024. A., and Vinyals, O. Vector quantized models for planning.

11
Fine-Grained Causal Dynamics Learning

In International Conference on Machine Learning, pp. Sontakke, S. A., Mehrjou, A., Itti, L., and Schölkopf, B.
8302–8313. PMLR, 2021. Causal curiosity: Rl agents discovering self-supervised
experiments for causal representation learning. In Inter-
Pearl, J. Causality. Cambridge university press, 2009. national conference on machine learning, pp. 9848–9858.
PMLR, 2021.
Peters, J., Janzing, D., and Schölkopf, B. Elements of causal
inference: foundations and learning algorithms. The MIT Spirtes, P., Glymour, C. N., Scheines, R., and Heckerman,
Press, 2017. D. Causation, prediction, and search. MIT press, 2000.
Pitis, S., Creager, E., and Garg, A. Counterfactual data Sutton, R. S. and Barto, A. G. Reinforcement learning: An
augmentation using locally factored dynamics. Advances introduction. MIT press, 2018.
in Neural Information Processing Systems, 33, 2020.
Takida, Y., Shibuya, T., Liao, W., Lai, C.-H., Ohmura, J.,
Pitis, S., Creager, E., Mandlekar, A., and Garg, A. MocoDA:
Uesaka, T., Murata, N., Takahashi, S., Kumakura, T.,
Model-based counterfactual data augmentation. In Ad-
and Mitsufuji, Y. Sq-vae: Variational bayes on discrete
vances in Neural Information Processing Systems, 2022.
representation with self-annealed stochastic quantization.
Poole, D. Context-specific approximation in probabilistic In International Conference on Machine Learning, pp.
inference. In Proceedings of the Fourteenth conference on 20987–21012. PMLR, 2022.
Uncertainty in artificial intelligence, pp. 447–454, 1998.
Tikka, S., Hyttinen, A., and Karvanen, J. Identifying causal
Ramsey, J., Spirtes, P., and Zhang, J. Adjacency-faithfulness effects via context-specific independence relations. Ad-
and conservative causal inference. In Proceedings of the vances in Neural Information Processing Systems, 32:
Twenty-Second Conference on Uncertainty in Artificial 2804–2814, 2019.
Intelligence, pp. 401–408, 2006.
Tomar, M., Zhang, A., Calandra, R., Taylor, M. E.,
Razavi, A., van den Oord, A., and Vinyals, O. Generating and Pineau, J. Model-invariant state abstractions for
diverse high-fidelity images with vq-vae-2. In Advances model-based reinforcement learning. arXiv preprint
in Neural Information Processing Systems, volume 32, arXiv:2102.09850, 2021.
2019.
Van Den Oord, A., Vinyals, O., et al. Neural discrete rep-
Rezende, D. J., Danihelka, I., Papamakarios, G., Ke, N. R., resentation learning. Advances in neural information
Jiang, R., Weber, T., Gregor, K., Merzic, H., Viola, F., processing systems, 30, 2017.
Wang, J., et al. Causally correct partial models for re-
inforcement learning. arXiv preprint arXiv:2002.02836, Volodin, S., Wichers, N., and Nixon, J. Resolving spuri-
2020. ous correlations in causal models of environments via
interventions. arXiv preprint arXiv:2002.05217, 2020.
Rubinstein, R. Y. and Kroese, D. P. The cross-entropy
method: a unified approach to combinatorial optimiza- Wang, Z., Xiao, X., Zhu, Y., and Stone, P. Task-independent
tion, Monte-Carlo simulation, and machine learning, vol- causal state abstraction. In Proceedings of the 35th Inter-
ume 133. Springer, 2004. national Conference on Neural Information Processing
Systems, Robot Learning workshop, 2021.
Schölkopf, B., Locatello, F., Bauer, S., Ke, N. R., Kalch-
brenner, N., Goyal, A., and Bengio, Y. Toward causal Wang, Z., Xiao, X., Xu, Z., Zhu, Y., and Stone, P. Causal
representation learning. Proceedings of the IEEE, 109(5): dynamics learning for task-independent state abstraction.
612–634, 2021. In International Conference on Machine Learning, pp.
23151–23180. PMLR, 2022.
Schrittwieser, J., Antonoglou, I., Hubert, T., Simonyan, K.,
Sifre, L., Schmitt, S., Guez, A., Lockhart, E., Hassabis, Wang, Z., Hu, J., Stone, P., and Martı́n-Martı́n, R. ELDEN:
D., Graepel, T., et al. Mastering atari, go, chess and shogi Exploration via local dependencies. In Thirty-seventh
by planning with a learned model. Nature, 588(7839): Conference on Neural Information Processing Systems,
604–609, 2020. 2023.

Seitzer, M., Schölkopf, B., and Martius, G. Causal influ- Williams, W., Ringer, S., Ash, T., MacLeod, D., Dougherty,
ence detection for improving efficiency in reinforcement J., and Hughes, J. Hierarchical quantized autoencoders.
learning. In Advances in Neural Information Processing Advances in Neural Information Processing Systems, 33:
Systems, 2021. 4524–4535, 2020.

12
Fine-Grained Causal Dynamics Learning

Yang, K., Katcoff, A., and Uhler, C. Characterizing and


learning equivalence classes of causal dags under inter-
ventions. In International Conference on Machine Learn-
ing, pp. 5541–5550. PMLR, 2018.
Yao, W., Chen, G., and Zhang, K. Learning latent causal
dynamics. arXiv preprint arXiv:2202.04828, 2022.

Yoon, J., Wu, Y.-F., Bae, H., and Ahn, S. An investiga-


tion into pre-training object-centric representations for
reinforcement learning. In International Conference on
Machine Learning, pp. 40147–40174. PMLR, 2023.
Zadaianchuk, A., Seitzer, M., and Martius, G. Self-
supervised visual reinforcement learning with object-
centric representations. In International Conference on
Learning Representations, 2021.
Zhang, A., Lipton, Z. C., Pineda, L., Azizzadenesheli, K.,
Anandkumar, A., Itti, L., Pineau, J., and Furlanello, T.
Learning causal state representations of partially observ-
able environments. arXiv preprint arXiv:1906.10437,
2019.
Zhang, A., Lyle, C., Sodhani, S., Filos, A., Kwiatkowska,
M., Pineau, J., Gal, Y., and Precup, D. Invariant causal
prediction for block mdps. In International Conference
on Machine Learning, pp. 11214–11224. PMLR, 2020a.
Zhang, J., Kumor, D., and Bareinboim, E. Causal imita-
tion learning with unobserved confounders. Advances in
neural information processing systems, 33:12263–12274,
2020b.

Zhang, N. L. and Poole, D. On the role of context-specific


independence in probabilistic inference. In 16th Interna-
tional Joint Conference on Artificial Intelligence, IJCAI
1999, Stockholm, Sweden, volume 2, pp. 1288, 1999.
Zholus, A., Ivchenkov, Y., and Panov, A. Factorized world
models for learning causal relationships. In ICLR2022
Workshop on the Elements of Reasoning: Objects, Struc-
ture and Causality, 2022.
Zhu, Y., Wong, J., Mandlekar, A., Martı́n-Martı́n, R., Joshi,
A., Nasiriany, S., and Zhu, Y. robosuite: A modular
simulation framework and benchmark for robot learning.
arXiv preprint arXiv:2009.12293, 2020.

13
Fine-Grained Causal Dynamics Learning

A. Appendix for Preliminary


A.1. Extended Related Work
Recently, incorporating causal reasoning into RL has gained much attention in the community in various aspects. For
example, causality has been shown to improve off-policy evaluation (Buesing et al., 2019; Oberst & Sontag, 2019), goal-
directed tasks (Nair et al., 2019), credit assignment (Mesnard et al., 2021), robustness (Lyle et al., 2021; Volodin et al., 2020),
policy transfer (Killian et al., 2022), explainability (Madumal et al., 2020), and policy learning with counterfactual data
augmentation (Lu et al., 2020; Pitis et al., 2020; 2022). Causality has also been integrated with bandits (Bareinboim et al.,
2015; Lee & Bareinboim, 2018; 2020), curriculum learning (Li et al., 2024) or imitation learning (Bica et al., 2021; De Haan
et al., 2019; Zhang et al., 2020b; Kumor et al., 2021; Jamshidi et al., 2023) to handle the unobserved confounders and learn
generalizable policies. Another line of work focused on causal reasoning over the high-dimensional visual observation (Lu
et al., 2018; Rezende et al., 2020; Feng et al., 2022; Feng & Magliacane, 2023), e.g., learning sparse and modular dynamics
(Goyal et al., 2021c;b;a), where the representation learning is crucial (Zhang et al., 2019; Sontakke et al., 2021; Tomar et al.,
2021; Schölkopf et al., 2021; Zadaianchuk et al., 2021; Yoon et al., 2023).
Our work falls into the category of incorporating causality into dynamics learning in RL (Mutti et al., 2023), where recent
works have focused on conditional independences between the variables and their global causal relationships (Wang et al.,
2021; 2022; Ding et al., 2022). On the contrary, our work incorporates fine-grained causal relationships into dynamics
learning, which is underexplored in prior works.

A.2. Background on Local Independence Relationship


In this subsection, we provide the background on the local independence relationship. We first describe context-specific
independence (CSI) (Boutilier et al., 2013), which denotes a variable being conditionally independent of others given a
particular context, not the full set of parents in the graph.
Definition 2 (Context-Specific Independence (CSI) (Boutilier et al., 2013), reproduced from
 Hwang etal. (2023)). Y is said
to be contextually independent of XB given the context XA = xA if P y | xA , xB = P y | xA , holds for all y ∈ Y
and xB ∈ XB whenever P (xA , xB ) > 0. This will be denoted by Y ⊥ ⊥ XB | XA = x A .

CSI has been widely studied especially for discrete variables with low cardinality, e.g., binary variables. Context-set specific
independence (CSSI) generalizes the notion of CSI allowing continuous variables.
Definition 3 (Context-Set Specific Independence (CSSI) (Hwang et al., 2023)). Let X = {X1 , · · · , Xd } be a non-empty set
of the parents of Y in a causal graph, and E ⊆ X be an event with a positive probability.  E is said to′ be a context
 set which
induces context-set specific independence (CSSI) of X A c from Y if p y | xAc , xA = p y | x c , xA holds for every
A
(xAc , xA ) , x′Ac , xA ∈ E. This will be denoted by Y ⊥

⊥ XAc | XA , E.

Intuitively, it denotes that the conditional distribution p(y | x) = p(y | xAc , xA ) is the same for different values of xAc , for
all x = (xAc , xA ) ∈ E. In other words, only a subset of the parent variables is sufficient for modeling p(y | x) on E.

A.3. Fine-Grained Causal Relationships in Factored MDP


As mentioned in Sec. 2, we consider factored MDP where the causal graph is directed bipartite and make standard
assumptions in the field to properly identify the causal relationships in MBRL (Ding et al., 2022; Wang et al., 2021; 2022;
Seitzer et al., 2021; Pitis et al., 2020; 2022).
Assumption 1. We assume Markov property (Pearl, 2009), faithfulness (Peters et al., 2017), and causal sufficiency (Spirtes
et al., 2000).

Recall that X = {S1 , · · · , SN , A1 , · · · , AM }, Y = {S1′ , · · · , SN



}, and P a(j) is parent variables of Sj′ . Now, we formally
define local independence by adapting CSSI to our setting.
Definition 4 (Local Independence). Let T ⊆ P a(j) and E ⊆ X with p(E) > 0. We say the local  independence
Sj′ ⊥
⊥ X \ T | T, E holds on E if p(s′j | xT c , xT ) = p(s′j | x′T c , xT ) holds for every (xT c , xT ) , x′T c , xT ∈ E.3

It implies that only a subset of the parent variables (xT ) is locally relevant on E, and any other remaining variables (xT c )
are locally irrelevant, i.e., p(s′j | x) is a function of xT on E. Local independence generalizes conditional independence
3
T denotes an index set of T.

14
Fine-Grained Causal Dynamics Learning

in the sense that if Sj′ ⊥


⊥ X \ T | T holds, then Sj′ ⊥⊥ X \ T | T, E holds for any E ⊆ X . Throughout the paper, we are
concerned with the events with the positive probability, i.e., p(E) > 0.
Definition 5. P a(j; E) is a subset of P a(j) such that Sj′ ⊥
⊥ X \ P a(j; E) | P a(j; E), E holds and Sj′ ̸⊥
⊥ X \ T | T, E for
any T ⊊ P a(j; E).

In other words, P a(j; E) is a minimal subset of P a(j) in which the local independence on E holds. Clearly, P a(j; X ) =
P a(j), i.e., local independence on X is equivalent to the (global) conditional independence.
LCG (Def. 1) describes fine-grained causal relationships specific to E. LCG is always a subgraph of the (global) causal
graph, i.e., GD ⊆ G, because if a dependency (i.e., edge) does not exist under the whole domain, it cannot exist under any
context. Note that GX = G, i.e., local independence and LCG under X are equivalent to conditional independence and CG,
respectively.
Analogous to the faithfulness assumption (Peters et al., 2017) that no conditional independences other than ones entailed by
CG are present, we introduce a similar assumption for LCG and local independence.
Assumption 2 (E-Faithfulness). For any E, no local independences on E other than the ones entailed by GE are present, i.e.,
for any j, there does not exists any T such that P a(j; E) \ T ̸= ∅ and Sj′ ⊥
⊥ X \ T | T, E.

Regardless of E-faithfulness assumption, LCG always exists because P a(j; E) always exists. However, such LCG may not
be unique without this (see Hwang et al. (2023, Example. 2) for this example). Assumption 2 implies the uniqueness of
P a(j; E) and GE , and thus it is required to properly identify fine-grained causal relationships between the variables.
Such fine-grained causal relationships are prevalent in the real world. Physical law; To move a static object, a force
exceeding frictional resistance must be exerted. Otherwise, the object would not move. Logic; Consider A ∨ B ∨ C. When
A is true, any changes of B or C no longer affect the outcome. Biology; In general, smoking has a causal effect on blood
pressure. However, one’s blood pressure becomes independent of smoking if a ratio of alpha and beta lipoproteins is larger
than a certain threshold (Edwards & Toma, 1985).

B. Appendix for Method and Theoretical Analysis


B.1. Fine-Grained Dynamics Modeling
K
With the arbitrary decomposition {Ez }z=1 , true transition dynamics p(s′ | s, a) can be written as:
p(s′j | P a(j; Ez ), z)1{(s,a)∈Ez } ,
X XY
p(s′ | s, a) = p(s′ | s, a, z)p(z | s, a) = (7)
z z j

where p(z | s, a) = 1 if (s, a) ∈ Ez . This illustrates our approach to dynamics modeling based on fine-grained causal
dependencies: p(s′j | s, a) is a function of P a(j, Ez ) on Ez , and our dynamics model employs locally relevant dependencies
P a(j, Ez ) for predicting Sj′ . Our dynamics modeling with some graphs {Gz }K z=1 is:

p̂(s′ | s, a; Gz , ϕz )1{(s,a)∈Ez } = z )1{(s,a)∈Ez } ,


X XY
p̂(s′ | s, a; {Gz , Ez }, ϕ) = p̂j (s′j | P aGz (j); ϕ(j) (8)
z z j

(j) (j)
where ϕz takes P aGz (j) as an input and outputs the parameters of the density function p̂j and ϕ := {ϕz }. We denote
p̂{Gz ,Ez },ϕ := p̂(s′ | s, a; {Gz , Ez }, ϕ) and p̂Gz ,ϕz := p̂(s′ | s, a; Gz , ϕz ). In other words, p̂{Gz ,Ez },ϕ (s′ | s, a) = p̂Gz ,ϕz (s′ |
s, a) if (s, a) ∈ Ez .
Now, we revisit the score function in Eq. (4):

S({Gz , Ez }K
 
z=1 ) := sup Ep(s,a,s′ ) log p̂(s | s, a; {Gz , Ez }, ϕ) − λ|Gz | , (9)
ϕ

= sup Ep(s,a) Ep(s′ |s,a) log p̂(s′ | s, a; {Gz , Ez }, ϕ) − λ|Gz | ,


 
(10)
ϕ
XZ  
= sup p(s, a) Ep(s′ |s,a) log p̂(s′ | s, a; Gz , ϕz ) − λ|Gz | , (11)
ϕ z (s,a)∈Ez
" #
X Z  

= sup p(s, a) Ep(s′ |s,a) log p̂(s | s, a; Gz , ϕz ) − λ · p(Ez ) · |Gz | , (12)
ϕ z (s,a)∈Ez

15
Fine-Grained Causal Dynamics Learning

(j)
where p̂(s′ | s, a; Gz , ϕz ) = p̂j (s′j | P aGz (j); ϕz ).
Q
j

B.2. Proof of Thm. 1


Due to the nature of factored MDP where the causal graph is directed bipartite, each Markov equivalence class (MEC)
constrained under temporal precedence contains a unique causal graph (i.e., a skeleton determines a unique causal graph
since temporal precedence fully orients the edges). Given this background, it is known that the causal graph is uniquely
identifiable with oracle conditional independence test (Ding et al., 2022) or score-based method (Brouillard et al., 2020).
We will now show that LCG is also uniquely identifiable via score maximization. Our proof techniques are built upon
Brouillard et al. (2020). It is worth noting that they provide the identifiability of (global) CG up to I-MEC (Yang et al., 2018)
by utilizing observational and interventional data. In contrast, our analysis is on the identifiability of LCGs by utilizing only
observational data. We start by adopting some assumptions from Brouillard et al. (2020).
Assumption 3. The ground truth density p(s′ | s, a) ∈ H({Gz∗ , Ez }) for any decomposition {Ez } with corresponding true
LCGs {Gz∗ }, where H({Gz∗ , Ez }) := {p | ∃ϕ, p = p̂{Gz∗ ,Ez },ϕ }. We assume the density p̂{Gz∗ ,Ez },ϕ is strictly positive for all ϕ.
Definition 6. For a graph G and E ⊂ X , let FE (G) be a set of conditional densities f such that f (s′ | s, a) = j fj (s′j |
Q

P aG (j)) for all (s, a) ∈ E where each fj is a conditional density.


Assumption 4. |Ep(s,a,s′ ) log p(s′ | s, a)| < ∞.

Assumption 3 states that the model parametrized by neural network has sufficient capacity to represent the ground truth
density. Assumption 4 is a technical tool for handling the score differences as we will see later.
Lemma 1. Let Gz∗ be a true LCG on Ez for all z. Then, S({Gz∗ , Ez }K ′
 ∗
z=1 ) = Ep(s,a,s′ ) log p(s | s, a) − λ · E |Gz | .

Proof. First,

0 ≤ DKL (p ∥ p̂{Gz∗ ,Ez },ϕ ) = Ep(s,a,s′ ) log p(s′ | s, a) − Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz∗ , Ez }, ϕ), (13)

where the equality holds because Ep(s,a,s′ ) log p(s′ | s, a) < ∞ by Assumption 4. Therefore,

sup Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz∗ , Ez }, ϕ) ≤ Ep(s,a,s′ ) log p(s′ | s, a). (14)
ϕ

On the other hand, by Assumption 3, there exists ϕ∗ such that p = p̂{Gz∗ ,Ez },ϕ∗ . Hence,

sup Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz∗ , Ez }, ϕ) ≥ Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz∗ , Ez }, ϕ∗ ) = Ep(s,a,s′ ) log p(s′ | s, a). (15)
ϕ

By Eqs. (14) and (15), we have supϕ Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz∗ , Ez }, ϕ) = Ep(s,a,s′ ) log p(s′ | s, a). Therefore, we have
S({Gz∗ , Ez }K ′
 ∗
z=1 ) = Ep(s,a,s′ ) log p(s | s, a) − λ · E |Gz | .

Corollary 1. Let Gz∗ be a true LCG on Ez for all z. Then, |S({Gz∗ , Ez }K


z=1 )| < ∞.

Proof. By Lemma 1, S({Gz∗ , Ez }K ′


 ∗ ′
z=1 ) = Ep(s,a,s′ ) log p(s | s, a) − λ · E |Gz | . Since |Ep(s,a,s′ ) log p(s | s, a)| < ∞ by
∗ ∗ K
Assumption 4 and |Gz | ≤ N (N + M ), this concludes that |S({Gz , Ez }z=1 )| < ∞.
Lemma 2. Let Gz∗ be a true LCG on Ez for all z. Then,
X
S({Gz∗ , Ez }K K
z=1 ) − S({Gz , Ez }z=1 ) = inf DKL (p ∥ p̂{Gz ,Ez },ϕ ) + λ p(Ez )(|Gz | − |Gz∗ |). (16)
ϕ
z

Proof. First, we can rewrite the score S({Gz , Ez }K


z=1 ) as:


S({Gz , Ez }K
 
z=1 ) = sup Ep(s,a,s′ ) log p̂(s | s, a; {Gz , Ez }, ϕ) − λ · E |Gz | (17)
ϕ

= − inf −Ep(s,a,s′ ) log p̂(s′ | s, a; {Gz , Ez }, ϕ) − λ · E |Gz |


 
(18)
ϕ

= − inf DKL (p ∥ p̂{Gz ,Ez },ϕ ) + Ep(s,a,s′ ) log p(s′ | s, a) − λ · E |Gz |


 
(19)
ϕ

16
Fine-Grained Causal Dynamics Learning

The last equality holds by Assumption 4. Subtracting Eq. (19) from Lemma 1, we obtain:
X
S({Gz∗ , Ez }K K
z=1 ) − S({Gz , Ez }z=1 ) = inf DKL (p ∥ p̂{Gz ,Ez },ϕ ) + λ p(Ez )(|Gz | − |Gz∗ |).
ϕ
z

Note that |S({Gz∗ , Ez }K


z=1 )| < ∞ by Corollary 1, and thus, this score difference is well defined.

Lemma 3 (Modified from Brouillard et al. (2020), Lemma 16). If p ∈


/ FEz (Gz ), then
Z
inf pz (s, a)DKL (p(· | s, a) ∥ p̂Gz ,ϕz (· | s, a)) > 0, (20)
ϕz

where pz (s, a) := p(s, a | z) = p(s, a)/p(Ez ) for all (s, a) ∈ Ez , i.e., density function of the distribution PS×A|Ez .

Proof. First, since p̂Gz ,ϕz ∈ FEz (Gz ) for all ϕz ,


Z Z
inf pz (s, a)DKL (p(· | s, a) ∥ p̂Gz ,ϕz (· | s, a)) ≥ inf pz (s, a)DKL (p(· | s, a) ∥ f (· | s, a)). (21)
ϕz f ∈FEz (Gz )

Now, let fˆ(s′ | s, a) := j pz (s′j | P aGz (j)) for all (s, a) ∈ Ez . Then, for any f ∈ FEz (Gz ),
Q

fˆ(s′ | s, a) pz (s′j | P aGz (j))


Z Z Z X
pz (s, a) p(s′ | s, a) log = p z (s, a, s ′
) log
f (s′ | s, a) j
fj (s′j | P aGz (j))
XZ pz (s′j | P aGz (j))
= pz (s, a, s′ ) log
j
fj (s′j | P aGz (j))
XZ pz (s′j | P aGz (j))
Z
= pz (P aGz (j)) pz (s′j | P aGz (j)) log ≥ 0. (22)
j
fj (s′j | P aGz (j))

Therefore, for any f ∈ FEz (Gz ),

p(s′ | s, a) fˆ(s′ | s, a)
Z Z Z
pz (s, a)DKL (p(· | s, a) ∥ f (· | s, a)) = pz (s, a) p(s′ | s, a) log
fˆ(s′ | s, a) f (s′ | s, a)
fˆ(s′ | s, a)
Z Z Z
ˆ
= pz (s, a)DKL (p ∥ f ) + pz (s, a) p(s′ | s, a) log
f (s′ | s, a)
Z
≥ pz (s, a)DKL (p ∥ fˆ).

Here, the last inequality holds by Eq. (22). Therefore,


Z Z
inf pz (s, a)DKL (p(· | s, a) ∥ f (· | s, a)) = pz (s, a)DKL (p(· | s, a) ∥ fˆ(· | s, a)) > 0. (23)
f ∈FEz (Gz )

Here, the last inequality holds because fˆ ∈ FEz (Gz ) and p ∈


/ FEz (Gz ) and thus p ̸= fˆ. By Eqs. (21) and (23), the proof is
complete.

Theorem 1 (Identifiability of LCGs). With Assumptions 1 to 4, let {Ĝz } ∈ argmax S({Gz , Ez }K


z=1 ) for λ > 0 small enough.
Then, each Ĝz is true LCG on Ez , i.e., Ĝz = GEz .

Proof. To simplify the notation, let Gz∗ be a true LCG on Ez for all z, i.e., Gz∗ := GEz for brevity. It is enough to show that
S({Gz∗ , Ez }K K ∗
z=1 ) > S({Gz , Ez }z=1 ) if Gz ̸= Gz for some z.

17
Fine-Grained Causal Dynamics Learning

Now, by Lemma 2,
S({Gz∗ , Ez }K K
z=1 ) − S({Gz , Ez }z=1 ) (24)
X
= inf DKL (p ∥ p̂{Gz ,Ez },ϕ ) + λ p(Ez )(|Gz | − |Gz∗ |) (25)
ϕ
z
Z X
= inf p(s, a)DKL (p(· | s, a) ∥ p̂{Gz ,Ez },ϕ (· | s, a)) + λ p(Ez )(|Gz | − |Gz∗ |) (26)
ϕ
z
XZ X
= inf p(s, a)DKL (p(· | s, a) ∥ p̂Gz ,ϕz (· | s, a)) + λ p(Ez )(|Gz | − |Gz∗ |) (27)
ϕ (s,a)∈Ez
z z
X Z X
= inf p(Ez ) pz (s, a)DKL (p(· | s, a) ∥ p̂Gz ,ϕz (· | s, a)) + λ p(Ez )(|Gz | − |Gz∗ |) (28)
ϕz
z z
X X
= p(Ez ) inf DKL (pz ∥ p̂Gz ,ϕz ) + λ p(Ez )(|Gz | − |Gz∗ |) (29)
ϕz
z z
X   X

= p(Ez ) inf DKL (pz ∥ p̂Gz ,ϕz ) + λ(|Gz | − |Gz |) = p(Ez ) · Az . (30)
ϕz
z z
R
For brevity, we denote DKL (pz ∥ p̂Gz ,ϕz ) := pz (s, a)DKL (p(· | s, a) ∥ p̂Gz ,ϕz (· | s, a)) and Az := inf ϕz DKL (pz ∥
p̂Gz ,ϕz ) + λ(|Gz | − |Gz∗ |). Now, we will show that for all z ∈ [K], Az > 0 if and only if Gz∗ ̸= Gz .
Case 0: Gz∗ = Gz . Clearly, Az = 0 in this case.
Case 1: Gz∗ ⊊ Gz . Then, |Gz | > |Gz∗ | and thus Az > 0 since λ(|Gz | − |Gz∗ |) > 0.
Case 2: Gz∗ ̸⊆ Gz . In this case, there exists (i → j) ∈ Gz∗ such that (i → j) ∈/ Gz . Thus, Sj′ ⊥
⊥ Gz Xi | X \ {Xi }
and Sj′ ̸⊥
⊥Gz∗ Xi | X \ {Xi }. Therefore, Sj′ ̸⊥
⊥p Xi | X \ {Xi }, Ez by Assumption 2. Thus, p ∈
/ FEz (Gz ) and we have
inf ϕz DKL (pz ∥ p̂Gz ,ϕz ) > 0 by Lemma 3.
′ ∗ ′ ′ ∗ − ′ ∗ ′ ′
Now, we consider two subcases: (i) Gz ∈ G+ z := {G | Gz ̸⊆ G , |G | ≥ |Gz |}, and (ii) Gz ∈ Gz := {G | Gz ̸⊆ G , |G | <
∗ + −
|Gz |}. Clearly, if Gz ∈ Gz then Az > 0. Suppose Gz ∈ Gz . Then,
1
λ ≤ ηz := min inf DKL (pz ∥ p̂G ′ ,ϕz ) (31)
N (N + M ) + 1 G ′ ∈G− z
ϕz

inf ϕz DKL (pz ∥ p̂G ′ ,ϕz ) inf ϕz DKL (pz ∥ p̂G ′ ,ϕz )
=⇒ λ ≤ < for ∀G ′ ∈ G−
z (32)
N (N + M ) + 1 |Gz∗ | − |G ′ |
∗ ′ −
=⇒ inf DKL (pz ∥ p̂G ′ ,ϕz ) + λ(|Gz | − |Gz |) > 0 for ∀G ∈ Gz . (33)
ϕz

Here, we use the fact that |Gz∗ | − |G ′ | ≤ |Gz∗ | < N (N + M ) + 1. Therefore, for 0 < ∀λ ≤ ηz , we have Az > 0 if Gz∗ ̸= Gz .
Here, we note that ηz > 0 for all z, since G− ′ −
z is finite and inf ϕz DKL (pz ∥ p̂G ′ ,ϕz ) > 0 for any G ∈ Gz by Lemma 3.

Consequently, for 0 < λ ≤ η({Ez }) := minz ηz , we have S({Gz∗ , Ez }K K ∗


z=1 ) − S({Gz , Ez }z=1 ) > 0 if Gz ̸= Gz for some z.
We also note that η({Ez }) > 0 since ηz > 0 for all z.

B.3. Proof of Prop. 1


Definition 7. Let T := {{Ez }K
z=1 }, i.e., a set of all decompositions of size K.
Definition 8. Let Tλ := {{Ez }K
z=1 | η({Ez }) ≥ λ}.
Remark 1. Tλ → T (= T0 ) as λ → 0.

Recall that Thm. 1 holds for 0 < λ ≤ η({Ez }). Here, η({Ez }) is the value corresponding to the specific decomposition
{Ez }. For the arguments henceforth, we consider the arbitrary decomposition and thus introduce the following assumption.
Assumption 5. inf {Ez }∈T η({Ez }) > 0.

Note that η({Ez }) > 0 for any {Ez }, and thus inf {Ez }∈T η({Ez }) ≥ 0. We now take 0 < λ ≤ inf {Ez }∈T η({Ez }) with
Assumption 5, which allows Thm. 1 to hold on any arbitrary decomposition. It is worth noting that this assumption is purely
technical because for a small fixed λ > 0, the arguments henceforth hold for all {Ez } ∈ Tλ , where Tλ → T as λ → 0.

18
Fine-Grained Causal Dynamics Learning

Proposition 1. Let {Gz∗ , Ez∗ } ∈ argmax


 ∗  S({G
K
 z , Ez }z=1 ) for λ > 0 small enough, with Assumptions 1 to 5.KThen, (i) each
∗ ∗
Gz is true LCG on Ez , and (ii) E |Gz | ≤ E |Gz | where {Gz } are LCGs on arbitrary decomposition {Ez }z=1 .

Proof. Let 0 < λ ≤ inf {Ez }∈T η({Ez }). (i) First, {Gz∗ , Ez∗ }K K ∗ K
z=1 ∈ argmax{Gz ,Ez } S({Gz , Ez }z=1 ) implies that {Gz }z=1
∗ K ∗ K ∗ K ∗
also maximizes the score on the fixed {Ez }z=1 , i.e., {Gz }z=1 ∈ argmax{Gz } S({Gz , Ez }z=1 ). Thus, each Gz is true LCG
on Ez∗ by Thm. 1, i.e., Gz∗ = GEz∗ .
∗ ∗
(ii) Also, since {Ez }K
z=1 is the arbitrary decomposition, S({Gz , Ez }) ≥ S({Gz , Ez }) holds. Since {Gz } is the true LCGs on
each Ez , i.e., Gz = GEz , by Lemma 1,
X

S({Gz , Ez }Kz=1 ) = Ep(s,a,s′ ) log p(s | s, a) − λ p(Ez ) · |Gz |. (34)
z

Similarly, since {Gz∗ } is the true LCGs on each Ez∗ ,


X
S({Gz∗ , Ez∗ }K ′
z=1 ) = Ep(s,a,s′ ) log p(s | s, a) − λ p(Ez∗ ) · |Gz∗ |. (35)
z

Therefore, 0 ≤ S({Gz∗ , Ez∗ }) − S({Gz , Ez }) = E |Gz | − E |Gz∗ | holds, and thus E |Gz∗ | ≤ E |Gz | .
       

B.4. Proof of Thm. 2


We first provide some useful lemma.
Lemma 4 (Hwang et al. (2023), Prop. 4). Sj′ ⊥
⊥ X \ P a(j; E) | P a(j; E), F holds for any F ⊆ E.
Lemma 5 (Monotonicity). Let F ⊆ E. Then, GF ⊆ GE .

Proof. Since Sj′ ⊥


⊥ X \ P a(j; E) | P a(j; E), F holds by Lemma 4, P a(j; F) ⊆ P a(j; E) holds by definition; otherwise,
P a(j; F) \ P a(j; E) ̸= ∅ which leads to contradiction. Therefore, P a(j; F) ⊆ P a(j; E) for all j and thus GF ⊆ GE .

Now, we provide a proof of Thm. 2.


Definition 9. The context D ⊂ X is canonical if GF = GD for any F ⊂ D.
Theorem 2 (Identifiability of contexts). Let {Gz∗ , Ez∗ } ∈ argmax S({Gz , Ez }K
z=1 ) for λ > 0 small enough, with Assump-
tions 1 to 5. Suppose X = ∪m∈[H] Dm where GDm is distinct for all m ∈ [H], andSD1 , · · · , DH are disjoint and canonical.
Suppose K ≥ H. Then, for all m ∈ [H], there exists Im ⊂ [K] such that Dm = z∈Im Ez∗ almost surely.

Proof. Let {Fz }K


S
z=1 be the decomposition such that for all m ∈ [H], z∈Jm Fz = Dm for some Jm ⊂ [K]. Note that such
K
decomposition exists since K ≥ H. Let {G }
z z=1 be the true LCGs corresponding to each Fz , i.e., Gz = GFz . Recall that
E |Gz∗ | ≤ E |Gz | holds by Prop. 1, we have
   

 X X
0 ≤ E |Gz | − E |Gz∗ | = p(Ej∗ )|Gj∗ |
  
p(Fi )|Gi | −
i j
X
= p(Fi ∩ Ej∗ )(|Gi | − |Gj∗ |). (36)
i,j

Suppose p(Fi ∩ Ej∗ ) > 0 for some i, j. Let Cij := Fi ∩ Ej∗ . Since Fi ⊂ Dm for some m and Dm is canonical, Fi is also
canonical. Therefore, Gi = GCij since Cij ⊂ Fi . Since Cij ⊂ Ej∗ , we have GCij ⊆ Gj∗ by Lemma 5. Therefore, we have
Gi ⊆ Gj∗ . Therefore, |Gi | − |Gj∗ | ≤ 0 for any i, j such that p(Fi ∩ Ej∗ ) > 0. Thus, by Eq. (36), |Gi | = |Gj∗ | if p(Fi ∩ Ej∗ ) > 0.
Since Gi ⊆ Gj∗ if p(Fi ∩ Ej∗ ) > 0, we conclude that

Gi = Gj∗ if p(Fi ∩ Ej∗ ) > 0. (37)

Now, for arbitrary Ej∗ , suppose there exist s ̸= t such that p(Ds ∩ Ej∗ ) > 0 and p(Dt ∩ Ej∗ ) > 0. Then, there exist some
Fi ⊂ Ds and Fk ⊂ Dt such that p(Fi ∩ Ej∗ ) > 0 and p(Fk ∩ Ej∗ ) > 0. By Eq. (37), we have Gi = Gj∗ = Gk . Also,
GDs = Gi and GDt = Gk since Ds , Dt are canonical. Therefore, we have GDs = GDt , which contradicts that GDm is distinct

19
Fine-Grained Causal Dynamics Learning

{ , , , , } : Color : Action ?: Noisy node

Root node

{ , , }
? ?
Full

? ?
Root node

{ , }

Fork

Figure 8. Illustration of C HEMICAL (full-fork) environment with 4 nodes. (Left) the color of the root node determines the activation of
local causal graph fork. (Right) the noisy nodes are redundant for predicting the colors of other nodes under the local causal graph.

for all m. Therefore, for any Ej∗ , there exists a unique Dm such that p(Dm ∩ Ej∗ ) > 0, which leads p(Ej∗ \ Dm ) = 0 since
{Dm }m∈[H] is a decomposition of X . Let Im = {j ∈ [K] | p(Dm ∩ Ej∗ ) > 0}. Here, we have
 
[ X
Ez∗ \ Dm  = p Ez∗ \ Dm = 0.

p (38)
z∈Im z∈Im

Also, by the definition of Im and because {Ez∗ }z∈[K] is a decomposition of X , we have


 
[
p Dm \ Ez∗  = 0. (39)
z∈Im

Ez∗ almost surely for all m ∈ [H].


S
Therefore, by Eqs. (38) and (39), we have Dm = z∈Im

C. Appendix for Experiments


C.1. Environment Details

Table 4. Environment configurations. Table 5. CEM parameters.


Chemical Chemical
Magnetic Magnetic
Parameters full-fork full-chain CEM parameters full-fork full-chain
Training step 1.5 × 105 1.5 × 105 2 × 105 Planning length 3 3 1
Optimizer Adam Adam Adam Number of candidates 64 64 64
Learning rate 1e-4 1e-4 1e-4 Number of top candidates 32 32 32
Batch size 256 256 256 Number of iterations 5 5 5
Initial step 1000 1000 2000 Exploration noise N/A N/A 1e-4
Max episode length 25 25 25 Exploration probability 0.05 0.05 N/A
Action type Discrete Discrete Continuous

C.1.1. C HEMICAL
Here, we describe two settings, namely full-fork and full-chain, modified from Ke et al. (2021). In both settings, there are 10
state variables representing the color of corresponding nodes, with each color represented as a one-hot encoding. The action
variable is a 50-dimensional categorical variable that changes the color of a specific node to a new color (e.g., changing
the color of the third node to blue). According to the underlying causal graph and pre-defined conditional probability
distributions, implemented with randomly initialized neural networks, an action changes the colors of the intervened object’s
descendants as depicted in Fig. 8. As shown in Fig. 3(a), the (global) causal graph is full in both settings, and the LCG is

20
Fine-Grained Causal Dynamics Learning

(a) (b) (c)


Figure 9. (a) Causal graph of Magnetic environment. Red boxes indicate redundant edges under the non-magnetic context. (b) LCG under
the magnetic context, which is the same as global CG. (c) LCG under the non-magnetic context.

fork and chain, respectively. For example in full-fork, the LCG fork is activated according to the particular color of the root
node, as shown in Fig. 8.
In both settings, the task is to match the colors of each node to the given target. The reward function is defined as:
1 X
r= 1 [si = gi ] , (40)
|O|
i∈O

where O is a set of the indices of observable nodes, si is the current color of the i-th node, and gi is the target color of the
i-th node in this episode. Success is determined if all colors of observable nodes are the same as the target. During training,
all 10 nodes are observable, i.e., O = {0, · · · , 9}. In downstream tasks, the root color is set to induce the LCG, and the
agent receives noisy observations for a subset of nodes, aiming to match the colors of the rest of the observable nodes. As
shown in Fig. 8, noisy nodes are spurious for predicting the colors of other nodes under the LCG. Thus, the agent capable of
reasoning the fine-grained causal relationships would generalize well in downstream tasks. Note that the transition dynamics
of the environment is the same in training and downstream tasks. To create noisy observations, we use a noise sampled from
N (0, σ 2 ), similar to Wang et al. (2022), where the noise is multiplied to the one-hot encoding representing color during the
test. In our experiments, we use σ = 100.
As the root color determines the local causal graph in both settings, the root node is always observable to the agent during
the test. The root colors of the initial state and the goal state are the same, inducing the local causal graph. As the root color
can be changed by the action during the test, this may pose a challenge in evaluating the agent’s reasoning of local causal
relationships. This can be addressed by modifying the initial distribution of CEM to exclude the action on the root node and
only act on the other nodes during the test. Nevertheless, we observe that restricting the action on the root during the test
has little impact on the behavior of any model, and we find that this is because the agent rarely changes the root color as it
already matches the goal color in the initial state.

C.1.2. M AGNETIC
In this environment, there are two objects on a table, a moving ball and a box, colored either red or black, as shown in
Fig. 3(b). The red color indicates that the object is magnetic. In other words, when they are both colored red, magnetic force
will be applied and the ball will move toward the box. If one of the objects is colored black, the ball would not move since
the box has no influence on the ball.
The state consists of the color, x, y position of each object, and x, y, z position of the end-effector of the robot arm, where
the color is given as the 2-dimensional one-hot encoding. The action is a 3-dimensional vector that moves the robot arm.
The causal graph of the Magnetic environment is shown in Fig. 9(a). LCGs under magnetic and non-magnetic context are
shown in Figs. 9(b) and 9(c), respectively. The table in our setup has a width of 0.9 and a length of 0.6, with the y-axis
defined by the width and the x-axis defined by the length. For each episode, the initial positions of a moving ball and a box
are randomly sampled within the range of the table.
The task is to move the robot arm to reach the moving ball. Thus, accurately predicting the trajectory of the ball is crucial.
The reward function is defined as:
r = 1 − tanh(5 · ∥eef − g∥1 ), (41)

21
Fine-Grained Causal Dynamics Learning

where the eef ∈ R3 is the current position of the end-effector, g = (bx , by , 0.8) ∈ R3 , and (bx , by ) is the current position of
the moving ball. Success is determined if the distance is smaller than 0.05. During the test, the color of one of the objects
is black and the box is located at the position unseen during the training. Specifically, the box position is sampled from
N (0, σ 2 ) during the test. Note that the box can be located outside of the table, which never happens during the training. In
our experiments, we use σ = 100.

C.2. Experimental Details


To assess the performance of different dynamics models of the baselines and our method, we use a model predictive control
(MPC) (Camacho & Alba, 2013) which selects the actions based on the prediction of the learned dynamics model, following
prior works (Ding et al., 2022; Wang et al., 2022). Specifically, we use a cross-entropy method (CEM) (Rubinstein &
Kroese, 2004), which iteratively generates and refines action sequences through a process of sampling from a probability
distribution that is updated based on the performance of these sampled sequences, with a known reward function. We use a
random policy for the initial data collection. Environmental configurations and CEM parameters are shown in Tables 4
and 5, respectively. Most of the experiments were processed using a single NVIDIA RTX 3090. For Fig. 7, we use structural
hamming distance (SHD) for evaluation, which is a metric used to quantify the dissimilarity between two graphs based on
the number of edge additions or deletions needed to make the graphs identical (Acid & de Campos, 2003; Ramsey et al.,
2006).

C.3. Implementation of Baselines


For all methods, the dynamics model outputs the parameters of categorical distribution for discrete variables, and the
mean and standard deviation of normal distribution for continuous variables. All methods have a similar number of model
parameters for a fair comparison. Detailed parameters of each model are shown in Table 6.
MLP and Modular. MLP models the transition dynamics as p(s′ | s, a). Modular has a separate network for each state

Q
variable, i.e., j p(sj | s, a), where each network is implemented as an MLP.
GNN, NPS, and CDL. We employ publicly available source codes.4 For NPS (Goyal et al., 2021a), we search the
number of rules N ∈ {4, 15, 20}. CDL (Wang et al., Q 2022) infers the causal structure by estimating conditional mutual
information (CMI) and models the dynamics as j p(s′j | P a(j)). For CDL, we search the initial CMI threshold
ϵ ∈ {0.001, 0.002, 0.005, 0.01, 0.02} and exponential moving average (EMA) coefficient τ ∈ {0.9, 0.95, 0.99, 0.999}. As
CDL is a two-stage method, we only report their final performance.
GRADER. We implement GRADER (Ding et al., 2022) based on the code provided by the authors.5 GRADER relies on
the conditional independence test (CIT) to discover the causal structure. In Chemical, we ran the CIT for every 10 episodes,
following their default setting. We only report its performance in Chemical due to the poor scalability of the conditional
independence test in Magnetic environment, which took about 30 minutes for each test.
Oracle and NCD. For a fair comparison, we employ the same architecture for the dynamic models of Oracle, NCD, and
our method, as their main difference lies in the inference of local causal graphs (LCG). As illustrated in Fig. 10, the key
difference is that NCD (Hwang et al., 2023) performs direct inference of the LCG from each individual sample (referred
to as sample-specific inference), while our method decomposes the data domain and infers the LCGs for each subgroup
through quantization. We provide an implementation details of our method in the next subsection.

C.4. Implementation of FCDL


For our method, we use MLPs for the implementation of genc , gdec , and p̂, with configurations provided in Table 6. The
quantization encoder genc of our method or the auxiliary network of NCD shares the initial feature extraction layer with the
dynamics model p̂ as we found that it yields better performance compared to full decoupling of them.
4
https://github.com/wangzizhao/CausalDynamicsLearning
5
https://github.com/GilgameshD/GRADER

22
Fine-Grained Causal Dynamics Learning

Sample-speci c Inference (NCD)

s0
<latexit sha1_base64="rLydIPTJEtYRfoZpDbg2DKGKs/w=">AAAB6XicbVDLSgNBEOyNrxhfUY9eBoPoKeyKr2PAi8co5gHJEmYnvcmQ2dllZlYIS/7AiwdFvPpH3vwbJ8keNLGgoajqprsrSATXxnW/ncLK6tr6RnGztLW9s7tX3j9o6jhVDBssFrFqB1Sj4BIbhhuB7UQhjQKBrWB0O/VbT6g0j+WjGSfoR3QgecgZNVZ60Ke9csWtujOQZeLlpAI56r3yV7cfszRCaZigWnc8NzF+RpXhTOCk1E01JpSN6AA7lkoaofaz2aUTcmKVPgljZUsaMlN/T2Q00nocBbYzomaoF72p+J/XSU1442dcJqlByeaLwlQQE5Pp26TPFTIjxpZQpri9lbAhVZQZG07JhuAtvrxMmudV76p6eX9RqVXzOIpwBMdwBh5cQw3uoA4NYBDCM7zCmzNyXpx352PeWnDymUP4A+fzBzxjjR4=</latexit>

<latexit sha1_base64="7o47/186CuYTcMwNodu2vFO8Xng=">AAAB7XicbVDLSgNBEOyNrxhfUY9eBoMQQZZd8XUMePEYwTwgWcLsZDYZMzuzzMwKYck/ePGgiFf/x5t/4yTZgyYWNBRV3XR3hQln2njet1NYWV1b3yhulra2d3b3yvsHTS1TRWiDSC5VO8SaciZowzDDaTtRFMchp61wdDv1W09UaSbFgxknNIjxQLCIEWys1KzqM4RPe+WK53ozoGXi56QCOeq98le3L0kaU2EIx1p3fC8xQYaVYYTTSambappgMsID2rFU4JjqIJtdO0EnVumjSCpbwqCZ+nsiw7HW4zi0nTE2Q73oTcX/vE5qopsgYyJJDRVkvihKOTISTV9HfaYoMXxsCSaK2VsRGWKFibEBlWwI/uLLy6R57vpX7uX9RaXm5nEU4QiOoQo+XEMN7qAODSDwCM/wCm+OdF6cd+dj3lpw8plD+APn8wcZR44d</latexit> <latexit sha1_base64="7o47/186CuYTcMwNodu2vFO8Xng=">AAAB7XicbVDLSgNBEOyNrxhfUY9eBoMQQZZd8XUMePEYwTwgWcLsZDYZMzuzzMwKYck/ePGgiFf/x5t/4yTZgyYWNBRV3XR3hQln2njet1NYWV1b3yhulra2d3b3yvsHTS1TRWiDSC5VO8SaciZowzDDaTtRFMchp61wdDv1W09UaSbFgxknNIjxQLCIEWys1KzqM4RPe+WK53ozoGXi56QCOeq98le3L0kaU2EIx1p3fC8xQYaVYYTTSambappgMsID2rFU4JjqIJtdO0EnVumjSCpbwqCZ+nsiw7HW4zi0nTE2Q73oTcX/vE5qopsgYyJJDRVkvihKOTISTV9HfaYoMXxsCSaK2VsRGWKFibEBlWwI/uLLy6R57vpX7uX9RaXm5nEU4QiOoQo+XEMN7qAODSDwCM/wCm+OdF6cd+dj3lpw8plD+APn8wcZR44d</latexit>

(s, a) Sample-Speci c LCGs (s, a)

Auxiliary Dynamics
Network Model

Event-speci c Inference (Ours)

Quantization Codebook Event-Speci c LCGs


State or Action Variables
Masked Variables

s0
<latexit sha1_base64="rLydIPTJEtYRfoZpDbg2DKGKs/w=">AAAB6XicbVDLSgNBEOyNrxhfUY9eBoPoKeyKr2PAi8co5gHJEmYnvcmQ2dllZlYIS/7AiwdFvPpH3vwbJ8keNLGgoajqprsrSATXxnW/ncLK6tr6RnGztLW9s7tX3j9o6jhVDBssFrFqB1Sj4BIbhhuB7UQhjQKBrWB0O/VbT6g0j+WjGSfoR3QgecgZNVZ60Ke9csWtujOQZeLlpAI56r3yV7cfszRCaZigWnc8NzF+RpXhTOCk1E01JpSN6AA7lkoaofaz2aUTcmKVPgljZUsaMlN/T2Q00nocBbYzomaoF72p+J/XSU1442dcJqlByeaLwlQQE5Pp26TPFTIjxpZQpri9lbAhVZQZG07JhuAtvrxMmudV76p6eX9RqVXzOIpwBMdwBh5cQw3uoA4NYBDCM7zCmzNyXpx352PeWnDymUP4A+fzBzxjjR4=</latexit>

<latexit sha1_base64="7o47/186CuYTcMwNodu2vFO8Xng=">AAAB7XicbVDLSgNBEOyNrxhfUY9eBoMQQZZd8XUMePEYwTwgWcLsZDYZMzuzzMwKYck/ePGgiFf/x5t/4yTZgyYWNBRV3XR3hQln2njet1NYWV1b3yhulra2d3b3yvsHTS1TRWiDSC5VO8SaciZowzDDaTtRFMchp61wdDv1W09UaSbFgxknNIjxQLCIEWys1KzqM4RPe+WK53ozoGXi56QCOeq98le3L0kaU2EIx1p3fC8xQYaVYYTTSambappgMsID2rFU4JjqIJtdO0EnVumjSCpbwqCZ+nsiw7HW4zi0nTE2Q73oTcX/vE5qopsgYyJJDRVkvihKOTISTV9HfaYoMXxsCSaK2VsRGWKFibEBlWwI/uLLy6R57vpX7uX9RaXm5nEU4QiOoQo+XEMN7qAODSDwCM/wCm+OdF6cd+dj3lpw8plD+APn8wcZR44d</latexit> <latexit sha1_base64="7o47/186CuYTcMwNodu2vFO8Xng=">AAAB7XicbVDLSgNBEOyNrxhfUY9eBoMQQZZd8XUMePEYwTwgWcLsZDYZMzuzzMwKYck/ePGgiFf/x5t/4yTZgyYWNBRV3XR3hQln2njet1NYWV1b3yhulra2d3b3yvsHTS1TRWiDSC5VO8SaciZowzDDaTtRFMchp61wdDv1W09UaSbFgxknNIjxQLCIEWys1KzqM4RPe+WK53ozoGXi56QCOeq98le3L0kaU2EIx1p3fC8xQYaVYYTTSambappgMsID2rFU4JjqIJtdO0EnVumjSCpbwqCZ+nsiw7HW4zi0nTE2Q73oTcX/vE5qopsgYyJJDRVkvihKOTISTV9HfaYoMXxsCSaK2VsRGWKFibEBlWwI/uLLy6R57vpX7uX9RaXm5nEU4QiOoQo+XEMN7qAODSDwCM/wCm+OdF6cd+dj3lpw8plD+APn8wcZR44d</latexit>

(s, a) (s, a)
Decoder
Encoder

Dynamics
Model

(a) Local Causal Graph Inference (b) Masked Prediction with LCG
fi
fi fi
fi
Figure 10. Comparison of the sample-specific inference of NCD (top) and quantization-based inference of our method (bottom).

C.4.1. DYNAMICS M ODEL


(j)
Recall our dynamics modeling in Eq. (8) that p̂(s′j | P aGz (j); ϕz ) if (s, a) ∈ Ez , which corresponds to p(s′j | s, a) =
(j)
p(s′j | P a(j; Ez ), z) in Eq. (3). Here, each is a neural network that takes P aGz (j) as an input and predicts s′j under ϕz
Ez . In general, this separate network for each subgroup would allow it to effectively adapt to environments with complex
dynamics and learn transition functions separately for each subgroup. However, this requires a total of K × N separate
networks, which could incur a computational burden. Instead, we employ an efficient parameter-sharing mechanism to
simplify the model implementation: we let the dynamics model consist of separate networks for each state variable, i.e.,
(j)
ϕ = {ϕ(j) } and each ϕ(j) takes (P aGz (j), z) as an input, instead of using separate networks ϕz for each Ez , which is

analogous to p(sj | P a(j; Ez ), z). This requires a total of N separate networks, one for each state variable. There are
different implementation design choices for z in (P aGz (j), z). We consider two cases: (i) concatenation of P aGz (j) and ez
(i.e., code), and (ii) concatenation of P aGz (j) and one-hot encoding of z (dimension of K). We opt for a simpler choice of
the latter. This allows us to model (possibly) different transition functions for each subgroup with a single dynamics model
for each state variable. Note that if the subgroups having the same LCG share the same transition function, such labeling of
z could be further omitted.
(j)
For the implementation of taking P aGz (j) as input for p̂(s′j | P aGz (j); ϕz ), we simply mask out the features of unused
variables, but other design choices such as Gated Recurrent Unit (Chung et al., 2014; Ding et al., 2022) are also possible. As
architectural design is not the primary focus of this work, we leave the exploration of different architectures to future work.
Note that all baselines except MLP (e.g., GNN and causal dynamics models) use separate networks for each state variable,
and we made sure that all methods have a similar number of model parameters for a fair comparison.

C.4.2. BACKPROPAGATION
We now describe how each component of our method is updated by the training objective in Eq. (6). First, Lpred updates
the encoder genc (s, a), decoder gdec (e), and the dynamics model p̂. Recall that A ∼ gdec (e), backpropagation from A in
Lpred updates the quantization decoder gdec through e. During the backward path in Eq. (5), gradients are copied from e (=
input of gdec ) to h (= output of genc ), following VQ-VAE (Van Den Oord et al., 2017). By doing so, Lpred also updates the
quantization encoder genc and h. Second, Lquant updates genc and the codebook C. We note that Lpred also affects the
learning of the codebook C since h is updated with Lpred . The rationale behind this trick of VQ-VAE is that the gradient
∇e Lpred could guide the encoder genc to change its output h = genc (s, a) to lower the prediction loss Lpred , altering the

23
Fine-Grained Causal Dynamics Learning

(a) (b) (c) (d)

Figure 11. (a,b) Codebook histogram on (a) ID states during training and (b) OOD states during the test in Chemical (full-fork). (c) True
causal graph of the fork structure. (d) Learned LCG corresponding to the most used code in (b).

(a) (b) (c)

(d) (e) (f) (g)


Figure 12. Analysis of LCGs learned by our method with quantization degree of 4 in Chemical (full-fork) environment. (a-c) Codebook
histogram on (a) ID states, (b) ID states on local structure fork, and (c) OOD states on local structure. (d-g) Learned LCGs. The
descriptions of the histograms are also applied to Figs. 13 to 15, 17 and 18.

quantization (i.e., assignment of the cluster) in the next forward pass. A larger prediction loss (which implies that this
sample (s, a) is assigned to the wrong cluster) induces a bigger change on h, and consequently, it would be more likely to
cause a re-assignment of the cluster.

C.4.3. H YPERPARAMETERS
For all experiments, we fix the codebook size K = 16, regularization coefficient λ = 0.001, and commitment coefficient
β = 0.25, as we found that the performance did not vary much for any K > 2, λ ∈ {10−4 , 10−3 , 10−2 } and β ∈ {0.1, 0.25}.

C.5. Additional Experimental Results


C.5.1. D ETAILED ANALYSIS OF LEARNED LCG S
LCGs learned by our method with a quantization degree of 4 in Chemical are shown in Figs. 12 and 13. Among the 4 codes,
one (Fig. 12(b)) or two (Fig. 13(b)) represent the local causal structure fork. Our method successfully infers the proper
code for most of the OOD samples (Figs. 12(c) and 13(c)). Two sample runs of our method with a quantization degree of 4
in Magnetic are shown in Figs. 14 and 15. Our method successfully learns LCGs correspond to a non-magnetic context
(Figs. 14(d), 14(g), 15(d) and 15(f)) and magnetic context (Figs. 14(e), 14(f), 15(e) and 15(g)).
We also observe that our method discovers more fine-grained relationships. Recall that the non-magnetic context is
determined when one of the objects is black, the box would have no influence on the ball regardless of the color of the
box when the ball is black, and vice versa. As shown in Fig. 16, our method discovers the context where the ball is black
(Fig. 16(b)), and the context where the box is black (Fig. 16(a)).

24
Fine-Grained Causal Dynamics Learning

(a) (b) (c)

(d) (e) (f) (g)


Figure 13. Another sample run of our method with quantization degree of 4 in Chemical (full-fork).

(a) (b) (c)

(d) (e) (f) (g)


Figure 14. Analysis of LCGs learned by our method with quantization degree of 4 in Magnetic.

We observe that the training of latent codebook with vector quantization is often unstable when K = 2. We demonstrate the
success (Fig. 17) and failure (Fig. 18) cases of our method with a quantization degree of 2. In a failure case, we observe that
the embeddings frequently fluctuate between the two codes, resulting in both codes corresponding to the global causal graph
and failing to capture the LCG, as shown in Fig. 18.

C.5.2. L EARNING CURVES ON ALL DOWNSTREAM TASKS


Fig. 19 shows the learning curves on training in all environments. Figs. 4, 20 and 21 shows the learning curves on all
downstream tasks.6
6
As CDL is a two-stage method that requires searching the best threshold after the first stage training, we only report their final
performance.

25
Fine-Grained Causal Dynamics Learning

(a) (b) (c)

(d) (e) (f) (g)


Figure 15. Another sample run of our method with quantization degree of 4 in Magnetic.

(a) (b)

Figure 16. More fine-grained LCGs learned by our method with quantization degree of 16 in Magnetic.

(a) (b) (c) (d) (e)


Figure 17. Analysis of LCGs learned by our method with quantization degree of 2 in Chemical (full-fork).

D. Additional Discussions
D.1. Difference from Sample-based Inference
Sample-based inference methods, e.g., NCD (Hwang et al., 2023) for LCG or ACD (Löwe et al., 2022) for CG, can be seen
as learning causal graphs with gated edges. They learn a function that maps each sample to the adjacency matrix where
each entry is the binary variable indicating whether the corresponding edge is on or off under the current state. The critical
difference from ours is that LCGs learned from sample-based inference methods are unbounded and blackbox.
Specifically, it is hard to understand which local structures and contexts are identified since they can only be examined by
observing the inference outcome from all samples (i.e., blackbox). Also, there is no (practical or theoretical) guarantee that

26
Fine-Grained Causal Dynamics Learning

(a) (b) (c) (d) (e)


Figure 18. Failure case of our method with quantization degree of 2 in Chemical (full-fork).

it outputs the same graph from the states within the same context, since the output of the function is unbounded. In contrast,
our method learns a finite set of LCGs where the contexts are explicitly identified by latent clustering. In other words, the
outcome is bounded (infers one of the K graphs) and the contexts are more interpretable.
For the robustness of the model and principled understanding of the fine-grained structures, the practical or theoretical
guarantee and interpretability are crucial, and we demonstrate the improved robustness of our method compared to prior
sample-based inference methods. On the other hand, sample-based inference or local edge switch methods have strength in
their simple design and efficiency, and it is known that the signals from such local edge switch enhance exploration in RL
(Seitzer et al., 2021; Wang et al., 2023). For the practitioners, the choice would depend on their purpose, e.g., whether their
primary interest is on the robustness and principled understanding of the fine-grained structures.

D.2. Limitations and Future Works


Insufficient or biased data may lead to inaccurate learning of causal relationships, including both CG and LCG. Our work
explored the potential of utilizing LCGs to deal with (locally) spurious correlations arising from insufficient or biased
data in the context of MBRL. While we assumed causal sufficiency, unobserved variables may also influence the causal
relationships. These assumptions are commonly adopted in the field, yet we consider that relaxing these assumptions would
be a promising future direction. Another promising future direction is to explore an inherent structure to the quantization
that can efficiently handle a large number of contexts.

27
Fine-Grained Causal Dynamics Learning

Table 6. Parameters of each model.


Chemical
Magnetic
Models Parameters full-fork full-chain
Hidden dim 1024 1024 512
MLP
Hidden layers 3 3 4
Hidden dim 128 128 128
Modular
Hidden layers 4 4 4
Node attribute dim 256 256 256
Node network hidden dim 512 512 512
Node network hidden layers 3 3 3
GNN
Edge attribute dim 256 256 256
Edge network hidden dim 512 512 512
Edge network hidden layers 3 3 3
Number of rules 20 20 15
Cond selector dim 128 128 128
Rule embedding dim 128 128 128
Rule selector dim 128 128 128
NPS
Feature encoder hidden dim 128 128 128
Feature encoder hidden layers 2 2 2
Rule network hidden dim 128 128 128
Rule network hidden layers 3 3 3
Hidden dim 128 128 128
Hidden layers 4 4 4
CMI threshold 0.001 0.001 0.001
CMI optimization frequency 10 10 10
CDL
CMI evaluation frequency 10 10 10
CMI evaluation step size 1 1 1
CMI evaluation batch size 256 256 256
EMA discount 0.9 0.9 0.99
Feature embedding dim 128 128 N/A
Grader GRU hidden dim 128 128 N/A
Causal discovery frequency 10 10 N/A
Hidden dim 128 128 128
Oracle
Hidden layers 4 4 5
Hidden dim 128 128 128
Hidden layers 4 4 5
NCD
Auxiliary network hidden dim 128 128 128
Auxiliary network hidden layers 2 2 2
Hidden dim 128 128 128
Hidden layers 4 4 5
VQ encoder [128, 64] [128, 64] [128, 64]
Ours
VQ decoder [32] [32] [32]
Codebook size 16 16 16
Code dimension 16 16 16

28
Fine-Grained Causal Dynamics Learning

MLP GNN GRADER NCD


Modular NPS Oracle Ours

Chemical (full-fork) Chemical (full-chain) Magnetic


25
20 10
20
episode reward

episode reward

episode reward
15 15
5
10 10

5 5
0
0 0.5 1.0 1.5 0 0.5 1.0 1.5 0.5 1.0 1.5 2.0
5 5 5
Environment step (x10 ) Environment step (x10 ) Environment step (x10 )

Figure 19. Learning curves during training as measured by the episode reward.

MLP GNN GRADER NCD


Modular NPS Oracle Ours

Chemical (full-fork) (n=2) Chemical (full-fork) (n=4) Chemical (full-fork) (n=6)


20 20 20
episode reward

episode reward

episode reward
15 15 15

10 10 10

5 5 5

0 0.5 1.0 1.5 0 0.5 1.0 1.5 0 0.5 1.0 1.5


5 5 5
Environment step (x10 ) Environment step (x10 ) Environment step (x10 )
0.8 0.8 0.8

0.6 0.6 0.6


success ratio

success ratio

success ratio

0.4 0.4 0.4

0.2 0.2 0.2

0 0 0
0 0.5 1.0 1.5 0 0.5 1.0 1.5 0 0.5 1.0 1.5
5 5 5
Environment step (x10 ) Environment step (x10 ) Environment step (x10 )

Figure 20. Learning curves on downstream tasks in Chemical (full-fork) as measured on the episode reward (top) and success rate
(bottom).

MLP GNN GRADER NCD


Modular NPS Oracle Ours

Chemical (full-chain) (n=2) Chemical (full-chain) (n=4) Chemical (full-chain) (n=6)

15 15 15
episode reward

episode reward

episode reward

10 10 10

5 5 5
0 0.5 1.0 1.5 0 0.5 1.0 1.5 0 0.5 1.0 1.5
5 5 5
Environment step (x10 ) Environment step (x10 ) Environment step (x10 )
0.6 0.6 1.0

0.8
success ratio

success ratio

success ratio

0.4 0.4
0.6

0.4
0.2 0.2
0.2

0 0 0
0 0.5 1.0 1.5 0 0.5 1.0 1.5 0 0.5 1.0 1.5
5 5 5
Environment step (x10 ) Environment step (x10 ) Environment step (x10 )

Figure 21. Learning curves on downstream tasks in Chemical (full-chain) as measured on the episode reward (top) and success rate
(bottom).

29

You might also like

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy