FastOLS(self, Y, T, G=None, X=None, W=None,*, discrete_treatment=False, engine='cpu',)
FastOLS is an optimized implementation of the OLS estimator designed specifically with treatment effect estimation in mind.
FastOLS is experimental and may change significantly in future versions.
This class estimates a standard linear regression model for any number of continuous or binary outcomes and a single continuous or binary treatment, and provides estimates for the Average Treatment Effects (ATEs) and Group Average Treatment Effects (GATEs) out of the box. Additionally, methods are provided for estimating custom GATEs & Conditional Average Treatment Effects (CATEs) of individual observations, which can also be used for out-of-sample predictions. Note, this method assumes linear treatment effects and heterogeneity, which is typically sufficient when primarily concerned with ATEs and GATEs.
This class leverages JAX for fast numerical computations, which can be installed using pip install caml[jax], defaulting to NumPy if JAX is not available. For GPU acceleration, install JAX with GPU support using pip install caml[jax-gpu].
A list of group variable names. These will be the groups for which GATEs will be estimated.
None
X
Collection[str] | None
A list of covariate variable names. These will be the covariates for which heterogeneity/CATEs can be estimated.
None
W
Collection[str] | None
A list of additional covariate variable names to be used as controls. These will be the additional covariates not used for modeling heterogeneity/CATEs.
None
discrete_treatment
bool
Whether the treatment is discrete
False
engine
str
The engine to use for computation. Can be “cpu” or “gpu”. Note “gpu” requires JAX to be installed, which can be installed via pip install caml[jax-gpu].
'cpu'
Attributes
Name
Type
Description
Y
Collection[str]
A list of outcome variable names.
T
str
The treatment variable name.
G
Collection[str] | None
The list of group variable names. These will be the groups for which GATEs will be estimated.
X
Collection[str] | None
The list of variable names representing the confounder/control feature set to be utilized for estimating heterogeneity/CATEs, that are in addition to G.
W
Collection[str] | None
The list of variable names representing the confounder/control feature not utilized for estimating heterogeneity/CATEs.
discrete_treatment
bool
Whether the treatment is binary.
engine
str
The engine to use for computation. Can be “cpu” or “gpu”. Note “gpu” requires JAX to be installed, which can be installed via pip install caml[jax-gpu]
formula
str
The formula leveraged for design matrix creation via Patsy.
results
dict
A dictionary containing the results of the fitted model & estimated ATEs/GATEs.
Examples
from caml import FastOLSfrom caml.extensions.synthetic_data import SyntheticDataGeneratordata_generator = SyntheticDataGenerator(n_cont_outcomes=1, n_binary_outcomes=1, n_cont_modifiers=1, n_binary_modifiers=2, seed=10)df = data_generator.dffo_obj = FastOLS( Y=[c for c in df.columns if"Y"in c], T="T1_binary", G=[c for c in df.columns if"X"in c and ("bin"in c or"dis"in c)], X=[c for c in df.columns if"X"in c and"cont"in c], W=[c for c in df.columns if"W"in c], engine="cpu", discrete_treatment=True,)print(fo_obj)
Fits the regression model on the provided data and, optionally, estimates Average Treatment Effect(s) (ATE) and Group Average Treatment Effect(s) (GATE).
Fits the regression model on the provided data and, optionally, estimates Average Treatment Effect(s) (ATE) and Group Average Treatment Effect(s) (GATE).
If estimate_effects is True, the method estimates Average Treatment Effects (ATEs) and Group Average Treatment Effects (GATEs), based on specified G. This leverages estimate_ate method under the hood, but efficiently reuses the data and parallelizes the computation of GATEs.
Parameters
Name
Type
Description
Default
df
PandasConvertibleDataFrame
Input dataframe to fit the model on. Supported formats: pandas DataFrame, PySpark DataFrame, Polars DataFrame, or Any object with toPandas() or to_pandas() method
required
n_jobs
int
The number of jobs to use for parallel processing in the estimation of GATEs. Defaults to -1, which uses all available processors. If getting OOM errors, try setting n_jobs to a lower value.
-1
estimate_effects
bool
Whether to estimate Average Treatment Effects (ATEs) and Group Average Treatment Effects (GATEs).
True
robust_vcv
bool
Whether to use heteroskedasticity-robust (white) variance-covariance matrix and standard errors.
Estimate Average Treatment Effects (ATEs) of T on each Y from fitted model.
If the entire dataframe is provided, the function will estimate the ATE of the entire population, where the ATE, in the case of binary treatments, is formally defined as: \[
\tau = \mathbb{E}_n[\mathbf{Y}_1 - \mathbf{Y}_0]
\]
If a subset of the dataframe is provided, the function will estimate the ATE of the subset (e.g., GATEs), where the GATE, in the case of binary treatments, is formally defined as: \[
\tau = \mathbb{E}_n[\mathbf{Y}_1 - \mathbf{Y}_0|\mathbf{G}=G]
\]
Dataframe containing the data to estimate the ATEs. Supported formats: pandas DataFrame, PySpark DataFrame, Polars DataFrame, or Any object with toPandas() or to_pandas() method
required
return_results_dict
bool
If True, the function returns a dictionary containing ATEs/GATEs, standard errors, t-statistics, confidence intervals, and p-values. If False, the function returns a numpy array containing ATEs/GATEs alone.
False
group
str
Name of the group to estimate the ATEs for.
'Custom Group'
membership
str | None
Name of the membership variable to estimate the ATEs for.
None
_diff_matrix
jnp.ndarray | None = None
Private argument used in fit method.
None
Returns
Name
Type
Description
jnp.ndarray | dict
Estimated ATEs/GATEs or dictionary containing the estimated ATEs/GATEs and their standard errors, t-statistics, and p-values.
Examples
ate = fo_obj.estimate_ate(df, return_results_dict=True, group="Overall")ate
Dataframe containing the data to estimate CATEs for. Supported formats: pandas DataFrame, PySpark DataFrame, Polars DataFrame, or Any object with toPandas() or to_pandas() method
required
return_results_dict
bool
If True, the function returns a dictionary containing CATEs, standard errors, t-statistics, confidence intervals, and p-values. If False, the function returns a numpy array containing CATEs alone.
False
Returns
Name
Type
Description
jnp.ndarray | dict
CATEs or dictionary containing CATEs, standard errors, t-statistics, and p-values.
Dataframe containing the data to estimate CATEs for. Supported formats: pandas DataFrame, PySpark DataFrame, Polars DataFrame, or Any object with toPandas() or to_pandas() method
required
return_results_dict
bool
If True, the function returns a dictionary containing CATEs, standard errors, t-statistics, confidence intervals, and p-values. If False, the function returns a numpy array containing CATEs alone.
False
Returns
Name
Type
Description
jnp.ndarray | dict
CATEs or dictionary containing CATEs, standard errors, t-statistics, and p-values.
Convert treatment effects dictionary to a pandas DataFrame.
If no argument is provided, the results are constructed from internal results dictionary. This is useful default behavior. For custom treatment effects, you can pass the results generated by the estimate_ate method.
Parameters
Name
Type
Description
Default
effects
dict
Dictionary of treatment effects. If None, the results are constructed from internal results dictionary.
None
Returns
Name
Type
Description
pd.DataFrame
DataFrame of treatment effects.
Examples
fo_obj.prettify_treatment_effects()
group
membership
outcome
ate
std_err
t_stat
pval
n
n_treated
n_control
0
overall
None
Y1_continuous
3.759449
0.020121
186.844025
0.000000
10000
5060
4940
1
overall
None
Y2_binary
0.193491
0.009711
19.924061
0.000000
10000
5060
4940
2
X2_binary
0
Y1_continuous
3.584580
0.034896
102.722456
0.000000
3320
1691
1629
3
X2_binary
0
Y2_binary
0.156381
0.017070
9.161124
0.000000
3320
1691
1629
4
X2_binary
1
Y1_continuous
3.846360
0.024623
156.212722
0.000000
6680
3369
3311
5
X2_binary
1
Y2_binary
0.211934
0.011806
17.950933
0.000000
6680
3369
3311
6
X3_binary
1
Y1_continuous
4.081423
0.021434
190.419041
0.000000
8801
4451
4350
7
X3_binary
1
Y2_binary
0.208494
0.010351
20.142500
0.000000
8801
4451
4350
8
X3_binary
0
Y1_continuous
1.396072
0.058353
23.924627
0.000000
1199
609
590
9
X3_binary
0
Y2_binary
0.083364
0.028048
2.972195
0.002957
1199
609
590
## Using a custom GATEcustom_gate = fo_obj.estimate_ate(df_filtered, return_results_dict=True, group="My Custom Group")fo_obj.prettify_treatment_effects(custom_gate)