from caml.extensions.synthetic_data import SyntheticDataGenerator
import numpy as np
Caml Synthetic Data Generator
Generate Data
= SyntheticDataGenerator(
data_generator =10_000,
n_obs=1,
n_cont_outcomes=1,
n_binary_outcomes=1,
n_cont_treatments=1,
n_binary_treatments=1,
n_discrete_treatments=1,
n_cont_confounders=1,
n_binary_confounders=1,
n_discrete_confounders=1,
n_cont_modifiers=1,
n_binary_modifiers=1,
n_discrete_modifiers=1,
n_confounding_modifiers=3,
stddev_outcome_noise=3,
stddev_treatment_noise="linear",
causal_model_functional_form=5,
n_nonlinear_transformations=10,
seed )
Simulated Dataframe
data_generator.df
W1_continuous | W2_binary | W3_discrete | X1_continuous | X2_binary | X3_discrete | T1_continuous | T2_binary | T3_discrete | Y1_continuous | Y2_binary | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.354380 | 0 | 1 | 0.432254 | 0 | 1 | -3.726761 | 0 | 0 | -20.927146 | 0 |
1 | 0.568499 | 1 | 0 | 0.812796 | 0 | 1 | -3.438992 | 1 | 4 | -9.747703 | 0 |
2 | 0.162715 | 1 | 1 | 0.874920 | 1 | 0 | -2.806232 | 0 | 0 | -17.370004 | 1 |
3 | 0.362944 | 0 | 0 | 0.411990 | 1 | 3 | -0.950123 | 1 | 2 | -2.274883 | 0 |
4 | 0.612101 | 1 | 0 | 0.444356 | 1 | 0 | -2.888419 | 1 | 4 | -0.420176 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
9995 | 0.340436 | 0 | 1 | 0.657928 | 1 | 0 | 3.182961 | 1 | 2 | 8.651162 | 0 |
9996 | 0.019523 | 1 | 0 | 0.672171 | 1 | 1 | -0.384001 | 1 | 1 | 3.081509 | 0 |
9997 | 0.325401 | 1 | 0 | 0.407528 | 1 | 3 | -7.291501 | 0 | 0 | -105.683437 | 1 |
9998 | 0.586715 | 1 | 0 | 0.542052 | 1 | 0 | -4.196167 | 0 | 4 | 0.847574 | 0 |
9999 | 0.003002 | 1 | 1 | 0.261432 | 1 | 4 | 5.563359 | 0 | 4 | 113.301968 | 0 |
10000 rows × 11 columns
True Conditional Average Treatment Effects (CATEs)
data_generator.cates
CATE_of_T1_continuous_on_Y1_continuous | CATE_of_T2_binary_on_Y1_continuous | CATE_of_T3_discrete_on_Y1_continuous_level_4_v_0 | CATE_of_T3_discrete_on_Y1_continuous_level_2_v_0 | CATE_of_T3_discrete_on_Y1_continuous_level_1_v_0 | CATE_of_T3_discrete_on_Y1_continuous_level_3_v_0 | CATE_of_T1_continuous_on_Y2_binary | CATE_of_T2_binary_on_Y2_binary | CATE_of_T3_discrete_on_Y2_binary_level_4_v_0 | CATE_of_T3_discrete_on_Y2_binary_level_2_v_0 | CATE_of_T3_discrete_on_Y2_binary_level_1_v_0 | CATE_of_T3_discrete_on_Y2_binary_level_3_v_0 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 6.431287 | 2.115750 | 5.518583 | 2.759291 | 1.379646 | 4.138937 | -0.022789 | 0.023626 | -0.722338 | -0.512221 | -0.256669 | -0.661029 |
1 | 7.137688 | 3.047866 | 7.489986 | 3.744993 | 1.872497 | 5.617490 | -0.004433 | 0.009044 | -0.814747 | -0.547820 | -0.241318 | -0.740738 |
2 | 3.589688 | 0.170688 | 14.223319 | 7.111660 | 3.555830 | 10.667489 | -0.033585 | -0.059832 | -0.829362 | -0.328208 | -0.100727 | -0.627206 |
3 | 14.923903 | 5.855224 | 7.506872 | 3.753436 | 1.876718 | 5.630154 | 0.000000 | 0.000000 | -0.687002 | -0.687002 | -0.603805 | -0.687002 |
4 | 2.790430 | -0.883957 | 11.992772 | 5.996386 | 2.998193 | 8.994579 | -0.020381 | -0.090014 | -0.744867 | -0.480709 | -0.223118 | -0.659681 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
9995 | 3.186885 | -0.360823 | 13.099186 | 6.549593 | 3.274797 | 9.824390 | -0.005274 | -0.024819 | -0.124580 | -0.119305 | -0.087765 | -0.124580 |
9996 | 7.277843 | 1.946885 | 11.733566 | 5.866783 | 2.933392 | 8.800175 | -0.055789 | -0.091764 | -0.660311 | -0.621732 | -0.426627 | -0.660311 |
9997 | 14.915621 | 5.844295 | 7.483759 | 3.741879 | 1.870940 | 5.612819 | 0.029784 | 0.113995 | -0.516854 | -0.516854 | -0.479383 | -0.516854 |
9998 | 2.971783 | -0.644656 | 12.498887 | 6.249444 | 3.124722 | 9.374165 | -0.080096 | -0.134494 | -0.733551 | -0.231604 | -0.070256 | -0.489847 |
9999 | 18.708941 | 7.759261 | 5.287498 | 2.643749 | 1.321874 | 3.965623 | 0.000000 | 0.000000 | -0.954142 | -0.948999 | -0.572786 | -0.954142 |
10000 rows × 12 columns
True Average Treatment Effects (ATEs)
data_generator.ates
Treatment | ATE | |
---|---|---|
0 | T1_continuous_on_Y1_continuous | 8.832689 |
1 | T2_binary_on_Y1_continuous | 2.723753 |
2 | T3_discrete_on_Y1_continuous_level_4_v_0 | 9.625175 |
3 | T3_discrete_on_Y1_continuous_level_2_v_0 | 4.812588 |
4 | T3_discrete_on_Y1_continuous_level_1_v_0 | 2.406294 |
5 | T3_discrete_on_Y1_continuous_level_3_v_0 | 7.218881 |
6 | T1_continuous_on_Y2_binary | -0.023382 |
7 | T2_binary_on_Y2_binary | -0.032929 |
8 | T3_discrete_on_Y2_binary_level_4_v_0 | -0.683641 |
9 | T3_discrete_on_Y2_binary_level_2_v_0 | -0.571733 |
10 | T3_discrete_on_Y2_binary_level_1_v_0 | -0.376043 |
11 | T3_discrete_on_Y2_binary_level_3_v_0 | -0.650020 |
True Data Generating Process
for k, v in data_generator.dgp.items():
print(f"DGP for {k}:")
print(v)
DGP for T1_continuous:
{'formula': '1 + W1_continuous + W2_binary + W3_discrete + X1_continuous', 'params': array([ 0.06058085, 1.16021218, -1.96203513, 0.36371696, -2.17056973]), 'noise': array([-3.62397606, -0.4328875 , 0.44179631, ..., -4.88301257,
-1.79886609, 7.66507059]), 'raw_scores': array([-3.72676077, -3.43899217, -2.80623211, ..., -7.29150054,
-4.19616712, 5.56335924]), 'function': <function SyntheticDataGenerator._create_dgp_function.<locals>.f_cont at 0x7fed767b16c0>}
DGP for T2_binary:
{'formula': '1 + W1_continuous + W2_binary + W3_discrete + X1_continuous', 'params': array([ 0.40298227, 0.46731892, -0.04221048, -1.01745299, 0.75777365]), 'noise': array([ 0.61588584, 3.23845316, -4.20232967, ..., -5.30465767,
-4.77383245, -3.11883147]), 'raw_scores': array([0.62118342, 0.98880257, 0.01598516, ..., 0.01117315, 0.0234736 ,
0.02722537]), 'function': <function SyntheticDataGenerator._create_dgp_function.<locals>.f_binary at 0x7fed75d97130>}
DGP for T3_discrete:
{'formula': '1 + W1_continuous + W2_binary + W3_discrete + X1_continuous', 'params': array([[ 0.28408456, -0.41400534, 0.89869683, -0.21225787, -0.51305644],
[-0.19632837, 0.02928109, -0.76527952, -1.08498952, -0.10759088],
[ 0.47243684, 0.23156474, -0.05579596, 0.57863612, 0.2423976 ],
[-0.388172 , -0.10787099, -0.09372564, 0.11089256, -0.00356159],
[-0.57508702, 0.26987442, 0.26917321, 0.07404752, 0.54707381]]), 'noise': array([[-5.62802687, -1.71188738, -1.6500188 , -0.21256387, -2.52126791],
[-1.93444462, 1.17883094, -3.22882543, 5.68613081, -3.1945138 ],
[ 3.71485923, -2.99880648, 5.32356089, 1.85843325, 5.28230021],
...,
[-0.31659162, 7.97101022, -0.80878524, -2.60107951, -1.99808469],
[-2.55599124, 2.33060302, -1.21559451, 1.39676363, -1.08090999],
[-0.93239336, 1.08707538, 3.1621623 , -3.38126668, -1.0282605 ]]), 'raw_scores': array([[0.2, 0.2, 0.2, 0.2, 0.2],
[0.2, 0.2, 0.2, 0.2, 0.2],
[0.2, 0.2, 0.2, 0.2, 0.2],
...,
[0.2, 0.2, 0.2, 0.2, 0.2],
[0.2, 0.2, 0.2, 0.2, 0.2],
[0.2, 0.2, 0.2, 0.2, 0.2]]), 'function': <function SyntheticDataGenerator._create_dgp_function.<locals>.f_discrete at 0x7fed75d97640>}
DGP for Y1_continuous:
{'formula': '1 + W1_continuous + W2_binary + W3_discrete + X1_continuous + X2_binary + X3_discrete + T1_continuous + T2_binary + T3_discrete + T1_continuous*X1_continuous + T1_continuous*X2_binary + T1_continuous*X3_discrete + T2_binary*X1_continuous + T2_binary*X2_binary + T2_binary*X3_discrete + T3_discrete*X1_continuous + T3_discrete*X2_binary + T3_discrete*X3_discrete', 'params': array([-1.24245918, -2.15840213, 1.21168055, 1.88903987, -1.66394546,
-2.8238776 , 2.19040025, 1.56437294, -1.2158549 , 1.1796722 ,
1.85630525, 0.40119701, 4.06451821, 2.44944857, -0.75652854,
2.2728199 , 1.29513067, 1.243022 , -0.35985222]), 'noise': array([ 1.68788239, 4.68077263, -4.52386602, ..., 0.73772796,
5.84175262, -3.42458613]), 'raw_scores': array([ -20.92714575, -9.74770328, -17.37000412, ..., -105.68343735,
0.84757356, 113.30196785]), 'function': <function SyntheticDataGenerator._create_dgp_function.<locals>.f_cont at 0x7fed75db17e0>}
DGP for Y2_binary:
{'formula': '1 + W1_continuous + W2_binary + W3_discrete + X1_continuous + X2_binary + X3_discrete + T1_continuous + T2_binary + T3_discrete + T1_continuous*X1_continuous + T1_continuous*X2_binary + T1_continuous*X3_discrete + T2_binary*X1_continuous + T2_binary*X2_binary + T2_binary*X3_discrete + T3_discrete*X1_continuous + T3_discrete*X2_binary + T3_discrete*X3_discrete', 'params': array([ 3.10830726e-01, -8.76756219e-02, 5.06943774e-01, 3.63265570e-01,
8.10885598e-02, 1.61061427e-01, -2.64294557e-04, -2.93078750e-01,
-6.41686816e-01, -2.78041378e-01, -9.20270247e-02, -1.89629801e-01,
2.13406730e-01, 5.55264734e-01, -7.14195311e-01, 5.33709717e-01,
-3.91935251e-01, -6.16656576e-01, -6.84033929e-01]), 'noise': array([-2.0490385 , 0.51909257, 0.81681174, ..., 5.05902104,
-2.13548014, -1.22267512]), 'raw_scores': array([0.28372465, 0.05238531, 0.97803581, ..., 0.99 , 0.03363826,
0.01 ]), 'function': <function SyntheticDataGenerator._create_dgp_function.<locals>.f_binary at 0x7fed75db1870>}
We can recreate the raw scores of our treatment and outcome variables too:
# Recreate Y1_continuous
= data_generator.df
df = data_generator.dgp["Y1_continuous"]
dgp
= data_generator.create_design_matrix(df, formula=dgp["formula"])
design_matrix
= dgp["params"]
params = dgp["noise"]
noise = dgp["function"]
f
= f(design_matrix, params, noise)
raw_scores
assert np.allclose(raw_scores, df["Y1_continuous"])
assert np.allclose(raw_scores, dgp["raw_scores"])
raw_scores
0 -20.927146
1 -9.747703
2 -17.370004
3 -2.274883
4 -0.420176
...
9995 8.651162
9996 3.081509
9997 -105.683437
9998 0.847574
9999 113.301968
Length: 10000, dtype: float64
For treatment variables, we get back the probabilities:
# Recreate Y2_binary
= data_generator.dgp["Y2_binary"]
dgp_bin
= data_generator.create_design_matrix(df, formula=dgp_bin["formula"])
design_matrix_bin
= dgp_bin["params"]
params_bin = dgp_bin["noise"]
noise_bin = dgp_bin["function"]
f_bin
= f_bin(design_matrix_bin, params_bin, noise_bin)
raw_scores_bin
assert np.allclose(raw_scores_bin, dgp_bin["raw_scores"])
raw_scores_bin
0 0.283725
1 0.052385
2 0.978036
3 0.010000
4 0.115042
...
9995 0.609972
9996 0.010000
9997 0.990000
9998 0.033638
9999 0.010000
Length: 10000, dtype: float64