FastOLS

FastOLS(
    self,
    Y,
    T,
    G=None,
    X=None,
    W=None,
    *,
    discrete_treatment=False,
    engine='cpu',
)

FastOLS is an optimized implementation of the OLS estimator designed specifically with treatment effect estimation in mind.

FastOLS is experimental and may change significantly in future versions.

This class estimates a standard linear regression model for any number of continuous or binary outcomes and a single continuous or binary treatment, and provides estimates for the Average Treatment Effects (ATEs) and Group Average Treatment Effects (GATEs) out of the box. Additionally, methods are provided for estimating custom GATEs & Conditional Average Treatment Effects (CATEs) of individual observations, which can also be used for out-of-sample predictions. Note, this method assumes linear treatment effects and heterogeneity, which is typically sufficient when primarily concerned with ATEs and GATEs.

This class leverages JAX for fast numerical computations, which can be installed using pip install caml[jax], defaulting to NumPy if JAX is not available. For GPU acceleration, install JAX with GPU support using pip install caml[jax-gpu].

For outcome/treatment support, see Support Matrix.

For model specification details, see Model Specifications.

For a more detailed working example, see FastOLS Example.

Parameters

Name Type Description Default
Y Collection[str] A list of outcome variable names. required
T str The treatment variable name. required
G Collection[str] | None A list of group variable names. These will be the groups for which GATEs will be estimated. None
X Collection[str] | None A list of covariate variable names. These will be the covariates for which heterogeneity/CATEs can be estimated. None
W Collection[str] | None A list of additional covariate variable names to be used as controls. These will be the additional covariates not used for modeling heterogeneity/CATEs. None
discrete_treatment bool Whether the treatment is discrete False
engine str The engine to use for computation. Can be “cpu” or “gpu”. Note “gpu” requires JAX to be installed, which can be installed via pip install caml[jax-gpu]. 'cpu'

Attributes

Name Type Description
Y Collection[str] A list of outcome variable names.
T str The treatment variable name.
G Collection[str] | None The list of group variable names. These will be the groups for which GATEs will be estimated.
X Collection[str] | None The list of variable names representing the confounder/control feature set to be utilized for estimating heterogeneity/CATEs, that are in addition to G.
W Collection[str] | None The list of variable names representing the confounder/control feature not utilized for estimating heterogeneity/CATEs.
discrete_treatment bool Whether the treatment is binary.
engine str The engine to use for computation. Can be “cpu” or “gpu”. Note “gpu” requires JAX to be installed, which can be installed via pip install caml[jax-gpu]
formula str The formula leveraged for design matrix creation via Patsy.
results dict A dictionary containing the results of the fitted model & estimated ATEs/GATEs.

Examples

from caml import FastOLS
from caml.extensions.synthetic_data import SyntheticDataGenerator

data_generator = SyntheticDataGenerator(n_cont_outcomes=1,
                                            n_binary_outcomes=1,
                                            n_cont_modifiers=1,
                                            n_binary_modifiers=2,
                                            seed=10)
df = data_generator.df

fo_obj = FastOLS(
    Y=[c for c in df.columns if "Y" in c],
    T="T1_binary",
    G=[c for c in df.columns if "X" in c and ("bin" in c or "dis" in c)],
    X=[c for c in df.columns if "X" in c and "cont" in c],
    W=[c for c in df.columns if "W" in c],
    engine="cpu",
    discrete_treatment=True,
)

print(fo_obj)
================== FastOLS Object ==================
Engine: cpu
Outcome Variable: ['Y1_continuous', 'Y2_binary']
Treatment Variable: T1_binary
Discrete Treatment: True
Group Variables: ['X2_binary', 'X3_binary']
Features/Confounders for Heterogeneity (X): ['X1_continuous']
Features/Confounders as Controls (W): []
Formula: Y1_continuous + Y2_binary ~ C(T1_binary) + C(X2_binary)*C(T1_binary) + C(X3_binary)*C(T1_binary) + X1_continuous*C(T1_binary)

Methods

Name Description
fit Fits the regression model on the provided data and, optionally, estimates Average Treatment Effect(s) (ATE) and Group Average Treatment Effect(s) (GATE).
estimate_ate Estimate Average Treatment Effects (ATEs) of T on each Y from fitted model.
estimate_cate Estimate Conditional Average Treatment Effects (CATEs) for all given observations in the dataset.
predict Alias for estimate_cate.
prettify_treatment_effects Convert treatment effects dictionary to a pandas DataFrame.

fit

FastOLS.fit(df, *, n_jobs=-1, estimate_effects=True, robust_vcv=False)

Fits the regression model on the provided data and, optionally, estimates Average Treatment Effect(s) (ATE) and Group Average Treatment Effect(s) (GATE).

If estimate_effects is True, the method estimates Average Treatment Effects (ATEs) and Group Average Treatment Effects (GATEs), based on specified G. This leverages estimate_ate method under the hood, but efficiently reuses the data and parallelizes the computation of GATEs.

Parameters

Name Type Description Default
df PandasConvertibleDataFrame Input dataframe to fit the model on. Supported formats: pandas DataFrame, PySpark DataFrame, Polars DataFrame, or Any object with toPandas() or to_pandas() method required
n_jobs int The number of jobs to use for parallel processing in the estimation of GATEs. Defaults to -1, which uses all available processors. If getting OOM errors, try setting n_jobs to a lower value. -1
estimate_effects bool Whether to estimate Average Treatment Effects (ATEs) and Group Average Treatment Effects (GATEs). True
robust_vcv bool Whether to use heteroskedasticity-robust (white) variance-covariance matrix and standard errors. False

Examples

fo_obj.fit(df, n_jobs=4, estimate_effects=True, robust_vcv=True)

fo_obj.results.keys()
dict_keys(['params', 'vcv', 'std_err', 'treatment_effects'])

estimate_ate

FastOLS.estimate_ate(
    df,
    *,
    return_results_dict=False,
    group='Custom Group',
    membership=None,
    _diff_matrix=None,
)

Estimate Average Treatment Effects (ATEs) of T on each Y from fitted model.

If the entire dataframe is provided, the function will estimate the ATE of the entire population, where the ATE, in the case of binary treatments, is formally defined as: \[ \tau = \mathbb{E}_n[\mathbf{Y}_1 - \mathbf{Y}_0] \]

If a subset of the dataframe is provided, the function will estimate the ATE of the subset (e.g., GATEs), where the GATE, in the case of binary treatments, is formally defined as: \[ \tau = \mathbb{E}_n[\mathbf{Y}_1 - \mathbf{Y}_0|\mathbf{G}=G] \]

For more details on treatment effect estimation, see Model Specifications.

Parameters

Name Type Description Default
df PandasConvertibleDataFrame Dataframe containing the data to estimate the ATEs. Supported formats: pandas DataFrame, PySpark DataFrame, Polars DataFrame, or Any object with toPandas() or to_pandas() method required
return_results_dict bool If True, the function returns a dictionary containing ATEs/GATEs, standard errors, t-statistics, confidence intervals, and p-values. If False, the function returns a numpy array containing ATEs/GATEs alone. False
group str Name of the group to estimate the ATEs for. 'Custom Group'
membership str | None Name of the membership variable to estimate the ATEs for. None
_diff_matrix jnp.ndarray | None = None Private argument used in fit method. None

Returns

Name Type Description
jnp.ndarray | dict Estimated ATEs/GATEs or dictionary containing the estimated ATEs/GATEs and their standard errors, t-statistics, and p-values.

Examples

ate = fo_obj.estimate_ate(df, return_results_dict=True, group="Overall")

ate
{'Overall': {'outcome': ['Y1_continuous', 'Y2_binary'],
  'ate': array([[3.75944943, 0.19349086]]),
  'std_err': array([[0.02012079, 0.00971142]]),
  't_stat': array([[186.84402522,  19.92406093]]),
  'pval': array([[0., 0.]]),
  'n': 10000,
  'n_treated': 5060,
  'n_control': 4940}}
df_filtered = df.query(
    "X3_binary == 0 & X1_continuous < 5"
).copy()

custom_gate = fo_obj.estimate_ate(df_filtered)

custom_gate
array([[1.39607152, 0.08336419]])

estimate_cate

FastOLS.estimate_cate(df, *, return_results_dict=False)

Estimate Conditional Average Treatment Effects (CATEs) for all given observations in the dataset.

The CATE, in the case of binary treatments, is formally defined as: \[ \tau = \mathbb{E}_n[\mathbf{Y}_1 - \mathbf{Y}_0|\mathbf{Q}=Q] \]

For more details on treatment effect estimation, see Model Specifications.

Parameters

Name Type Description Default
df PandasConvertibleDataFrame Dataframe containing the data to estimate CATEs for. Supported formats: pandas DataFrame, PySpark DataFrame, Polars DataFrame, or Any object with toPandas() or to_pandas() method required
return_results_dict bool If True, the function returns a dictionary containing CATEs, standard errors, t-statistics, confidence intervals, and p-values. If False, the function returns a numpy array containing CATEs alone. False

Returns

Name Type Description
jnp.ndarray | dict CATEs or dictionary containing CATEs, standard errors, t-statistics, and p-values.

Examples

cates = fo_obj.estimate_cate(df)
cates[:5]
array([[3.89639511, 0.17159235],
       [3.74635426, 0.23789582],
       [4.5283798 , 0.21765926],
       [3.87988946, 0.17201947],
       [3.66232416, 0.24007028]])
res = fo_obj.estimate_cate(df, return_results_dict=True)
res.keys()
dict_keys(['outcome', 'cate', 'std_err', 't_stat', 'pval'])

predict

FastOLS.predict(df, *, return_results_dict=False)

Alias for estimate_cate.

Parameters

Name Type Description Default
df PandasConvertibleDataFrame Dataframe containing the data to estimate CATEs for. Supported formats: pandas DataFrame, PySpark DataFrame, Polars DataFrame, or Any object with toPandas() or to_pandas() method required
return_results_dict bool If True, the function returns a dictionary containing CATEs, standard errors, t-statistics, confidence intervals, and p-values. If False, the function returns a numpy array containing CATEs alone. False

Returns

Name Type Description
jnp.ndarray | dict CATEs or dictionary containing CATEs, standard errors, t-statistics, and p-values.

Examples

cates = fo_obj.predict(df)
cates[:5]
array([[3.89639511, 0.17159235],
       [3.74635426, 0.23789582],
       [4.5283798 , 0.21765926],
       [3.87988946, 0.17201947],
       [3.66232416, 0.24007028]])
res = fo_obj.predict(df, return_results_dict=True)
res.keys()
dict_keys(['outcome', 'cate', 'std_err', 't_stat', 'pval'])

prettify_treatment_effects

FastOLS.prettify_treatment_effects(effects=None)

Convert treatment effects dictionary to a pandas DataFrame.

If no argument is provided, the results are constructed from internal results dictionary. This is useful default behavior. For custom treatment effects, you can pass the results generated by the estimate_ate method.

Parameters

Name Type Description Default
effects dict Dictionary of treatment effects. If None, the results are constructed from internal results dictionary. None

Returns

Name Type Description
pd.DataFrame DataFrame of treatment effects.

Examples

fo_obj.prettify_treatment_effects()
group membership outcome ate std_err t_stat pval n n_treated n_control
0 overall None Y1_continuous 3.759449 0.020121 186.844025 0.000000 10000 5060 4940
1 overall None Y2_binary 0.193491 0.009711 19.924061 0.000000 10000 5060 4940
2 X2_binary 0 Y1_continuous 3.584580 0.034896 102.722456 0.000000 3320 1691 1629
3 X2_binary 0 Y2_binary 0.156381 0.017070 9.161124 0.000000 3320 1691 1629
4 X2_binary 1 Y1_continuous 3.846360 0.024623 156.212722 0.000000 6680 3369 3311
5 X2_binary 1 Y2_binary 0.211934 0.011806 17.950933 0.000000 6680 3369 3311
6 X3_binary 1 Y1_continuous 4.081423 0.021434 190.419041 0.000000 8801 4451 4350
7 X3_binary 1 Y2_binary 0.208494 0.010351 20.142500 0.000000 8801 4451 4350
8 X3_binary 0 Y1_continuous 1.396072 0.058353 23.924627 0.000000 1199 609 590
9 X3_binary 0 Y2_binary 0.083364 0.028048 2.972195 0.002957 1199 609 590
## Using a custom GATE
custom_gate = fo_obj.estimate_ate(df_filtered, return_results_dict=True, group="My Custom Group")
fo_obj.prettify_treatment_effects(custom_gate)
group membership outcome ate std_err t_stat pval n n_treated n_control
0 My Custom Group None Y1_continuous 1.396072 0.058353 23.924627 0.000000 1199 609 590
1 My Custom Group None Y2_binary 0.083364 0.028048 2.972195 0.002957 1199 609 590
Back to top