| Title: | Functional Machine Learning Framework |
|---|---|
| Description: | A compact and explicit machine learning framework for supervised learning, resampling-based evaluation, hyperparameter tuning, learner comparison, interpretation, and plug-in g-computation. The package uses standard formulas for model specification and provides stable S3 interfaces for fitting, evaluation, tuning, interpretation, and causal estimation across a learner registry with multiple backend engines. Implemented interpretation methods build on established approaches such as permutation-based variable importance, partial dependence, individual conditional expectation, accumulated local effects, SHAP, and LIME; see Friedman (2001) <doi:10.1214/aos/1013203451>, Goldstein et al. (2015) <doi:10.1080/10618600.2014.907095>, Apley and Zhu (2020) <doi:10.1111/rssb.12377>, Lundberg and Lee (2017) <doi:10.48550/arXiv.1705.07874>, and Ribeiro et al. (2016) <doi:10.48550/arXiv.1602.04938>. The framework is intentionally opinionated: preprocessing is expected to occur outside the modeling step, and the API emphasizes explicit inputs, consistent object contracts, and compact interfaces rather than feature-by-feature competition with larger machine learning ecosystems. |
| Authors: | Imad El Badisy |
| Maintainer: | Imad El Badisy <[email protected]> |
| License: | GPL-3 |
| Version: | 0.7.2 |
| Built: | 2026-05-14 10:19:29 UTC |
| Source: | https://github.com/ielbadisy/funcml |
A classification dataset on arthritis status and related demographic and behavioral covariates.
arthritisarthritis
A data frame with 4,856 rows and 12 variables:
Participant identifier.
Arthritis status ("Yes" or "No").
Whether a relative had a heart attack.
Participant gender.
Participant age in years.
Body mass index.
Whether the participant has diabetes.
Whether the participant reports alcohol use.
Whether the participant smokes.
Whether the participant has prehypertension.
Whether the participant follows a vegetarian diet.
Whether the participant has health coverage.
Column names were standardized to snake_case when packaging the data.
Original arthritis survey dataset distributed with the project materials.
str(funcml::arthritis) table(funcml::arthritis$status)str(funcml::arthritis) table(funcml::arthritis$status)
A classification dataset for maternal health risk level with vital signs, diabetes history, and related clinical indicators.
bangladeshmaternalriskbangladeshmaternalrisk
A data frame with 1,205 rows and 12 variables:
Maternal age in years.
Systolic blood pressure.
Diastolic blood pressure.
Blood sugar measurement.
Body temperature.
Body mass index.
Indicator for previous pregnancy complications.
Indicator for preexisting diabetes.
Indicator for gestational diabetes.
Indicator for mental health concerns.
Heart rate.
Maternal risk level outcome.
Column names were standardized to snake_case when packaging the data.
Mojumdar MU, Sarker D, Assaduzzaman M, et al. (2025). Maternal health risk factors dataset: Clinical parameters and insights from rural Bangladesh. Data in Brief, 59(Suppl 2), 111363. doi:10.1016/j.dib.2025.111363.
str(funcml::bangladeshmaternalrisk) table(funcml::bangladeshmaternalrisk$risk_level)str(funcml::bangladeshmaternalrisk) table(funcml::bangladeshmaternalrisk$risk_level)
A regression-oriented birth weight dataset with maternal risk factors and a derived low-birth-weight indicator.
birthweightbirthweight
A data frame with 189 rows and 10 variables:
Maternal age in years.
Maternal weight at the last menstrual period.
Maternal race code.
Smoking status indicator.
Number of previous premature labors.
History of hypertension indicator.
Presence of uterine irritability indicator.
Number of physician visits in the first trimester.
Birth weight in grams.
Low-birth-weight outcome indicator.
Hosmer DW, Lemeshow S (1989). Applied Logistic Regression. Wiley.
The packaged data are a lightly renamed version of the classic
MASS::birthwt dataset.
str(funcml::birthweight) summary(funcml::birthweight$birth_weight_g)str(funcml::birthweight) summary(funcml::birthweight$birth_weight_g)
A binary classification dataset for breast cancer diagnosis using tumor morphology measurements.
breastcancerdiagnosticbreastcancerdiagnostic
A data frame with 569 rows and 31 variables:
Mean radius.
Mean texture.
Mean perimeter.
Mean area.
Mean smoothness.
Mean compactness.
Mean concavity.
Mean number of concave points.
Mean symmetry.
Mean fractal dimension.
Radius standard error.
Texture standard error.
Perimeter standard error.
Area standard error.
Smoothness standard error.
Compactness standard error.
Concavity standard error.
Concave points standard error.
Symmetry standard error.
Fractal dimension standard error.
Worst radius.
Worst texture.
Worst perimeter.
Worst area.
Worst smoothness.
Worst compactness.
Worst concavity.
Worst number of concave points.
Worst symmetry.
Worst fractal dimension.
Diagnosis outcome ("B" = benign, "M" = malignant).
Column names were standardized to snake_case when packaging the data.
Breast Cancer Wisconsin Diagnostic Dataset from the UCI Machine
Learning Repository, packaged in dslabs::brca.
str(funcml::breastcancerdiagnostic) table(funcml::breastcancerdiagnostic$diagnosis)str(funcml::breastcancerdiagnostic) table(funcml::breastcancerdiagnostic$diagnosis)
A binary classification dataset for breast cancer diagnosis from cytology measurements.
breastcancerwisconsinbreastcancerwisconsin
A data frame with 699 rows and 10 variables:
Clump thickness score.
Uniformity of cell size score.
Uniformity of cell shape score.
Marginal adhesion score.
Single epithelial cell size score.
Bare nuclei score.
Bland chromatin score.
Normal nucleoli score.
Mitoses score.
Diagnostic class (2 = benign, 4 = malignant).
Wisconsin Breast Cancer Database from University of Wisconsin Hospitals, distributed through the UCI Machine Learning Repository.
str(funcml::breastcancerwisconsin) table(funcml::breastcancerwisconsin$class)str(funcml::breastcancerwisconsin) table(funcml::breastcancerwisconsin$class)
A binary classification dataset on cancer remission status using leukemia index and treatment group indicators.
cancerremissioncancerremission
A data frame with 27 rows and 3 variables:
Leukemia index measurement.
Treatment group indicator.
Remission outcome indicator (0 = no remission,
1 = remission).
Column names were standardized to snake_case when packaging the data.
Davison AC, Hinkley DV (1997). Bootstrap Methods and Their
Application. Cambridge University Press. The packaged data are from
boot::remission.
str(funcml::cancerremission) table(funcml::cancerremission$remission)str(funcml::cancerremission) table(funcml::cancerremission$remission)
A regression dataset relating baseline CD4 counts to one-year follow-up CD4 measurements in HIV-positive patients.
cd4countscd4counts
A data frame with 20 rows and 2 variables:
Baseline CD4 count.
One-year follow-up CD4 count.
Davison AC, Hinkley DV (1997). Bootstrap Methods and Their
Application. Cambridge University Press. The packaged data are from
boot::cd4.
str(funcml::cd4counts) summary(funcml::cd4counts$oneyear)str(funcml::cd4counts) summary(funcml::cd4counts$oneyear)
A classification-oriented survey dataset on smoking exposure, tobacco use, and tobacco-related environments among youth respondents.
cigsmokecigsmoke
A data frame with 3,915 rows and 27 variables:
Survey final weight.
Age group.
Gender.
Personal spending money category.
Parental work status.
Father's education level.
Mother's education level.
Living environment.
Age at first cigarette.
Cigar use indicator.
Non-cigarette tobacco use indicator.
Smokeless tobacco use indicator.
Parental smoking exposure.
Friends' smoking exposure.
Secondhand smoke exposure at home.
Secondhand smoke exposure outside the home.
Indoor smoking ban indicator.
Outdoor smoking ban indicator.
Exposure to antitobacco media.
Exposure to school antitobacco education.
Exposure to tobacco media.
Whether free tobacco was offered.
Ownership of tobacco-branded items.
Knowledge that tobacco is harmful.
Electronic cigarette use indicator.
Survey stratum identifier.
Primary sampling unit identifier.
Column names were standardized to snake_case when packaging the data.
Morocco Global Youth Tobacco Survey public-use survey data.
Kim N, Loh WY, McCarthy DE (2021). Machine learning models of tobacco susceptibility and current use among adolescents from 97 countries in the Global Youth Tobacco Survey, 2013-2017. PLOS Global Public Health, 1(12), e0000060. doi:10.1371/journal.pgph.0000060.
str(funcml::cigsmoke) table(funcml::cigsmoke$e_cig)str(funcml::cigsmoke) table(funcml::cigsmoke$e_cig)
Compare multiple learners with optional tuning.
compare_learners( data, formula, models, specs = NULL, resampling = cv(5), metrics = NULL, type = NULL, conf_level = 0.95, seed = NULL, ncores = NULL, tune = FALSE, grids = NULL, metric = NULL, ... )compare_learners( data, formula, models, specs = NULL, resampling = cv(5), metrics = NULL, type = NULL, conf_level = 0.95, seed = NULL, ncores = NULL, tune = FALSE, grids = NULL, metric = NULL, ... )
data |
Data frame. |
formula |
Model formula. |
models |
Character vector of learner ids. |
specs |
Optional named list of fixed specs per learner. |
resampling |
Resampling object from |
metrics |
Character vector of metrics to report. When |
type |
Prediction type override. |
conf_level |
Confidence level for learner summary intervals. |
seed |
Optional seed. |
ncores |
Optional number of CPU cores used to compare learners. |
tune |
Logical; if |
grids |
Optional tuning grids. Supply either a single data frame to reuse across learners or a named list of data frames keyed by learner id. |
metric |
Optimization metric used when |
... |
Additional arguments passed to |
A funcml_compare object.
cmp <- compare_learners( data = mtcars, formula = mpg ~ wt + hp, models = c("glm", "rpart"), resampling = cv(3, seed = 1), metrics = c("rmse", "mae") ) cmp$resultscmp <- compare_learners( data = mtcars, formula = mpg ~ wt + hp, models = c("glm", "rpart"), resampling = cv(3, seed = 1), metrics = c("rmse", "mae") ) cmp$results
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_compare objects.
## S3 method for class 'funcml_compare' print(x, ...) ## S3 method for class 'funcml_compare' summary(object, ...) ## S3 method for class 'funcml_compare' plot(x, ...)## S3 method for class 'funcml_compare' print(x, ...) ## S3 method for class 'funcml_compare' summary(object, ...) ## S3 method for class 'funcml_compare' plot(x, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
print() and summary() return the input object or results table
invisibly. plot() returns a ggplot2 object.
cmp <- compare_learners( data = mtcars, formula = mpg ~ wt + hp, models = c("glm", "rpart"), resampling = cv(3, seed = 1), metrics = c("rmse", "mae") ) print(cmp) summary(cmp) plot(cmp)cmp <- compare_learners( data = mtcars, formula = mpg ~ wt + hp, models = c("glm", "rpart"), resampling = cv(3, seed = 1), metrics = c("rmse", "mae") ) print(cmp) summary(cmp) plot(cmp)
Resampling specification generator.
cv( v = 5, repeats = 1, strata = TRUE, seed = NULL, method = c("vfold", "holdout", "group_vfold", "time"), prop = 0.8, group = NULL, time = NULL, initial = NULL, assess = NULL, skip = 0, cumulative = TRUE )cv( v = 5, repeats = 1, strata = TRUE, seed = NULL, method = c("vfold", "holdout", "group_vfold", "time"), prop = 0.8, group = NULL, time = NULL, initial = NULL, assess = NULL, skip = 0, cumulative = TRUE )
v |
Number of folds for cross-validation. |
repeats |
Number of repeats for standard or grouped cross-validation. |
strata |
Logical; stratify classification outcomes when supported. |
seed |
Optional seed for reproducibility. |
method |
Resampling strategy: |
prop |
Training-set proportion for holdout splits. |
group |
Optional grouping variable name or vector for grouped CV. |
time |
Optional ordering variable name or vector for time-aware splits. |
initial |
Initial training window size for time-aware CV. |
assess |
Assessment window size for time-aware CV. |
skip |
Number of observations to skip between successive time splits. |
cumulative |
Logical; use an expanding training window for time-aware CV. |
A funcml_cv object containing fold indices and parameters.
cv(v = 3, repeats = 2, seed = 1)cv(v = 3, repeats = 2, seed = 1)
A regression dataset on annual doctor visit counts and related health, demographic, and insurance covariates.
doctorvisitsdoctorvisits
A data frame with 5,190 rows and 12 variables:
Number of doctor visits.
Recorded gender.
Age in years scaled to decades.
Income measure scaled by household composition.
Number of illnesses in the previous two weeks.
Number of days with reduced activity.
Self-rated health score.
Private insurance indicator.
Free care indicator for low-income patients.
Free care indicator for pensioners or veterans.
Indicator for no chronic condition.
Indicator for limiting chronic condition.
Cameron AC, Trivedi PK (1998). Regression Analysis of Count Data.
Cambridge University Press. The packaged data are from AER::DoctorVisits.
str(funcml::doctorvisits) summary(funcml::doctorvisits$visits)str(funcml::doctorvisits) summary(funcml::doctorvisits$visits)
Causal effect estimation via plug-in g-computation.
estimate( data, formula, model = NULL, treatment = NULL, estimand = c("ATE", "ATT", "CATE", "IATE"), newdata = NULL, treatment_level = NULL, control_level = NULL, spec = NULL, type = NULL, interval = c("normal", "bootstrap"), conf_level = 0.95, n_boot = 200, seed = NULL, fit = NULL, ... )estimate( data, formula, model = NULL, treatment = NULL, estimand = c("ATE", "ATT", "CATE", "IATE"), newdata = NULL, treatment_level = NULL, control_level = NULL, spec = NULL, type = NULL, interval = c("normal", "bootstrap"), conf_level = 0.95, n_boot = 200, seed = NULL, fit = NULL, ... )
data |
Data frame. |
formula |
Outcome model formula. The first term on the right-hand side
is treated as the treatment variable unless |
model |
Learner id (ignored if |
treatment |
Optional treatment variable name. |
estimand |
One of |
newdata |
Optional target population for |
treatment_level |
Optional treated level for binary treatment. |
control_level |
Optional control level for binary treatment. |
spec |
Hyperparameter list passed to |
type |
Prediction type override for the outcome model. |
interval |
Interval method: |
conf_level |
Confidence level for uncertainty intervals. |
n_boot |
Number of bootstrap resamples used when |
seed |
Optional seed. |
fit |
Optional preconfigured |
... |
Passed to |
A funcml_estimand object.
causal_data <- mtcars causal_data$am <- factor(causal_data$am, labels = c("auto", "manual")) ate <- estimate( data = causal_data, formula = mpg ~ am + wt + hp, model = "glm", treatment = "am", estimand = "ATE" ) ate$estimatecausal_data <- mtcars causal_data$am <- factor(causal_data$am, labels = c("auto", "manual")) ate <- estimate( data = causal_data, formula = mpg ~ am + wt + hp, model = "glm", treatment = "am", estimand = "ATE" ) ate$estimate
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_estimand objects.
## S3 method for class 'funcml_estimand' print(x, ...) ## S3 method for class 'funcml_estimand' summary(object, ...) ## S3 method for class 'funcml_estimand' plot(x, ...)## S3 method for class 'funcml_estimand' print(x, ...) ## S3 method for class 'funcml_estimand' summary(object, ...) ## S3 method for class 'funcml_estimand' plot(x, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
print() and summary() return the input object or summary table
invisibly. plot() returns a ggplot2 object.
causal_data <- mtcars causal_data$am <- factor(causal_data$am, labels = c("auto", "manual")) ate <- estimate( data = causal_data, formula = mpg ~ am + wt + hp, model = "glm", treatment = "am", estimand = "ATE" ) print(ate) summary(ate) plot(ate)causal_data <- mtcars causal_data$am <- factor(causal_data$am, labels = c("auto", "manual")) ate <- estimate( data = causal_data, formula = mpg ~ am + wt + hp, model = "glm", treatment = "am", estimand = "ATE" ) print(ate) summary(ate) plot(ate)
Cross-validated evaluation.
evaluate( data, formula, model = NULL, spec = NULL, resampling = cv(5), metrics = NULL, type = NULL, conf_level = 0.95, seed = NULL, fit = NULL, ncores = NULL, ... )evaluate( data, formula, model = NULL, spec = NULL, resampling = cv(5), metrics = NULL, type = NULL, conf_level = 0.95, seed = NULL, fit = NULL, ncores = NULL, ... )
data |
Data frame. |
formula |
Model formula. |
model |
Learner id (ignored if |
spec |
Hyperparameter list. |
resampling |
Resampling object from |
metrics |
Character vector of metric names. |
type |
Prediction type override. |
conf_level |
Confidence level for fold-based summary intervals. |
seed |
Optional seed. |
fit |
Optional preconfigured |
ncores |
Optional number of CPU cores used to evaluate resampling folds.
|
... |
Passed to |
A funcml_eval object.
eval_obj <- evaluate( data = mtcars, formula = mpg ~ wt + hp, model = "glm", resampling = cv(3, seed = 1), metrics = c("rmse", "mae") ) eval_obj$summaryeval_obj <- evaluate( data = mtcars, formula = mpg ~ wt + hp, model = "glm", resampling = cv(3, seed = 1), metrics = c("rmse", "mae") ) eval_obj$summary
These methods provide the standard print() and summary() interfaces for
funcml_eval objects, plus a plot() method for fold-level diagnostics.
## S3 method for class 'funcml_eval' print(x, ...) ## S3 method for class 'funcml_eval' summary(object, ...) ## S3 method for class 'funcml_eval' plot(x, ...)## S3 method for class 'funcml_eval' print(x, ...) ## S3 method for class 'funcml_eval' summary(object, ...) ## S3 method for class 'funcml_eval' plot(x, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
print() and summary() return the input object or summary
table invisibly. plot() returns a ggplot2 object.
eval_obj <- evaluate( data = mtcars, formula = mpg ~ wt + hp, model = "glm", resampling = cv(3, seed = 1), metrics = c("rmse", "mae") ) print(eval_obj) summary(eval_obj) plot(eval_obj)eval_obj <- evaluate( data = mtcars, formula = mpg ~ wt + hp, model = "glm", resampling = cv(3, seed = 1), metrics = c("rmse", "mae") ) print(eval_obj) summary(eval_obj) plot(eval_obj)
Registered learner ids currently include:
regression and classification: glm, rpart, glmnet, ranger, nnet, mlp,
e1071_svm, randomForest, gbm, kknn, ctree, cforest,
lightgbm, xgboost, stacking, superlearner;
regression plus binary classification: gam, bart;
classification only: C50, naivebayes, fda, lda, qda;
binary classification only: adaboost;
regression plus binary classification: earth;
regression only: pls.
fit( formula, data, model, spec = NULL, seed = NULL, na_action = stats::na.fail, ... )fit( formula, data, model, spec = NULL, seed = NULL, na_action = stats::na.fail, ... )
formula |
Model formula. |
data |
Data frame. |
model |
Learner id (see |
spec |
Optional list of hyperparameters for the learner. |
seed |
Optional seed for reproducibility. |
na_action |
NA handling passed to |
... |
Additional parameters merged into |
The learner engine packages are installed with funcml, so the advertised
registry is intended to be available after a standard installation.
An object of class funcml_fit.
fit_obj <- fit(mpg ~ wt + hp, data = mtcars, model = "glm") predict(fit_obj, newdata = mtcars[1:3, , drop = FALSE])fit_obj <- fit(mpg ~ wt + hp, data = mtcars, model = "glm") predict(fit_obj, newdata = mtcars[1:3, , drop = FALSE])
These methods provide the standard print(), summary(), predict(), and
coef() interfaces for funcml_fit objects.
## S3 method for class 'funcml_fit' print(x, ...) ## S3 method for class 'funcml_fit' summary(object, ...) ## S3 method for class 'funcml_fit' predict( object, newdata, type = NULL, class_level = NULL, pos_level = NULL, na_action = object$na_action, ... ) ## S3 method for class 'funcml_fit' coef(object, ...)## S3 method for class 'funcml_fit' print(x, ...) ## S3 method for class 'funcml_fit' summary(object, ...) ## S3 method for class 'funcml_fit' predict( object, newdata, type = NULL, class_level = NULL, pos_level = NULL, na_action = object$na_action, ... ) ## S3 method for class 'funcml_fit' coef(object, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
newdata |
Data frame of new observations. |
type |
Prediction type override. |
class_level |
Target class for multiclass probability predictions. |
pos_level |
Alias for the binary positive class. |
na_action |
NA handling for new data. |
print() and summary() return the input object invisibly.
predict() returns predictions in the requested format. coef()
returns a named numeric coefficient vector when available.
fit_obj <- fit(mpg ~ wt + hp, data = mtcars, model = "glm") print(fit_obj) summary(fit_obj) predict(fit_obj, newdata = mtcars[1:3, , drop = FALSE]) coef(fit_obj)fit_obj <- fit(mpg ~ wt + hp, data = mtcars, model = "glm") print(fit_obj) summary(fit_obj) predict(fit_obj, newdata = mtcars[1:3, , drop = FALSE]) coef(fit_obj)
Grouped cross-validation.
group_cv(v = 5, group, repeats = 1, seed = NULL)group_cv(v = 5, group, repeats = 1, seed = NULL)
v |
Number of folds. |
group |
Grouping variable name or vector. |
repeats |
Number of repeats. |
seed |
Optional seed. |
A funcml_cv object.
group_cv(v = 3, group = rep(letters[1:3], each = 4), seed = 1)group_cv(v = 3, group = rep(letters[1:3], each = 4), seed = 1)
A binary classification dataset on breast cancer survival after surgery.
habermanhaberman
A data frame with 306 rows and 4 variables:
Age of patient at operation time in years.
Year of operation minus 1900.
Number of positive axillary nodes detected.
Survival status (1 = survived 5 years or longer,
2 = died within 5 years).
Haberman's Survival Data from the University of Chicago's Billings Hospital study, distributed through the UCI Machine Learning Repository.
str(funcml::haberman) table(funcml::haberman$survival_status)str(funcml::haberman) table(funcml::haberman$survival_status)
A binary classification dataset on heart disease status using demographic and clinical risk factors.
heartdiseaseheartdisease
A data frame with 303 rows and 9 variables:
Age in years.
Recorded sex.
Chest pain type.
Resting blood pressure.
Serum cholesterol measurement.
High fasting blood sugar indicator.
Maximum heart rate achieved.
Exercise-induced angina indicator.
Heart disease outcome ("Yes" or "No").
Column names were standardized to snake_case when packaging the data.
CardioDataSets package dataset
CardioDataSets::heartdisease_tbl_df.
str(funcml::heartdisease) table(funcml::heartdisease$heart_disease)str(funcml::heartdisease) table(funcml::heartdisease$heart_disease)
A binary classification dataset on heart failure mortality using demographic, laboratory, and clinical covariates.
heartfailureheartfailure
A data frame with 299 rows and 13 variables:
Age in years.
Anaemia indicator.
Creatinine phosphokinase level.
Diabetes indicator.
Ejection fraction percentage.
High blood pressure indicator.
Platelet count.
Serum creatinine level.
Serum sodium level.
Sex indicator.
Smoking indicator.
Follow-up time.
Death event outcome indicator (0 = no event,
1 = death).
Column names were standardized to snake_case when packaging the data.
CardioDataSets package dataset
CardioDataSets::cardiac_failure_df.
str(funcml::heartfailure) table(funcml::heartfailure$death_event)str(funcml::heartfailure) table(funcml::heartfailure$death_event)
Plain holdout resampling.
holdout(prop = 0.8, strata = TRUE, seed = NULL)holdout(prop = 0.8, strata = TRUE, seed = NULL)
prop |
Training-set proportion. |
strata |
Logical; stratify classification outcomes. |
seed |
Optional seed. |
A funcml_cv object.
holdout(prop = 0.75, seed = 1)holdout(prop = 0.75, seed = 1)
A regression dataset on infant mortality with country-level income, region, and oil-export status covariates.
infantmortalityinfantmortality
A data frame with 105 rows and 5 variables:
Country name.
Per-capita income.
Infant mortality rate.
Geographic region.
Oil-exporting country indicator.
Column names were standardized to snake_case when packaging the data.
Fox J, Weisberg S (2019). An R Companion to Applied Regression.
Sage. The packaged data are from carData::Leinhardt.
str(funcml::infantmortality) summary(funcml::infantmortality$infant)str(funcml::infantmortality) summary(funcml::infantmortality$infant)
Implements native permutation VI, PDP/ICE/ALE, SHAP approximations, local surrogate explanations, interaction strength, and global surrogate models.
interpret( fit, data, formula = fit$formula, method = c("vip", "permute", "pdp", "ice", "ale", "local", "lime", "shap", "local_model", "interaction", "surrogate", "profile", "ceteris_paribus", "calibration"), features = NULL, type = NULL, metric = NULL, importance_type = c("permute", "model", "auto"), compare = c("difference", "ratio"), keep = TRUE, k = NULL, gower_power = NULL, class_level = NULL, pos_level = NULL, newdata = NULL, nsim = NULL, nsamples = NULL, grid = NULL, seed = NULL, bins = 10, strategy = c("quantile", "uniform"), ... )interpret( fit, data, formula = fit$formula, method = c("vip", "permute", "pdp", "ice", "ale", "local", "lime", "shap", "local_model", "interaction", "surrogate", "profile", "ceteris_paribus", "calibration"), features = NULL, type = NULL, metric = NULL, importance_type = c("permute", "model", "auto"), compare = c("difference", "ratio"), keep = TRUE, k = NULL, gower_power = NULL, class_level = NULL, pos_level = NULL, newdata = NULL, nsim = NULL, nsamples = NULL, grid = NULL, seed = NULL, bins = 10, strategy = c("quantile", "uniform"), ... )
fit |
A |
data |
Reference data (typically training set). |
formula |
Optional formula (defaults to |
method |
One of "vip","permute","pdp","ice","ale","local","lime", "shap","local_model","interaction","surrogate","profile", "ceteris_paribus", or "calibration". |
features |
Optional subset of features; defaults to all predictors. |
type |
Prediction scale: regression -> "response"; classification -> "prob" or "class". |
metric |
Loss/score for importance (reg: rmse/mae/mse/medae/mape/rsq; cls: accuracy/precision/recall/specificity/f1/balanced_accuracy/logloss/brier/ece/mce/auc/auc_weighted). |
importance_type |
Importance engine for |
compare |
How to compare baseline and perturbed performance for
importance: |
keep |
Keep per-repetition raw importance scores when |
k |
Sparsity target for local surrogate fits ( |
gower_power |
Exponent applied to native similarity weights when constructing the local neighborhood. |
class_level |
Target class for multiclass/local prob explanations. |
pos_level |
Alias for binary positive class (second level default). |
newdata |
Single-row data frame for local/SHAP explanations; defaults to first row of |
nsim |
Number of Monte Carlo simulations (importance/SHAP) or repetitions. |
nsamples |
Row subsample for speed (reference/background set). |
grid |
Optional list of grids per feature for PDP/ICE/ALE. |
seed |
Optional seed for determinism. |
bins |
Number of bins for calibration diagnostics. |
strategy |
Binning strategy for calibration diagnostics. |
... |
Additional method-specific args. |
An interpretation object whose class depends on method.
Returned objects contain computed explanation values and metadata used
for printing, summarizing, and plotting.
fit_obj <- fit( mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5) ) vi <- interpret( fit = fit_obj, data = mtcars, method = "permute", features = c("wt", "hp"), nsim = 2, metric = "rmse" ) vi$result$scoresfit_obj <- fit( mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5) ) vi <- interpret( fit = fit_obj, data = mtcars, method = "permute", features = c("wt", "hp"), nsim = 2, metric = "rmse" ) vi$result$scores
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_ale objects returned by interpret(method = "ale").
## S3 method for class 'funcml_ale' plot(x, ...) ## S3 method for class 'funcml_ale' print(x, ...) ## S3 method for class 'funcml_ale' summary(object, ...)## S3 method for class 'funcml_ale' plot(x, ...) ## S3 method for class 'funcml_ale' print(x, ...) ## S3 method for class 'funcml_ale' summary(object, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
plot() returns a ggplot2 object. print() returns the input
object invisibly. summary() returns the ALE curve table invisibly.
fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) ale <- interpret( fit = fit_obj, data = mtcars, method = "ale", features = c("wt", "hp"), nsamples = 20 ) print(ale) summary(ale) plot(ale)fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) ale <- interpret( fit = fit_obj, data = mtcars, method = "ale", features = c("wt", "hp"), nsamples = 20 ) print(ale) summary(ale) plot(ale)
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_calibration objects returned by
interpret(method = "calibration").
## S3 method for class 'funcml_calibration' plot(x, style = c("curve", "histogram"), ...) ## S3 method for class 'funcml_calibration' print(x, ...) ## S3 method for class 'funcml_calibration' summary(object, ...)## S3 method for class 'funcml_calibration' plot(x, style = c("curve", "histogram"), ...) ## S3 method for class 'funcml_calibration' print(x, ...) ## S3 method for class 'funcml_calibration' summary(object, ...)
x |
A |
style |
Plot style: |
... |
Additional arguments passed to the underlying method. |
object |
A |
plot() returns a ggplot2 object. print() returns the input
object invisibly. summary() returns a list with calibration curve and
summary diagnostics invisibly.
fit_obj <- fit(mpg > 20 ~ wt + hp + disp, data = mtcars, model = "glm") cal <- interpret( fit = fit_obj, data = mtcars, method = "calibration" ) print(cal) summary(cal) plot(cal)fit_obj <- fit(mpg > 20 ~ wt + hp + disp, data = mtcars, model = "glm") cal <- interpret( fit = fit_obj, data = mtcars, method = "calibration" ) print(cal) summary(cal) plot(cal)
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_ice objects returned by interpret(method = "ice")
or interpret(method = "ceteris_paribus").
## S3 method for class 'funcml_ice' plot(x, ...) ## S3 method for class 'funcml_ice' print(x, ...) ## S3 method for class 'funcml_ice' summary(object, ...)## S3 method for class 'funcml_ice' plot(x, ...) ## S3 method for class 'funcml_ice' print(x, ...) ## S3 method for class 'funcml_ice' summary(object, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
plot() returns a ggplot2 object. print() returns the input
object invisibly. summary() returns the ICE curve table invisibly.
fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) ice <- interpret( fit = fit_obj, data = mtcars, method = "ice", features = c("wt", "hp"), nsamples = 20 ) print(ice) summary(ice) plot(ice)fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) ice <- interpret( fit = fit_obj, data = mtcars, method = "ice", features = c("wt", "hp"), nsamples = 20 ) print(ice) summary(ice) plot(ice)
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_interaction objects returned by
interpret(method = "interaction").
## S3 method for class 'funcml_interaction' plot(x, ...) ## S3 method for class 'funcml_interaction' print(x, ...) ## S3 method for class 'funcml_interaction' summary(object, ...)## S3 method for class 'funcml_interaction' plot(x, ...) ## S3 method for class 'funcml_interaction' print(x, ...) ## S3 method for class 'funcml_interaction' summary(object, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
plot() returns a ggplot2 object. print() returns the input
object invisibly. summary() returns the interaction summary table
invisibly.
fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) interaction_obj <- interpret( fit = fit_obj, data = mtcars, method = "interaction", features = c("wt", "hp"), nsamples = 20, grid_size = 5 ) print(interaction_obj) summary(interaction_obj) plot(interaction_obj)fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) interaction_obj <- interpret( fit = fit_obj, data = mtcars, method = "interaction", features = c("wt", "hp"), nsamples = 20, grid_size = 5 ) print(interaction_obj) summary(interaction_obj) plot(interaction_obj)
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_local objects returned by interpret(method = "local")
or interpret(method = "lime").
## S3 method for class 'funcml_local' plot(x, ...) ## S3 method for class 'funcml_local' print(x, ...) ## S3 method for class 'funcml_local' summary(object, ...)## S3 method for class 'funcml_local' plot(x, ...) ## S3 method for class 'funcml_local' print(x, ...) ## S3 method for class 'funcml_local' summary(object, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
plot() returns a ggplot2 object. print() returns the input
object invisibly. summary() returns the local explanation payload
invisibly.
fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) local_obj <- interpret( fit = fit_obj, data = mtcars, method = "local", newdata = mtcars[1, , drop = FALSE], k = 2 ) print(local_obj) summary(local_obj) plot(local_obj)fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) local_obj <- interpret( fit = fit_obj, data = mtcars, method = "local", newdata = mtcars[1, , drop = FALSE], k = 2 ) print(local_obj) summary(local_obj) plot(local_obj)
local_model and lime.These methods provide the standard print(), summary(), and plot()
interfaces for funcml_iml_local_model objects returned by
interpret(method = "local_model") or interpret(method = "lime").
## S3 method for class 'funcml_iml_local_model' plot(x, ...) ## S3 method for class 'funcml_lime' plot(x, ...) ## S3 method for class 'funcml_iml_local_model' print(x, ...) ## S3 method for class 'funcml_iml_local_model' summary(object, ...)## S3 method for class 'funcml_iml_local_model' plot(x, ...) ## S3 method for class 'funcml_lime' plot(x, ...) ## S3 method for class 'funcml_iml_local_model' print(x, ...) ## S3 method for class 'funcml_iml_local_model' summary(object, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
plot() returns a ggplot2 object. print() returns the input
object invisibly. summary() returns the local model explanation payload
invisibly.
fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) local_model <- interpret( fit = fit_obj, data = mtcars, method = "local_model", newdata = mtcars[1, , drop = FALSE], k = 2 ) print(local_model) summary(local_model) plot(local_model)fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) local_model <- interpret( fit = fit_obj, data = mtcars, method = "local_model", newdata = mtcars[1, , drop = FALSE], k = 2 ) print(local_model) summary(local_model) plot(local_model)
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_pdp objects returned by interpret(method = "pdp").
## S3 method for class 'funcml_pdp' plot(x, ...) ## S3 method for class 'funcml_pdp' print(x, ...) ## S3 method for class 'funcml_pdp' summary(object, ...)## S3 method for class 'funcml_pdp' plot(x, ...) ## S3 method for class 'funcml_pdp' print(x, ...) ## S3 method for class 'funcml_pdp' summary(object, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
plot() returns a ggplot2 object. print() returns the input
object invisibly. summary() returns the partial dependence table
invisibly.
fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) pdp <- interpret( fit = fit_obj, data = mtcars, method = "pdp", features = c("wt", "hp"), nsamples = 20 ) print(pdp) summary(pdp) plot(pdp)fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) pdp <- interpret( fit = fit_obj, data = mtcars, method = "pdp", features = c("wt", "hp"), nsamples = 20 ) print(pdp) summary(pdp) plot(pdp)
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_permute objects returned by interpret(method = "permute").
## S3 method for class 'funcml_permute' plot(x, ...) ## S3 method for class 'funcml_permute' print(x, ...) ## S3 method for class 'funcml_permute' summary(object, ...)## S3 method for class 'funcml_permute' plot(x, ...) ## S3 method for class 'funcml_permute' print(x, ...) ## S3 method for class 'funcml_permute' summary(object, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
plot() returns a ggplot2 object. print() returns the input
object invisibly. summary() returns the permutation importance table
invisibly.
fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) perm <- interpret( fit = fit_obj, data = mtcars, method = "permute", features = c("wt", "hp"), nsim = 1, metric = "rmse" ) print(perm) summary(perm) plot(perm)fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) perm <- interpret( fit = fit_obj, data = mtcars, method = "permute", features = c("wt", "hp"), nsim = 1, metric = "rmse" ) print(perm) summary(perm) plot(perm)
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_shap objects returned by interpret(method = "shap").
## S3 method for class 'funcml_shap' plot( x, kind = c("auto", "waterfall", "force", "summary", "beeswarm", "importance", "bar", "dependence", "dependence2d", "interaction"), ... ) ## S3 method for class 'funcml_shap' print(x, ...) ## S3 method for class 'funcml_shap' summary(object, ...)## S3 method for class 'funcml_shap' plot( x, kind = c("auto", "waterfall", "force", "summary", "beeswarm", "importance", "bar", "dependence", "dependence2d", "interaction"), ... ) ## S3 method for class 'funcml_shap' print(x, ...) ## S3 method for class 'funcml_shap' summary(object, ...)
x |
A |
kind |
Plot kind. One of |
... |
Additional arguments passed to the underlying method. |
object |
A |
plot() returns a visualization object (typically a ggplot2
object) from shapviz. print() returns the input object invisibly.
summary() returns the SHAP contribution table invisibly.
fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) shap <- interpret( fit = fit_obj, data = mtcars, method = "shap", newdata = mtcars[1, , drop = FALSE], nsim = 1 ) print(shap) summary(shap) plot(shap)fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) shap <- interpret( fit = fit_obj, data = mtcars, method = "shap", newdata = mtcars[1, , drop = FALSE], nsim = 1 ) print(shap) summary(shap) plot(shap)
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_surrogate objects returned by
interpret(method = "surrogate").
## S3 method for class 'funcml_surrogate' plot(x, ...) ## S3 method for class 'funcml_surrogate' print(x, ...) ## S3 method for class 'funcml_surrogate' summary(object, ...)## S3 method for class 'funcml_surrogate' plot(x, ...) ## S3 method for class 'funcml_surrogate' print(x, ...) ## S3 method for class 'funcml_surrogate' summary(object, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
plot() returns a ggplot2 object. print() returns the input
object invisibly. summary() returns the surrogate model summary object
invisibly.
fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) surrogate <- interpret( fit = fit_obj, data = mtcars, method = "surrogate" ) print(surrogate) summary(surrogate) plot(surrogate)fit_obj <- fit(mpg ~ wt + hp + disp, data = mtcars, model = "rpart", spec = list(cp = 0.01, minsplit = 5)) surrogate <- interpret( fit = fit_obj, data = mtcars, method = "surrogate" ) print(surrogate) summary(surrogate) plot(surrogate)
A regression dataset on ketamine dosing, treatment characteristics, cost, quality-adjusted life years, and administration mode.
ketapainketapain
A data frame with 184 rows and 11 variables:
Patient identifier.
Recorded sex.
Patient age in years.
Average dose.
Dose level category.
Cumulative dose.
Cumulative treatment days.
Perfusion duration.
Treatment cost.
Quality-adjusted life years.
Administration mode.
Column names were standardized to snake_case when packaging the data.
Original ketamine pain management dataset distributed with the project materials.
str(funcml::ketapain) summary(funcml::ketapain$qaly)str(funcml::ketapain) summary(funcml::ketapain$qaly)
learners() returns the registry keys accepted by fit(). Task support is:
regression and classification: glm, rpart, glmnet, ranger, nnet, mlp,
e1071_svm, randomForest, gbm, kknn, ctree, cforest,
lightgbm, xgboost, stacking, superlearner;
regression plus binary classification: gam, bart, earth;
classification only: C50, naivebayes, fda, lda, qda;
binary classification only: adaboost;
regression only: pls.
learners()learners()
The learner engine packages are installed with funcml, so the advertised
registry is intended to be available after a standard installation.
Character vector of learner ids.
learners()learners()
Returns a compact catalog of interpret() entry points and whether each
method has a corresponding plot() method.
list_interpretability_methods(has_plot = NULL, columns = NULL)list_interpretability_methods(has_plot = NULL, columns = NULL)
has_plot |
Optional logical filter for methods with plot support. |
columns |
Optional character vector of columns to return. |
Data frame of interpretability methods.
list_interpretability_methods() subset(list_interpretability_methods(), has_plot)list_interpretability_methods() subset(list_interpretability_methods(), has_plot)
list_learners() returns a compact learner registry in the style of a
catalog table. By default it focuses on the most user-visible columns:
learner id, generic fit/predict/tune entry points, and availability in the
current session.
list_learners( has_fit = NULL, has_predict = NULL, has_tune = NULL, available = NULL, columns = NULL, regression = NULL, classification = NULL, prob = NULL, multiclass = NULL, importance = NULL, tune = NULL )list_learners( has_fit = NULL, has_predict = NULL, has_tune = NULL, available = NULL, columns = NULL, regression = NULL, classification = NULL, prob = NULL, multiclass = NULL, importance = NULL, tune = NULL )
has_fit |
Optional logical filter for fit support. |
has_predict |
Optional logical filter for predict support. |
has_tune |
Optional logical filter for tuning support. |
available |
Optional logical filter for engine availability in the current session. |
columns |
Optional character vector of columns to return. |
regression |
Optional logical filter for regression support. |
classification |
Optional logical filter for classification support. |
prob |
Optional logical filter for probability support. |
multiclass |
Optional logical filter for multiclass support. |
importance |
Optional logical filter for feature-importance support. |
tune |
Deprecated alias for |
Additional capability metadata remains available through columns =.
Data frame with learner metadata and capability columns.
list_learners() list_learners(has_tune = TRUE) list_tunable_learners() list_learners(classification = TRUE, prob = TRUE, columns = c("learner", "has_tune", "supports_prob", "engine_package"))list_learners() list_learners(has_tune = TRUE) list_tunable_learners() list_learners(classification = TRUE, prob = TRUE, columns = c("learner", "has_tune", "supports_prob", "engine_package"))
List available metrics used in scoring and resampling summaries.
list_metrics(direction = NULL, columns = NULL)list_metrics(direction = NULL, columns = NULL)
direction |
Optional character filter: |
columns |
Optional character vector of columns to return. |
Data frame of metric metadata.
list_metrics() list_metrics(direction = "minimize")list_metrics() list_metrics(direction = "minimize")
Shortcut for learners with tuning support.
list_tunable_learners(...)list_tunable_learners(...)
... |
Passed to |
Data frame with the same columns as list_learners().
list_tunable_learners()list_tunable_learners()
A binary classification dataset for detection of mammographic microcalcifications.
mammographymammography
A data frame with 11,183 rows and 7 variables:
Numeric imaging-derived predictor 1.
Numeric imaging-derived predictor 2.
Numeric imaging-derived predictor 3.
Numeric imaging-derived predictor 4.
Numeric imaging-derived predictor 5.
Numeric imaging-derived predictor 6.
Calcification class (\"-1\" or \"1\").
Woods K, Doss C, Bowyer K, Solka J, Priebe C, Kegelmeyer P (1993). Comparative evaluation of pattern recognition techniques for detection of microcalcifications in mammography.
str(funcml::mammography) table(funcml::mammography$class)str(funcml::mammography) table(funcml::mammography$class)
Base R implementations used across evaluation and interpretation utilities.
rmse(truth, pred) mae(truth, pred) mse(truth, pred) rsq(truth, pred) medae(truth, pred) mape(truth, pred) logloss(truth, prob_matrix) brier(truth, prob_matrix) accuracy(truth, pred_class) precision(truth, pred_class) recall(truth, pred_class) specificity(truth, pred_class) f1(truth, pred_class) balanced_accuracy(truth, pred_class) auc(truth, prob, average = c("macro", "weighted")) auc_weighted(truth, prob) calibration_curve( truth, prob, bins = 10, strategy = c("quantile", "uniform"), positive = NULL ) ece( truth, prob, bins = 10, strategy = c("quantile", "uniform"), positive = NULL ) mce( truth, prob, bins = 10, strategy = c("quantile", "uniform"), positive = NULL )rmse(truth, pred) mae(truth, pred) mse(truth, pred) rsq(truth, pred) medae(truth, pred) mape(truth, pred) logloss(truth, prob_matrix) brier(truth, prob_matrix) accuracy(truth, pred_class) precision(truth, pred_class) recall(truth, pred_class) specificity(truth, pred_class) f1(truth, pred_class) balanced_accuracy(truth, pred_class) auc(truth, prob, average = c("macro", "weighted")) auc_weighted(truth, prob) calibration_curve( truth, prob, bins = 10, strategy = c("quantile", "uniform"), positive = NULL ) ece( truth, prob, bins = 10, strategy = c("quantile", "uniform"), positive = NULL ) mce( truth, prob, bins = 10, strategy = c("quantile", "uniform"), positive = NULL )
truth |
Observed outcomes. |
pred |
Predicted numeric values or class labels. |
prob_matrix |
Matrix or vector of predicted probabilities (classification). |
pred_class |
Predicted class labels (classification). |
prob |
Probability vector (binary) or probability matrix with one column per class (multiclass). |
average |
For multiclass AUC, aggregation mode: |
bins |
Number of bins for calibration summaries. |
strategy |
Binning strategy: |
positive |
Optional positive/event class for binary classification. |
Numeric scalar metric.
truth_reg <- c(3, 5, 2.5, 7) pred_reg <- c(2.8, 4.9, 2.7, 6.8) rmse(truth_reg, pred_reg) mae(truth_reg, pred_reg) mse(truth_reg, pred_reg) rsq(truth_reg, pred_reg) medae(truth_reg, pred_reg) mape(truth_reg, pred_reg) truth_cls <- factor(c("no", "yes", "yes", "no"), levels = c("no", "yes")) pred_cls <- factor(c("no", "yes", "no", "no"), levels = levels(truth_cls)) prob_cls <- cbind( no = c(0.8, 0.2, 0.6, 0.7), yes = c(0.2, 0.8, 0.4, 0.3) ) logloss(truth_cls, prob_cls) brier(truth_cls, prob_cls) accuracy(truth_cls, pred_cls) precision(truth_cls, pred_cls) recall(truth_cls, pred_cls) specificity(truth_cls, pred_cls) f1(truth_cls, pred_cls) balanced_accuracy(truth_cls, pred_cls) auc(truth_cls, prob_cls[, "yes"]) truth_multi <- factor(c("a", "b", "c", "a", "b", "c"), levels = c("a", "b", "c")) prob_multi <- rbind( c(0.90, 0.05, 0.05), c(0.05, 0.90, 0.05), c(0.05, 0.05, 0.90), c(0.85, 0.10, 0.05), c(0.10, 0.80, 0.10), c(0.05, 0.10, 0.85) ) colnames(prob_multi) <- levels(truth_multi) auc(truth_multi, prob_multi) auc_weighted(truth_multi, prob_multi) calibration_curve(truth_cls, prob_cls[, "yes"]) ece(truth_cls, prob_cls[, "yes"]) mce(truth_cls, prob_cls[, "yes"])truth_reg <- c(3, 5, 2.5, 7) pred_reg <- c(2.8, 4.9, 2.7, 6.8) rmse(truth_reg, pred_reg) mae(truth_reg, pred_reg) mse(truth_reg, pred_reg) rsq(truth_reg, pred_reg) medae(truth_reg, pred_reg) mape(truth_reg, pred_reg) truth_cls <- factor(c("no", "yes", "yes", "no"), levels = c("no", "yes")) pred_cls <- factor(c("no", "yes", "no", "no"), levels = levels(truth_cls)) prob_cls <- cbind( no = c(0.8, 0.2, 0.6, 0.7), yes = c(0.2, 0.8, 0.4, 0.3) ) logloss(truth_cls, prob_cls) brier(truth_cls, prob_cls) accuracy(truth_cls, pred_cls) precision(truth_cls, pred_cls) recall(truth_cls, pred_cls) specificity(truth_cls, pred_cls) f1(truth_cls, pred_cls) balanced_accuracy(truth_cls, pred_cls) auc(truth_cls, prob_cls[, "yes"]) truth_multi <- factor(c("a", "b", "c", "a", "b", "c"), levels = c("a", "b", "c")) prob_multi <- rbind( c(0.90, 0.05, 0.05), c(0.05, 0.90, 0.05), c(0.05, 0.05, 0.90), c(0.85, 0.10, 0.05), c(0.10, 0.80, 0.10), c(0.05, 0.10, 0.85) ) colnames(prob_multi) <- levels(truth_multi) auc(truth_multi, prob_multi) auc_weighted(truth_multi, prob_multi) calibration_curve(truth_cls, prob_cls[, "yes"]) ece(truth_cls, prob_cls[, "yes"]) mce(truth_cls, prob_cls[, "yes"])
A multiclass classification dataset on thyroid functional state using five laboratory test measurements.
newthyroidnewthyroid
A data frame with 215 rows and 6 variables:
T3-resin uptake percentage.
Total serum thyroxin measurement.
Total serum triiodothyronine measurement.
Basal thyroid-stimulating hormone measurement.
Maximum absolute TSH difference after thyrotropin- releasing hormone injection.
Thyroid class (1 = normal, 2 = hyperthyroid,
3 = hypothyroid).
Thyroid gland data distributed through the UCI Machine Learning Repository.
str(funcml::newthyroid) table(funcml::newthyroid$class)str(funcml::newthyroid) table(funcml::newthyroid$class)
A diabetes classification dataset with clinical measurements and a predefined train/test split column.
pimadiabetespimadiabetes
A data frame with 532 rows and 9 variables:
Number of pregnancies.
Plasma glucose concentration.
Diastolic blood pressure.
Triceps skin fold thickness.
Body mass index.
Diabetes pedigree function.
Age in years.
Diabetes outcome ("Yes" or "No").
Suggested split indicator ("train" or "test").
National Institute of Diabetes and Digestive and Kidney Diseases Pima Indians Diabetes Database.
Smith JW, Everhart JE, Dickson WC, Knowler WC, Johannes RS (1988). Using the ADAP learning algorithm to forecast the onset of diabetes mellitus. In Proceedings of the Annual Symposium on Computer Application in Medical Care, 261-265.
str(funcml::pimadiabetes) table(funcml::pimadiabetes$split)str(funcml::pimadiabetes) table(funcml::pimadiabetes$split)
A custom ggplot2 theme used across funcml plots. It keeps a clean
light background, restrained grid lines, and high-contrast labels so
package figures remain consistent and publication-friendly.
theme_funcml(base_size = 11)theme_funcml(base_size = 11)
base_size |
Base text size passed to the theme. |
A ggplot2 theme object.
ggplot2::ggplot(mtcars, ggplot2::aes(wt, mpg)) + ggplot2::geom_point() + theme_funcml()ggplot2::ggplot(mtcars, ggplot2::aes(wt, mpg)) + ggplot2::geom_point() + theme_funcml()
Time-aware rolling resampling.
time_cv( initial, assess = 1, time = NULL, skip = 0, cumulative = TRUE, seed = NULL )time_cv( initial, assess = 1, time = NULL, skip = 0, cumulative = TRUE, seed = NULL )
initial |
Initial training window size. |
assess |
Assessment window size. |
time |
Ordering variable name or vector. |
skip |
Number of observations to skip between splits. |
cumulative |
Logical; use an expanding training window. |
seed |
Optional seed. |
A funcml_cv object.
time_cv(initial = 8, assess = 2, skip = 1)time_cv(initial = 8, assess = 2, skip = 1)
Hyperparameter tuning via grid or random search.
tune( data, formula, model, grid, resampling = cv(5), metric = NULL, type = NULL, search = c("grid", "random"), n_evals = NULL, outer_resampling = NULL, seed = NULL, ncores = NULL, ... )tune( data, formula, model, grid, resampling = cv(5), metric = NULL, type = NULL, search = c("grid", "random"), n_evals = NULL, outer_resampling = NULL, seed = NULL, ncores = NULL, ... )
data |
Data frame. |
formula |
Model formula. |
model |
Learner id. |
grid |
Data frame of hyperparameter combinations. |
resampling |
Resampling object. |
metric |
Metric to optimize. |
type |
Prediction type override. |
search |
Search strategy: |
n_evals |
Maximum number of configurations to evaluate when
|
outer_resampling |
Optional outer resampling object. When supplied,
|
seed |
Optional seed. |
ncores |
Optional number of CPU cores used for tuning tasks. |
... |
Passed to |
A funcml_tune object.
tune_obj <- tune( data = mtcars, formula = mpg ~ wt + hp, model = "rpart", grid = expand.grid(cp = c(0.001, 0.01), minsplit = c(5, 10)), resampling = cv(3, seed = 1), metric = "rmse" ) tune_obj$besttune_obj <- tune( data = mtcars, formula = mpg ~ wt + hp, model = "rpart", grid = expand.grid(cp = c(0.001, 0.01), minsplit = c(5, 10)), resampling = cv(3, seed = 1), metric = "rmse" ) tune_obj$best
These methods provide the standard print(), summary(), and plot()
interfaces for funcml_tune objects.
## S3 method for class 'funcml_tune' print(x, ...) ## S3 method for class 'funcml_tune' summary(object, ...) ## S3 method for class 'funcml_tune' plot(x, ...)## S3 method for class 'funcml_tune' print(x, ...) ## S3 method for class 'funcml_tune' summary(object, ...) ## S3 method for class 'funcml_tune' plot(x, ...)
x |
A |
... |
Additional arguments passed to the underlying method. |
object |
A |
print() and summary() return the input object or results table
invisibly. plot() returns a ggplot2 object.
tune_obj <- tune( data = mtcars, formula = mpg ~ wt + hp, model = "rpart", grid = expand.grid(cp = c(0.001, 0.01), minsplit = c(5, 10)), resampling = cv(3, seed = 1), metric = "rmse" ) print(tune_obj) summary(tune_obj) plot(tune_obj)tune_obj <- tune( data = mtcars, formula = mpg ~ wt + hp, model = "rpart", grid = expand.grid(cp = c(0.001, 0.01), minsplit = c(5, 10)), resampling = cv(3, seed = 1), metric = "rmse" ) print(tune_obj) summary(tune_obj) plot(tune_obj)