2 Causal inference overview and course goals
Lecture notes prepared by Kaushal Paneiri
2.1 Course thesis
The goal of this course is to learn techniques of causal inference in a way that builds on students’ existing intuition and experience with generative machine learning. Moreover, we will do so using frameworks from generative machine learning, include tools for building deep neural networks. Further, when reasoning about causal inference problems, we will bias the case studies to those seen in professional environments where data scientists and machine learning engineers build and manage in-product machine learning models.
2.1.1 Causal modeling as generative ML
More specifically, this course focuses on machine learning in the following two ways.
- We will place causal inference firmly on a foundation of model-based generative machine learning. Our goal is to build machine learning systems that think in causal terms, such as confounding, interventions, and counterfactuals.
- If you peruse the causal inference literature, you will see examples similar to the A/B test example from epidemiology, econometrics, and clinical trials. This course focuses on the kinds of cases data scientists experience in professional settings, particularly in the tech industry. The focus of the tech industry is shifting towards problems where A/B test becomes more complicated and not feasible. We will cover some advanced techniques like how to deal with confounding, how to build up an online and offline learning and policy evaluation for Markov decision processes that automates testing. We will also cover a little bit of relevant literature from Game theory (Auction models) and Reinforcement learning (policy evaluation and improvement) at the end.
2.1.2 What is left out
Causal inference spans many other concepts, and we won’t be able to cover all of them. Though the concepts below are essential, they are out of scope for this course.
- Causal discovery
- Causal inference with regression models and various canonical SCM models
- Doubly-robust estimation
- Interference due to network effects (important in social network tech companies like Facebook or Twitter)
- heterogeneous treatment effects
- deep architectures for causal effect inference
- causal time series models
- algorithmic information theory approaches to causal inference
2.1.3 Examples of problems in causal inference
To properly contextualize our motivation, we start by understanding how causal inference developed as a field across domains, including economics, biology, social science, computer science, anthropology, epidemiology, statistics.
2.1.3.1 Estimation of causal effects
The problem of finding causal effects is the primary motivation of researchers in these domains. For example, in the late 80s and 90s, doctors used to prescribe Hormonal replacement therapy to old women. Experts believed that at the lower age, women have a lower risk of heart disease than men do, but as they age, after menopause, their estrogen level decline. However, after doing a large randomized trial, where women were selected randomly and given either a placebo or estrogen, the results showed that taking estrogen increases the chance of getting heart disease. Causal inference techniques are essential because the stakes are quite high.
2.1.3.2 Counterfactual reasoning with statistics
Counterfactual reasoning means observing reality, and then imagining how reality would have unfolded differently had some causal factor been different. For example, “had I broken up with my girlfriend sooner, I would be much happier today” or “had I studied harder for my SATs, I would be in a much better school today.” An example of a question from an experimental context would be “This subject took the drug, and their condition improved. What is the difference between this amount improvement and the improvement they would have seen had they taken placebo?”
Counterfactual reasoning is fundamental to how we as humans reason. However, statistical methods are generally not equipped to enable this type of logic. Your counterfactual reasoning process works with data both from actual and hypothetical realities, while your statistical procedure only has access to data from actual reality.
The same is true of cutting-edge machine learning. Intuition tells us that if we trained the most powerful deep learning methods to provide us with relationship advice based on our romantic successes and failers, something would be lacking in that advice since those counterfactual outcomes are missing from the training data.
2.1.3.3 The challenge of running experiments
In traditional statistics, randomized experiments are the gold standard for discovering the causal effect. An example of a randomized experiment is an A/B test on a new feature in an app. We randomly assign users to two groups and let one group use the feature while the other is presented with a control comparison. We then observe some key outcome, such as conversions. As we will learn, the randomization enables us to conclude the difference between the two groups is the causal effect of the feature on the conversions, because it isolates that effect from other unknown factors that are also affecting the conversions.
However, in many instances, setting up this randomization might be complicated. What if users object to not getting a feature that other users are enjoying? What if the experience of the feature and probability of conversion both depend on user-related factors, such that it is unclear how to do proper randomization? What if some users object to being the subjects of an experiment? What if it is unethical to do the experiment?
2.2 Causal modeling as an extension of generative modeling
2.2.1 Generative vs. discriminative Models
Let’s focus on supervised learning for a moment. Given a target variable Y and predictor(s) X, a discriminative model learns as much about \(P(Y| X)\) as it needs to an optimal prediction.
In contrast, generative models try to fully learn the joint distribution \(P(X, Y)\) underlying the data. We will discuss this more in later lectures. In simple words, these models can generate data that looks like real data.
We focus on generative models because they allow us to build our theories about the data-generating process into the model itself. We will see that we naturally think of this process in causal terms.
2.2.2 Model-based ML and learning to think about the data-generating process
The following is the typical checklist in training a statistical machine learning model.
- Split the data into training and test sets.
- Choose a few models from literally thousands of algorithm choices. Typically this choice is limited algorithms you are familiar with, are in vogue, or happen to be implemented in the software you have available.
- Manipulate the data until it fits your algorithm inputs and outputs.
- Evaluate the model on test data, compare to other models
- ( optional) If data doesn’t fit the algorithms modeling assumptions, manipulate the data until it does.
- (optional) If using a deep learning algorithm, search for hyperparameter settings that further optimize prediction.
This process works well. However, in this workflow, the data scientist’s time is devoted to manipulating data, hyperparameters, and often, the problem definition itself until things work.
An alternative approach is to think hard about the process that generated the data, and then explicitly building your assumptions about that process into a bespoke solution tailored to each new problem. This approach is model-based machine learning. Proponents like this approach because with an excellent model-based machine learning framework, you can create a bespoke solution to pretty much any problem, and don’t need to learn a vast number of machine learning algorithms and techniques.
Most interestingly, with the model-based machine learning approach the data scientists shifts her time from transforming her problem to fit some standard algorithm, to thinking hard about the process that generated the problem data, and then building those assumptions into the designing of the algorithm.
We’ll see in this class that when we think about the data-generating process, we think about it causally, meaning it has some ordering of causes and effects. In this class, we formalize this intuition by apply causal inference theory to model-based machine learning.
2.2.3 Note on reinforcement learning
As reinforcement learning gains in popularity amongst machine learning researchers and practitioners, many may have encountered the term “model-based” for the first time in a reinforcement learning (RL) context. Model-based RL is indeed an example of model-based machine learning.
- Model-free RL. The agent has no model of the generating process of the data it perceives in the environment; i.e., how states and actions lead to new states and rewards. It can only learn in a Pavlovian sense, relying solely upon experience.
- Mode-based RL: The agent has a model of the generating process of the data it perceives in the environment. This model enables the agent to make use not only of experience but also of model-based predictions of the consequences of particular actions it has less experience performing.
2.3 Case studies
2.3.1 From linear regression to model-based machine learning
The standard Gaussian linear regression model is represented as follows:
\[\begin{align} \epsilon &\sim \mathcal{N}(0,1)\\ Y &= \beta X + \alpha + \epsilon \end{align}\]
When we read this model specification, it is natural to think of it as predictors \(X\) generating target variable \(Y\). Indeed, the term generates feels a lot like causes here. Usually, we moderate this feeling by remembering that linear regression models only correlation, and we could just have easily regressed \(X\) on \(Y\). In this course, we learn how to formalize that feeling.
We can turn this model into a generative model by placing a marginal distribution on X.
\[\begin{align} \epsilon &\sim \mathcal{N}(0,1)\nonumber\\ X &\sim P_X\nonumber \\ Y &= \beta X + \alpha + \epsilon \nonumber \end{align}\]
At this point, we are already telling a data generating story where \(Y\) comes from \(X\) and \(\epsilon\). Now let’s expand on that story. Suppose we observe that \(Y\) is measured from some instrument, and we suppose that this instrument is adding technical noise to \(Y\). Now the regression model becomes a noise model.
\[\begin{align} \epsilon &\sim \mathcal{N}(0,1)\nonumber\\ X &\sim P_X\nonumber \\ Z &\sim P_Z \nonumber\\ Y &= \beta X + \alpha + \epsilon + Z \nonumber \end{align}\]
2.3.2 Binary classifier
The logistic regression model has the form:
\[\mathbb{E}[Y] = \texttt{logit}(\beta X + \alpha)\]
If we read this formula, it reads as Y comes from X. Of course that is not true, this model doesn’t care whether Y comes from X or vice versa, in other words, it doesn’t care how Y is generated, it merely wants to model \(P(Y=1|X)\)
In contrast, a naive Bayes classifier models P(X, Y) as P(X|Y)P(Y). P(X|Y) and P(Y) are estimated from the data, and then we use Bayes rule to find P(Y|X) and predict Y given X. P(X|Y)P(Y) is a representation of the data generating process that reads as “there is some unobserved Y, and then we observe X which came from Y.” There is nothing that forces us to apply naive Bayes only in problems where the generation of the prediction target generation of the features. Yet, this is precisely the kind of problem where this approach tends to get applied, such as spam detection. I argue that it P(X|Y)P(Y) aligns with a causal intuition that X comes from Y, and we avoid the inner cringe that comes from using naive Bayes when we suspect that Y comes from X. Causal modeling gives us a language to formalize this intuition.
2.3.3 Gaussian mixture model
The naive Bayes classifier is an example of a latent variable model. Latent variable models come with a pre-packaged data-generating story. Another example is the Gaussian mixture model.
Let’s recall the intuition with simple GMM with two Gaussians. We observe the data and realize that it is coming from two Gaussians with means \(\mu_1\), and \(\mu_2\). Let \(Z_i\) be a binary variable that says \(X_i\) belongs to one of these distributions.
The data generating process is: 1. Some probabilistic process generated \(\mu\). 2. Some Dirichlet distribution generated \(\theta\). 3. Then for each \(i\) in some range 1. a discrete distribution with parameter \(\theta\) generated \(Z_i\). 2. \(Z_i\) picks a \(\mu\) that generates \(X_i\) from a Gaussian with mean \(\mu\).
We can represent this data generative process in code quite easily. The following pseudocode generalizes from two to k mixture components.
``` function (alphas, sigma, scale): theta = random_dirichlet(alphas))
for each mixture component k:
mu[k] = random_normal(0, sigma))
for each data_point:
Z[i] = random_discrete(theta))
X[i] = random_normal(mu[Z[i]], scale))
```
Now inferring the mixture components given data using this code requires an inference algorithm. Model-based machine learning frameworks generally let you code up the model just as above and then provide implementations of algorithms that can provide inference on the model. In the next lecture, we will cover the basics of two frameworks for model-based machine learning that implement inference algorithms.
The GMM and other latent variable models like the hidden Markov model, mixed membership models like LDA, linear factor models, provide an off-the-shelf data-generating story that is straightforward to cast into code. However, just as we turned regression into a noise model, we can adjust the model and code to create a bespoke solution to a unique problem.
2.3.4 Deep generative models
Deep generative models are generative models that use deep neural network architectures. Examples include variational autoencoders and generative adversarial networks. Rather than make the data generation story explicit, their basic implementation compresses all generative process into a latent encoding. But nothing is forcing them to do so. In this course, we will see examples of deep generative models where we model the critical components of the data generating process explicitly, and let the latent encoding handle nuisance variables that we don’t care about.
2.4 Don’t worry about being wrong
Deep learning works well because they essentially infer the optimal circuit between a given input signal and output channels. In contrast, when you reason about and represent as code the data generating process, you are inferring a program given program inputs and outputs.
In machine learning, the task of automatically inferring a program is called program induction, and it is much harder than inferring a circuit. Indeed, that is an ill-specified problem, because there are numerous programs we could write to generate the same data. Algorithmic information theory tells us that the task of finding the shortest program that produces an output from a given input is an NP-hard problem.
So if program induction is hard for computers, it should surprise us that it is challenging for humans too. In practice, we make use of domain knowledge and other constraints. For example, an economist building a price model might incorporate in their understanding of supply and demand. A computational biologist may use extensive databases of verified relationships between genes in modeling.
Finally, we can still validate the model on the data using goodness-of-fit or predictive performance statistics. We can also use standard techniques for handling uncertainty in our models, such as ensemble methods.