CamlSyntheticDataGenerator

CamlSyntheticDataGenerator(
    self,
    n_obs=10000,
    n_cont_outcomes=1,
    n_binary_outcomes=0,
    n_cont_treatments=0,
    n_binary_treatments=1,
    n_discrete_treatments=0,
    n_cont_confounders=2,
    n_binary_confounders=0,
    n_discrete_confounders=0,
    n_cont_modifiers=2,
    n_binary_modifiers=0,
    n_discrete_modifiers=0,
    n_confounding_modifiers=0,
    stddev_outcome_noise=1.0,
    stddev_treatment_noise=1.0,
    causal_model_functional_form='linear',
    n_nonlinear_transformations=None,
    n_nonlinear_interactions=None,
    seed=None,
)

Generate highly flexible synthetic data for use in causal inference and CaML testing.

The general form of the data generating process is:

\[ \mathbf{Y_i} = \tau (\mathbf{X_i}) \mathbf{T_i} + g(\mathbf{W_i}, \mathbf{X_i}) + \mathbf{\epsilon_i} \] \[ \mathbf{T}_i=f(\mathbf{W}_i, \mathbf{X_{i,\mathcal{S}}})+\mathbf{\eta_i} \]

where \(\mathbf{Y_i}\) are the outcome(s), \(\mathbf{T_i}\) are the treatment(s), \(\mathbf{X_i}\) are the effect modifiers (leveraged for treatment effect heterogeneity) with an optional random subset \(\mathcal{S}\) selected as confounders, \(\mathbf{W_i}\) are the confounders, \(\mathbf{\epsilon_i}\) and \(mathbf{\eta_i}\) are the error terms drawn from normal distributions with optional specified standard deviation, \(\tau\) is the CATE function, \(g\) is the linearly seperable/nuisance component of the outcome function, and \(f\) is the treatment function. Note in the case of no modifier variables, we obtain a purely partially linear model, with \(\tau\) as a constant.

For linear data generating process, \(f\) and \(g\) consist of strictly linear terms and untransformed variables. \(\tau\) consists linear interaction terms.

For nonlinear data generating process, \(f\) and \(g\) are generated via Generalized Additive Models (GAMs) with randomly selected transformations and interaction terms controlled via n_nonlinear_transformations. \(\tau\) contains interaction terms with transformed modifiers controlled via n_nonlinear_interactions.

As a DAG, the data generating process can be roughly represented as:

flowchart TD;
    X((X))-->Y((Y));
    W((W))-->Y((Y));
    W((W))-->T((T));
    X((X))-->|"S"|T((T));
    T((T))-->|"τ(X)"|Y((Y));

    linkStyle 0,1,2,3,4 stroke:black,stroke-width:2px

Parameters

Name Type Description Default
n_obs int Number of observations. 10000
n_cont_outcomes int Number of continuous outcomes (\(Y\)). 1
n_binary_outcomes int Number of binary outcomes (\(Y\)). 0
n_cont_treatments int Number of continuous treatments (\(T\)). 0
n_binary_treatments int Number of binary treatments (\(T\)). 1
n_discrete_treatments int Number of discrete treatments (\(T\)). 0
n_cont_confounders int Number of continuous confounders (\(W\)). 2
n_binary_confounders int Number of binary confounders (\(W\)). 0
n_discrete_confounders int Number of discrete confounders (\(W\)). 0
n_cont_modifiers int Number of continuous treatment effect modifiers (\(X\)). 2
n_binary_modifiers int Number of binary treatment effect modifiers (\(X\)). 0
n_discrete_modifiers int Number of discrete treatment effect modifiers (\(X\)). 0
n_confounding_modifiers int Number of confounding treatment effect modifiers (\(X_{\mathcal{S}}\)). 0
stddev_outcome_noise float Standard deviation of the outcome noise (\(\epsilon\)). 1.0
stddev_treatment_noise float Standard deviation of the treatment noise (\(\eta\)). 1.0
causal_model_functional_form str Functional form of the causal model, can be “linear” or “nonlinear”. 'linear'
n_nonlinear_transformations int | None Number of nonlinear transformations, only applies if causal_model_functional_form=“nonlinear”. None
n_nonlinear_interactions int | None Number of nonlinear interactions with treatment, introducing heterogeneity, only applies if causal_model_functional_form=“nonlinear”. None
seed int | None Random seed to use for generating the data. None

Attributes

Name Type Description
df pandas.DataFrame The data generated by the data generation process.
cates pandas.DataFrame The true conditional average treatment effects (CATEs) of the data.
ates pandas.DataFrame The true average treatment effects (ATEs) of the data.
dgp dict[str, pandas.DataFrame] The true data generating processes of the treatments and outcomes.

Examples

from caml.extensions.synthetic_data import CamlSyntheticDataGenerator

data_generator = CamlSyntheticDataGenerator(seed=10)
data_generator.df
W1_continuous W2_continuous X1_continuous X2_continuous T1_binary Y1_continuous
0 0.354380 -3.252276 2.715662 -3.578800 1 -11.068941
1 0.568499 2.484069 -6.402235 -2.611815 0 4.742376
2 0.162715 8.842902 1.288770 -3.788545 0 23.149713
3 0.362944 -0.959538 1.080988 -3.542550 0 -3.390266
4 0.612101 1.417536 4.143630 -4.112453 1 -0.990706
... ... ... ... ... ... ...
9995 0.340436 0.241095 -6.524222 -3.188783 0 -3.946413
9996 0.019523 1.338152 -2.555492 -3.643733 0 3.952691
9997 0.325401 1.258659 -3.340546 -4.255203 1 10.588338
9998 0.586715 1.263264 -2.826709 -4.149383 1 9.652253
9999 0.003002 6.723381 1.260782 -3.660600 0 19.840431

10000 rows × 6 columns

data_generator.cates
CATE_of_T1_binary_on_Y1_continuous
0 -2.908078
1 13.601348
2 -0.003848
3 0.173513
4 -5.154262
... ...
9995 14.346026
9996 7.190235
9997 9.228579
9998 8.155846
9999 -0.064142

10000 rows × 1 columns

data_generator.ates
Treatment ATE
0 T1_binary_on_Y1_continuous 3.912029
for t, df in data_generator.dgp.items():
    print(f"\nDGP for {t}:")
    print(df)

DGP for T1_binary:
      covariates    params global_transformation
0  W1_continuous  1.281740               Sigmoid
1  W2_continuous -0.575915               Sigmoid

DGP for Y1_continuous:
                    covariates    params global_transformation
0                W1_continuous -1.137974                  None
1                W2_continuous  2.571330                  None
2                X1_continuous  0.254323                  None
3                X2_continuous -0.104112                  None
4                    T1_binary -0.912958                  None
5  int_T1_binary_X1_continuous -1.904831                  None
6  int_T1_binary_X2_continuous -0.887939                  None
Back to top