flowchart TD; X((X))-->Y((Y)); W((W))-->Y((Y)); W((W))-->T((T)); X((X))-->|"S"|T((T)); T((T))-->|"τ(X)"|Y((Y)); linkStyle 0,1,2,3,4 stroke:black,stroke-width:2px
CamlSyntheticDataGenerator
CamlSyntheticDataGenerator(self,
=10000,
n_obs=1,
n_cont_outcomes=0,
n_binary_outcomes=0,
n_cont_treatments=1,
n_binary_treatments=0,
n_discrete_treatments=2,
n_cont_confounders=0,
n_binary_confounders=0,
n_discrete_confounders=2,
n_cont_modifiers=0,
n_binary_modifiers=0,
n_discrete_modifiers=0,
n_confounding_modifiers=1.0,
stddev_outcome_noise=1.0,
stddev_treatment_noise='linear',
causal_model_functional_form=None,
n_nonlinear_transformations=None,
n_nonlinear_interactions=None,
seed )
Generate highly flexible synthetic data for use in causal inference and CaML testing.
The general form of the data generating process is:
\[ \mathbf{Y_i} = \tau (\mathbf{X_i}) \mathbf{T_i} + g(\mathbf{W_i}, \mathbf{X_i}) + \mathbf{\epsilon_i} \] \[ \mathbf{T}_i=f(\mathbf{W}_i, \mathbf{X_{i,\mathcal{S}}})+\mathbf{\eta_i} \]
where \(\mathbf{Y_i}\) are the outcome(s), \(\mathbf{T_i}\) are the treatment(s), \(\mathbf{X_i}\) are the effect modifiers (leveraged for treatment effect heterogeneity) with an optional random subset \(\mathcal{S}\) selected as confounders, \(\mathbf{W_i}\) are the confounders, \(\mathbf{\epsilon_i}\) and \(mathbf{\eta_i}\) are the error terms drawn from normal distributions with optional specified standard deviation, \(\tau\) is the CATE function, \(g\) is the linearly seperable/nuisance component of the outcome function, and \(f\) is the treatment function. Note in the case of no modifier variables, we obtain a purely partially linear model, with \(\tau\) as a constant.
For linear data generating process, \(f\) and \(g\) consist of strictly linear terms and untransformed variables. \(\tau\) consists linear interaction terms.
For nonlinear data generating process, \(f\) and \(g\) are generated via Generalized Additive Models (GAMs) with randomly selected transformations and interaction terms controlled via n_nonlinear_transformations
. \(\tau\) contains interaction terms with transformed modifiers controlled via n_nonlinear_interactions
.
As a DAG, the data generating process can be roughly represented as:
Parameters
Name | Type | Description | Default |
---|---|---|---|
n_obs | int | Number of observations. | 10000 |
n_cont_outcomes | int | Number of continuous outcomes (\(Y\)). | 1 |
n_binary_outcomes | int | Number of binary outcomes (\(Y\)). | 0 |
n_cont_treatments | int | Number of continuous treatments (\(T\)). | 0 |
n_binary_treatments | int | Number of binary treatments (\(T\)). | 1 |
n_discrete_treatments | int | Number of discrete treatments (\(T\)). | 0 |
n_cont_confounders | int | Number of continuous confounders (\(W\)). | 2 |
n_binary_confounders | int | Number of binary confounders (\(W\)). | 0 |
n_discrete_confounders | int | Number of discrete confounders (\(W\)). | 0 |
n_cont_modifiers | int | Number of continuous treatment effect modifiers (\(X\)). | 2 |
n_binary_modifiers | int | Number of binary treatment effect modifiers (\(X\)). | 0 |
n_discrete_modifiers | int | Number of discrete treatment effect modifiers (\(X\)). | 0 |
n_confounding_modifiers | int | Number of confounding treatment effect modifiers (\(X_{\mathcal{S}}\)). | 0 |
stddev_outcome_noise | float | Standard deviation of the outcome noise (\(\epsilon\)). | 1.0 |
stddev_treatment_noise | float | Standard deviation of the treatment noise (\(\eta\)). | 1.0 |
causal_model_functional_form | str | Functional form of the causal model, can be “linear” or “nonlinear”. | 'linear' |
n_nonlinear_transformations | int | None | Number of nonlinear transformations, only applies if causal_model_functional_form=“nonlinear”. | None |
n_nonlinear_interactions | int | None | Number of nonlinear interactions with treatment, introducing heterogeneity, only applies if causal_model_functional_form=“nonlinear”. | None |
seed | int | None | Random seed to use for generating the data. | None |
Attributes
Name | Type | Description |
---|---|---|
df | pandas.DataFrame | The data generated by the data generation process. |
cates | pandas.DataFrame | The true conditional average treatment effects (CATEs) of the data. |
ates | pandas.DataFrame | The true average treatment effects (ATEs) of the data. |
dgp | dict[str, pandas.DataFrame] | The true data generating processes of the treatments and outcomes. |
Examples
from caml.extensions.synthetic_data import CamlSyntheticDataGenerator
= CamlSyntheticDataGenerator(seed=10)
data_generator data_generator.df
W1_continuous | W2_continuous | X1_continuous | X2_continuous | T1_binary | Y1_continuous | |
---|---|---|---|---|---|---|
0 | 0.354380 | -3.252276 | 2.715662 | -3.578800 | 1 | -11.068941 |
1 | 0.568499 | 2.484069 | -6.402235 | -2.611815 | 0 | 4.742376 |
2 | 0.162715 | 8.842902 | 1.288770 | -3.788545 | 0 | 23.149713 |
3 | 0.362944 | -0.959538 | 1.080988 | -3.542550 | 0 | -3.390266 |
4 | 0.612101 | 1.417536 | 4.143630 | -4.112453 | 1 | -0.990706 |
... | ... | ... | ... | ... | ... | ... |
9995 | 0.340436 | 0.241095 | -6.524222 | -3.188783 | 0 | -3.946413 |
9996 | 0.019523 | 1.338152 | -2.555492 | -3.643733 | 0 | 3.952691 |
9997 | 0.325401 | 1.258659 | -3.340546 | -4.255203 | 1 | 10.588338 |
9998 | 0.586715 | 1.263264 | -2.826709 | -4.149383 | 1 | 9.652253 |
9999 | 0.003002 | 6.723381 | 1.260782 | -3.660600 | 0 | 19.840431 |
10000 rows × 6 columns
data_generator.cates
CATE_of_T1_binary_on_Y1_continuous | |
---|---|
0 | -2.908078 |
1 | 13.601348 |
2 | -0.003848 |
3 | 0.173513 |
4 | -5.154262 |
... | ... |
9995 | 14.346026 |
9996 | 7.190235 |
9997 | 9.228579 |
9998 | 8.155846 |
9999 | -0.064142 |
10000 rows × 1 columns
data_generator.ates
Treatment | ATE | |
---|---|---|
0 | T1_binary_on_Y1_continuous | 3.912029 |
for t, df in data_generator.dgp.items():
print(f"\nDGP for {t}:")
print(df)
DGP for T1_binary:
covariates params global_transformation
0 W1_continuous 1.281740 Sigmoid
1 W2_continuous -0.575915 Sigmoid
DGP for Y1_continuous:
covariates params global_transformation
0 W1_continuous -1.137974 None
1 W2_continuous 2.571330 None
2 X1_continuous 0.254323 None
3 X2_continuous -0.104112 None
4 T1_binary -0.912958 None
5 int_T1_binary_X1_continuous -1.904831 None
6 int_T1_binary_X2_continuous -0.887939 None