Building glass-box models that are both predictive and interpretable
For many reasons (e.g. scientific inquiry, high-stakes decision making), we need AI systems that are both accurate and intelligible.
We find that interaction effects are often a useful lens through which to view intelligibility. Interaction effects (Lengerich et al., 2020) are effects which require two input components to know anything about the output (one component alone tells you nothing). Since humans reason by chunking and hierarchical logic, we struggle to understand interactions of multiple variables. If we can instead represent effects as additive (non-interactive) combinations of components, we can understand the components independently and reason about even very complex concepts.
Background Early identification of patients at increased risk for postpartum hemorrhage (PPH) associated with severe maternal morbidity (SMM) is critical for preparation and preventative intervention. However, prediction is challenging in patients without obvious risk factors for postpartum hemorrhage with severe maternal morbidity. Current tools for hemorrhage risk assessment use lists of risk factors rather than predictive models. Objective To develop, validate (internally and externally), and compare a machine learning model for predicting PPH associated with SMM against a standard hemorrhage risk assessment tool in a lower-risk laboring obstetric population. Study Design This retrospective cross-sectional study included clinical data from singleton, term births (>=37 weeks’ gestation) at 19 US hospitals (2016-2021) using data from 44,509 births at 11 hospitals to train a generalized additive model (GAM) and 21,183 births at 8 held-out hospitals to externally validate the model. The outcome of interest was PPH with severe maternal morbidity (blood transfusion, hysterectomy, vascular embolization, intrauterine balloon tamponade, uterine artery ligation suture, uterine compression suture, or admission to intensive care). Cesarean birth without a trial of vaginal birth and patients with a history of cesarean were excluded. We compared the model performance to that of the California Maternal Quality Care Collaborative (CMQCC) Obstetric Hemorrhage Risk Factor Assessment Screen. Results The GAM predicted PPH with an area under the receiver-operating characteristic curve (AUROC) of 0.67 (95% CI 0.64-0.68) on external validation, significantly outperforming the CMQCC risk screen AUROC of 0.52 (95% CI 0.50-0.53). Additionally, the GAM had better sensitivity of 36.9% (95% CI 33.01, 41.02) than the CMQCC screen sensitivity of 20.30% (95% CI 17.40, 22.52) at the CMQCC screen positive rate of 16.8%. The GAM identified in-vitro fertilization as a risk factor (adjusted OR 1.5; 95% CI 1.2-1.8) and nulliparous births as the highest PPH risk factor (adjusted OR 1.5; 95% CI; 1.4-1.6). Conclusion Our model identified almost twice as many cases of PPH as the CMQCC rules-based approach for the same screen positive rate and identified in-vitro fertilization and first-time births as risk factors for PPH. Adopting predictive models over traditional screens can enhance PPH prediction.
2022
Automated interpretable discovery of heterogeneous treatment effectiveness: A COVID-19 case study
Testing multiple treatments for heterogeneous (varying) effectiveness with respect to many underlying risk factors requires many pairwise tests; we would like to instead automatically discover and visualize patient archetypes and predictors of treatment effectiveness using multitask machine learning. In this paper, we present a method to estimate these heterogeneous treatment effects with an interpretable hierarchical framework that uses additive models to visualize expected treatment benefits as a function of patient factors (identifying personalized treatment benefits) and concurrent treatments (identifying combinatorial treatment benefits). This method achieves state-of-the-art predictive power for COVID-19 in-hospital mortality and interpretable identification of heterogeneous treatment benefits. We first validate this method on the large public MIMIC-IV dataset of ICU patients to test recovery of heterogeneous treatment effects. Next we apply this method to a proprietary dataset of over 3000 patients hospitalized for COVID-19, and find evidence of heterogeneous treatment effectiveness predicted largely by indicators of inflammation and thrombosis risk: patients with few indicators of thrombosis risk benefit most from treatments against inflammation, while patients with few indicators of inflammation risk benefit most from treatments against thrombosis. This approach provides an automated methodology to discover heterogeneous and individualized effectiveness of treatments.
We examine Dropout through the perspective of interactions: effects that require multiple variables. Given N variables, there are N \choose k possible sets of k variables (N univariate effects, \mathcalO(N^2) pairwise interactions, \mathcalO(N^3) 3-way interactions); we can thus imagine that models with large representational capacity could be dominated by high-order interactions. In this paper, we show that Dropout contributes a regularization effect which helps neural networks (NNs) explore functions of lower-order interactions before considering functions of higher-order interactions. Dropout imposes this regularization by reducing the effective learning rate of higher-order interactions. As a result, Dropout encourages models to learn lower-order functions of additive components. This understanding of Dropout has implications for choosing Dropout rates: higher Dropout rates should be used when we need stronger regularization against interactions. This perspective also issues caution against using Dropout to measure term salience because Dropout regularizes against high-order interactions. Finally, this view of Dropout as a regularizer of interactions provides insight into the varying effectiveness of Dropout across architectures and datasets. We also compare Dropout to weight decay and early stopping and find that it is difficult to obtain the same regularization with these alternatives.
2021
Neural Additive Models: Interpretable Machine Learning with Neural Nets
Rishabh Agarwal, Levi Melnick, Nicholas Frosst, and 4 more authors
Advances in Neural Information Processing Systems, 2021
Deep neural networks (DNNs) are powerful black-box predictors that have achieved impressive performance on a wide variety of tasks. However, their accuracy comes at the cost of intelligibility: it is usually unclear how they make their decisions. This hinders their applicability to high stakes decision-making domains such as healthcare. We propose Neural Additive Models (NAMs) which combine some of the expressivity of DNNs with the inherent intelligibility of generalized additive models. NAMs learn a linear combination of neural networks that each attend to a single input feature. These networks are trained jointly and can learn arbitrarily complex relationships between their input feature and the output. Our experiments on regression and classification datasets show that NAMs are more accurate than widely used intelligible models such as logistic regression and shallow decision trees. They perform similarly to existing state-of-the-art generalized additive models in accuracy, but are more flexible because they are based on neural nets instead of boosted trees. To demonstrate this, we show how NAMs can be used for multitask learning on synthetic data and on the COMPAS recidivism data due to their composability, and demonstrate that the differentiability of NAMs allows them to train more complex interpretable models for COVID-19.
How Interpretable and Trustworthy are GAMs?
Chun-Hao Chang, Sarah Tan, Ben Lengerich, and 2 more authors
In Proceedings of the 27th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining , 2021
Generalized additive models (GAMs) have become a leading model class for interpretable machine learning. However, there are many algorithms for training GAMs, and these can learn different or even contradictory models, while being equally accurate. Which GAM should we trust? In this paper, we quantitatively and qualitatively investigate a variety of GAM algorithms on real and simulated datasets. We find that GAMs with high feature sparsity (only using a few variables to make predictions) can miss patterns in the data and be unfair to rare subpopulations. Our results suggest that inductive bias plays a crucial role in what interpretable models learn and that tree-based GAMs represent the best balance of sparsity, fidelity and accuracy and thus appear to be the most trustworthy GAM models.
2020
Purifying Interaction Effects with the Functional ANOVA: An Efficient Algorithm for Recovering Identifiable Additive Models
Ben Lengerich, Sarah Tan, Chun-Hao Chang, and 2 more authors
In Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics (AISTATS) , 26–28 aug 2020
Models which estimate main effects of individual variables alongside interaction effects have an identifiability challenge: effects can be freely moved between main effects and interaction effects without changing the model prediction. This is a critical problem for interpretability because it permits “contradictory" models to represent the same function. To solve this problem, we propose pure interaction effects: variance in the outcome which cannot be represented by any subset of features. This definition has an equivalence with the Functional ANOVA decomposition. To compute this decomposition, we present a fast, exact algorithm that transforms any piecewise-constant function (such as a tree-based model) into a purified, canonical representation. We apply this algorithm to Generalized Additive Models with interactions trained on several datasets and show large disparity, including contradictions, between the apparent and the purified effects. These results underscore the need to specify data distributions and ensure identifiability before interpreting model parameters.