AutoCATE

AutoCATE(
    Y,
    T,
    X=None,
    W=None,
    *,
    discrete_treatment=True,
    discrete_outcome=False,
    model_Y=None,
    model_T=None,
    model_regression=None,
    enable_categorical=False,
    n_jobs=1,
    use_ray=False,
    ray_remote_func_options_kwargs=None,
    use_spark=False,
    verbose=2,
    seed=None,
)

The AutoCATE class is a high-level API facilitating an AutoML framework for CATE estimation, built on top of the EconML library.

AutoCATE is experimental and may change significantly in future versions.

The CATE is defined as \(\mathbb{E}[\tau \mid \mathbf{X}]\) where \(\tau\) is the treatment effect and \(\mathbf{X}\) is the set of features leveraged for treatment effect heterogeneity.

This class is built on top of the EconML library and provides a high-level AutoML API for fitting, validating, and making predictions/inference with CATE models, with best practices built directly into the API. The class is designed to be easy to use and understand, while still providing flexibility for advanced users.

Note that first-stage models are estimated using a full AutoML framework via Flaml, whereas the second-stage models are currently estimated & selected based on a pre-specified set of models (or custom CATE models) passed, thus there is no tuning of hyperparameters. This and other AutoML features are on the roadmap for future versions of AutoCATE.

For technical details on the AutoCATE class, see here.

Note: All the standard assumptions of Causal Inference apply to this class (e.g., exogeneity/unconfoundedness, overlap, positivity, etc.). The class does not check for these assumptions and assumes that the user has already thought through these assumptions before using the class.

For outcome/treatment support, see matrix.

For a more detailed working example, see AutoCATE Example.

Parameters

Name Type Description Default
Y str The str representing the column name for the outcome variable. required
T str The str representing the column name(s) for the treatment variable(s). required
X Sequence[str] | None The sequence of feature names representing the feature set to be utilized for estimating heterogeneity/CATE. None
W Sequence[str] | None The sequence of feature names representing the confounder/control feature set to be utilized only for nuisance function estimation. When W is passed, only Orthogonal learners will be leveraged. None
discrete_treatment bool A boolean indicating whether the treatment is discrete/categorical or continuous. True
discrete_outcome bool A boolean indicating whether the outcome is binary or continuous. False
model_Y dict | BaseEstimator | None A dictionary of FLAML kwarg overrides or a BaseEstimator instance for the outcome model - \(\mathbb{E}[Y \mid \mathbf{X},\mathbf{W}]\). None
model_T dict | BaseEstimator | None A dictionary of FLAML kwarg overrides or a BaseEstimator instance for the treatment/propensity model - \(\mathbb{E}[T \mid \mathbf{X},\mathbf{W}]\). None
model_regression dict | BaseEstimator | None A dictionary of FLAML kwarg overrides or a BaseEstimator instance for the regression model - \(\mathbb{E}[Y \mid \mathbf{X},\mathbf{W},T]\). None
enable_categorical bool A boolean indicating whether to enable categorical encoding for the models. When set to True, pandas categorical types will be converted to ordinal encodings. For one-hot encoding, please implement it prior to passing the data to the model. False
n_jobs int The number of jobs to run in parallel for model training. 1
use_ray bool A boolean indicating whether to use Ray for distributed computing, utilized in both AutoML for first-stage models and AutoML for CATE models. n_concurrent_tasks defaults to 4 if not specified in FLAML kwarg overrides (e.g., for model_*). False
ray_remote_func_options_kwargs dict | None A dictionary of Ray remote function options for ray remote function that fits each individual CATE model. See here for options. None
use_spark bool A boolean indicating whether to use Spark for distributed computing, utilized only in AutoML for first-stage models. n_concurrent_tasks defaults to 4 if not specified in FLAML kwarg overrides (e.g., for model_*). False
verbose int The verbosity level. 0 = NOTSET, 1 = DEBUG, 2 = INFO, 3 = WARNING, 4 = ERROR, 5 = CRITICAL 2
seed int | None The seed to use for the random number generator. None

Attributes

Name Type Description
Y list[str] The list of str representing the column names for the outcome variable \(Y\).
T list[str] The list of str representing the column names for the treatment variable \(T\).
X list[str] The list of str representing the confounder/control feature set \(\mathbf{X}\) to be utilized for estimating heterogeneity/CATE and nuisance function estimation where applicable.
W list[str] | None The list of str representing the confounder/control feature set \(\mathbf{W}\) to be utilized only for nuisance function estimation, where applicable. These will be included by default in Meta-Learners.
available_estimators list[str] A list of the available CATE estimators out of the box. Validity of estimator at runtime will depend on the outcome and treatment types and be automatically selected.
model_Y BaseEstimator The selected outcome model - \(\mathbb{E}[Y \mid \mathbf{X},\mathbf{W}]\).
model_T BaseEstimator The selected treatment model - \(\mathbb{E}[T \mid \mathbf{X},\mathbf{W}]\).
model_regression BaseEstimator The selected regression model - \(\mathbb{E}[Y \mid \mathbf{X},\mathbf{W},T]\).
rscores dict[str, float] The dictionary of the Rscores on the validation set for each CATE estimator fitted during model selection.
test_results dict[str, float] | EvaluationResults The dictionary of the final test results on the test set for the best_estimator selected, if RScorer is used, otherwise EvaluationResults returned from DRTester.evaluate_all
best_estimator BaseCateEstimator The best EconML CATE estimator selected.
best_estimator_name str The name of the best EconML CATE estimator selected as passed to the AutoCateEstimator constructor.

Examples

from caml import AutoCATE
from caml.extensions.synthetic_data import SyntheticDataGenerator

data_generator = SyntheticDataGenerator(seed=10, n_cont_modifiers=1, n_cont_confounders=1)
df = data_generator.df

auto_cate = AutoCATE(
    Y="Y1_continuous",
    T="T1_binary",
    X=[c for c in df.columns if "X" in c or "W" in c],
    model_Y={"time_budget": 10},
    model_T={"time_budget": 10},
    model_regression={"time_budget": 10},
    discrete_treatment=True,
    discrete_outcome=False,
)

print(auto_cate)
[11/13/25 16:49:39] WARNING  AutoCATE is experimental and may change in future versions.           decorators.py:47
================== AutoCATE Object ==================
Outcome Variable: ['Y1_continuous']
Discrete Outcome: False
Treatment Variable: ['T1_binary']
Discrete Treatment: True
Features/Confounders for Heterogeneity (X): ['W1_continuous', 'X1_continuous']
Features/Confounders as Controls (W): []
Enable Categorical: False
n Jobs: 1
Use Ray: False
Use Spark: False
Random Seed: None

Methods

Name Description
fit Run end-to-end fitting, validation & model selection, and testing for CATE models.
refit_final Refits the final, best estimator on the provided data.
estimate_ate Calculate the average treatment effect(s) (ATE) \(\mathbb{E}[\tau(\mathbf{X})]\).
estimate_cate Calculate the Conditional Average Treatment Effects (CATE) \(\tau(\mathbf{x})\).
predict Alias for estimate_cate.
effect Alias for estimate_cate.

fit

AutoCATE.fit(
    df,
    *,
    cate_estimators=list(available_estimators.keys()),
    additional_cate_estimators=list(),
    use_cached_models=False,
    ensemble=False,
    refit_final=True,
    bootstrap_inference=False,
    n_bootstrap_samples=100,
    validation_fraction=0.2,
    test_fraction=0.1,
    rscorer_kwargs=dict(),
    n_groups_dr_tester=5,
    n_bootstrap_dr_tester=100,
)

Run end-to-end fitting, validation & model selection, and testing for CATE models.

Parameters

Name Type Description Default
df PandasConvertibleDataFrame The dataset to fit the CATE model on. Accepts a Pandas DataFrame or a compatible object with a to_pandas() or .toPandas() method. required
cate_estimators Sequence[str] The out-of-the-box CATE estimators to use. Accessible via self.available_estimators. list(available_estimators.keys())
additional_cate_estimators Sequence[AutoCateEstimator] Additional CATE estimators to use. list()
use_cached_models bool Whether to use cached first-stage/nuisance models, if previously fitted. False
ensemble bool Whether to use an ensemble of CATE estimators. False
refit_final bool Whether to refit the final CATE estimator on the entire dataset after model selection. True
bootstrap_inference bool Whether to use bootstrap inference for the final CATE estimator, when other inference methods are not available (e.g., metalearners). This can be computationally expensive. False
n_bootstrap_samples int The number of bootstrap samples to use for bootstrap inference, if bootstrap_inference is True. 100
validation_fraction float The fraction of the dataset to use for validation. 0.2
test_fraction float The fraction of the dataset to use for testing. 0.1
rscorer_kwargs dict Additional keyword arguments to pass to RScorer. dict()
n_groups_dr_tester The number of groups to use for the DRTester. 5
n_bootstrap_dr_tester The number of bootstrap samples to use for the DRTester. 100

Examples

from caml import AutoCateEstimator
from econml.dml import LinearDML

my_custom_estimator = AutoCateEstimator(name="MyCustomEstimator",estimator=LinearDML())

auto_cate.fit(
    df = df,
    cate_estimators = auto_cate.available_estimators,
    additional_cate_estimators = [my_custom_estimator],
)
                    INFO                                                                           decorators.py:93
                               ____      __  __ _                                                                  
                              / ___|__ _|  \/  | |                                                                 
                             | |   / _` | |\/| | |                                                                 
                             | |__| (_| | |  | | |___                                                              
                              \____\__,_|_|  |_|_____|                                                             
                                                                                                                   
                                                                                                                   
                    INFO     ================== AutoCATE Object ==================                      cate.py:298
                             Outcome Variable: ['Y1_continuous']                                                   
                             Discrete Outcome: False                                                               
                             Treatment Variable: ['T1_binary']                                                     
                             Discrete Treatment: True                                                              
                             Features/Confounders for Heterogeneity (X): ['W1_continuous',                         
                             'X1_continuous']                                                                      
                             Features/Confounders as Controls (W): []                                              
                             Enable Categorical: False                                                             
                             n Jobs: 1                                                                             
                             Use Ray: False                                                                        
                             Use Spark: False                                                                      
                             Random Seed: None                                                                     
                                                                                                                   
                                                                                                                   
                    INFO                                                                           decorators.py:93
                             ==============================                                                        
                             |🎯 AutoML Nuisance Functions|                                                        
                             ==============================                                                        
                                                                                                                   
                    INFO     Searching for model_Y:                                                     cate.py:678
[11/13/25 16:49:49] INFO     Best estimator: xgboost with loss 32.23692194931248 found on iteration    _base.py:144
                             160 in 8.757890224456787 seconds.                                                     
                                                                                                                   
                    INFO     Searching for model_regression:                                            cate.py:678
[11/13/25 16:49:59] INFO     Best estimator: lgbm with loss 1.1523443116910392 found on iteration 76   _base.py:144
                             in 9.57820987701416 seconds.                                                          
                                                                                                                   
                    INFO     Searching for model_T:                                                     cate.py:678
[11/13/25 16:50:10] INFO     Best estimator: extra_tree with loss 0.6690926447879088 found on          _base.py:144
                             iteration 72 in 5.334404945373535 seconds.                                            
                                                                                                                   
                    INFO     ✅ Completed.                                                         decorators.py:98
                    INFO                                                                           decorators.py:93
                             ==========================                                                            
                             |🎯 AutoML CATE Functions|                                                            
                             ==========================                                                            
                                                                                                                   
[11/13/25 16:50:16] INFO     Best Estimator: 'LinearDRLearner'                                          cate.py:837
                    INFO     Estimator RScores: {'LinearDML': np.float64(0.8645868094473952),           cate.py:838
                             'CausalForestDML': np.float64(0.8442914414351617), 'NonParamDML':                     
                             np.float64(0.8533441998160204), 'SparseLinearDML-2D':                                 
                             np.float64(0.8633935067983776), 'DRLearner':                                          
                             np.float64(0.8613833025729861), 'ForestDRLearner':                                    
                             np.float64(0.8442859340445911), 'LinearDRLearner':                                    
                             np.float64(0.864762465601705), 'SLearner': np.float64(0.8529245831246046),            
                             'TLearner': np.float64(0.8511452959516579), 'XLearner':                               
                             np.float64(0.8537925353299018), 'MyCustomEstimator':                                  
                             np.float64(0.8646918312578792)}                                                       
                    INFO     ✅ Completed.                                                         decorators.py:98
                    INFO                                                                           decorators.py:93
                             ====================                                                                  
                             |🧪 Testing Results|                                                                  
                             ====================                                                                  
                                                                                                                   
                    INFO     Discrete treatment specified. Using DRTester for final testing.            cate.py:893
                    INFO     All validation results suggest that the model has found statistically      cate.py:923
                             significant heterogeneity.                                                            
                                                                                                                   
                    INFO        treatment  blp_est  blp_se  blp_pval  qini_est  qini_se  qini_pval  \   cate.py:927
                             0          1    0.982   0.007       0.0     2.303    0.094        0.0                 
                                                                                                                   
                                autoc_est  autoc_se  autoc_pval  cal_r_squared                                     
                             0      6.194     0.297         0.0          0.977                                     
                    INFO     CALIBRATION CURVE                                                          cate.py:930

[11/13/25 16:50:17] INFO     QINI CURVE                                                                 cate.py:933

                    INFO     TOC CURVE                                                                  cate.py:936

                    INFO     ✅ Completed.                                                         decorators.py:98
                    INFO                                                                           decorators.py:93
                             ==============================                                                        
                             |🔋 Refitting Final Estimator|                                                        
                             ==============================                                                        
                                                                                                                   
                    INFO     ✅ Completed.                                                         decorators.py:98

refit_final

AutoCATE.refit_final(df, bootstrap_inference=False, n_bootstrap_samples=100)

Refits the final, best estimator on the provided data.

Parameters

Name Type Description Default
df PandasConvertibleDataFrame The data to fit the estimator on. required
bootstrap_inference bool Whether to use bootstrap inference for the final CATE estimator, when other inference methods fail. This can be computationally expensive. False

estimate_ate

AutoCATE.estimate_ate(
    df,
    *,
    effect_mode='discrete',
    T0=0,
    T1=1,
    T=None,
    return_inference=False,
    alpha=0.05,
    value=0,
)

Calculate the average treatment effect(s) (ATE) \(\mathbb{E}[\tau(\mathbf{X})]\).

This method can be used for group average treatment effects (GATEs) \(\mathbb{E}[\tau(\mathbf{X}) \mid G]\) by filtering the input df to a specific group.

Two effect modes are supported: discrete and marginal.

In discrete mode, the effect is calculated between two specific treatment levels T0 and T1 - \(\mathbb{E}[\tau(\mathbf{X}, T0, T1)]\).

In marginal mode, the effect is calculated for each observation as a gradient around their treatment levels T - \(\mathbb{E}[\partial_{\tau}(T,\mathbf{x})]\)

See EconML for more details.

Parameters

Name Type Description Default
df PandasConvertibleDataFrame The dataframe containing all of the data passed to fit method. required
effect_mode str The mode of effect calculation. Can be “marginal” or “discrete”. 'discrete'
T0 int The base treatment level when effect_mode is “discrete”. 0
T1 int The target treatment level when effect_mode is “discrete”. 1
T np.ndarray | pd.DataFrame | None The base treatment levels for each observation when effect_mode is “marginal”. None
return_inference bool Whether to return EconML inference results. False
alpha float The level of confidence in the reported interval. 0.05
value int The mean value to test under the null hypothesis. 0

Returns

Name Type Description
float | PopulationSummaryResults The average treatment effect estimate if return_inference is False. Otherwise, an instance of PopulationSummaryResults is returned.

Examples

# Return scalar ATE
auto_cate.estimate_ate(
    df = df,
    effect_mode = "discrete",
    T0 = 0,
    T1 = 1,
)
np.float64(-7.667648525594034)
# Return ATE with Inference
auto_cate.estimate_ate(
    df = df,
    effect_mode = "discrete",
    T0 = 0,
    T1 = 1,
    return_inference = True,
    alpha = 0.05,
    value = 0,
)
Uncertainty of Mean Point Estimate
mean_point stderr_mean zstat pvalue ci_mean_lower ci_mean_upper
-7.668 0.021 -361.115 0.0 -7.709 -7.626
Distribution of Point Estimate
std_point pct_point_lower pct_point_upper
8.311 -25.081 10.154
Total Variance of Point Estimate
stderr_point ci_point_lower ci_point_upper
8.311 -24.952 10.133

estimate_cate

AutoCATE.estimate_cate(
    df,
    *,
    effect_mode='discrete',
    T0=0,
    T1=1,
    T=None,
    return_inference=False,
)

Calculate the Conditional Average Treatment Effects (CATE) \(\tau(\mathbf{x})\).

Two effect modes are supported: discrete and marginal.

In discrete mode, the effect is calculated between two specific treatment levels T0 and T1 - \(\tau(\mathbf{X}, T0, T1)\).

In marginal mode, the effect is calculated for each observation as a gradient around their treatment levels - \(\partial_{\tau}(T,\mathbf{x})\)

See EconML for more details.

Parameters

Name Type Description Default
df PandasConvertibleDataFrame The dataframe containing all of the data passed to fit method. required
effect_mode str The mode of effect calculation. Can be “marginal” or “discrete”. 'discrete'
T0 int The base treatment level when effect_mode is “discrete”. 0
T1 int The target treatment level when effect_mode is “discrete”. 1
T np.ndarray | pd.DataFrame | None The base treatment levels for each observation when effect_mode is “marginal”. None
return_inference bool Whether to return EconML inference results. False

Returns

Name Type Description
np.ndarray | InferenceResults The CATE estimates as an array if return_inference is False. Otherwise, the return value is an instance of NormalInferenceResults if asymptotic inference is available for the best estimator or EmpiricalInferenceResults if not.

Examples

# Return CATEs array
auto_cate.estimate_cate(
    df = df,
    effect_mode = "discrete",
    T0 = 0,
    T1 = 1,
)[:5]
array([[ 11.59211524],
       [ -9.18703755],
       [-32.19966435],
       [  3.28953785],
       [ -5.32643696]])
# Return CATEs with Inference
inference = auto_cate.estimate_cate(
    df = df,
    effect_mode = "discrete",
    T0 = 0,
    T1 = 1,
    return_inference = True,
)

print(inference)
<econml.inference._inference.NormalInferenceResults object at 0x7fa6e8eaa5f0>

predict

AutoCATE.predict(df, **kwargs)

Alias for estimate_cate.

effect

AutoCATE.effect(df, **kwargs)

Alias for estimate_cate.

Back to top