Open your black box classifier
There is much discussion about opening black boxes, particularly in relation to predictive models that involve machine learning [1]. Some funding schemes go as far as requiring artificial intelligence models to inform about possible failures, as part of the requirement for technical robustness in high stakes applications. This has led to frameworks to define testable standards for the interpretability of predictive models [2].
In medicine in particular, interpretable models are important, not least because an understanding of the contributing factors towards a diagnosis can be as insightful as the quantification of the diagnostic prediction itself, but also because this level of transparency is essential for building trust in the model [3]. A typical example of good practice in explaining machine learning models in medicine is the application of Shapley values during model validation, which shows “how a domain understanding of machine learning models is straightforward to establish” [4].
The need for explanation is equally as pronounced when deep learning is applied to the classification of medical images. A rigorous study of Covid-19 detection from lung CTs showed that high performance metrics could be achieved when the predictive models focused on artifacts such as annotations in the images and even the support structure on which the person was laid out for the CT. Explanation methods were central to identifying the bias in the models due to spurious effects that happened to correlate with class membership in the data set, despite good practice by splitting the data into three groups for model inference, optimisation and performance estimation [5]. This paper found that “very small increases in validation accuracy can correspond to drastic changes in the concepts learned by the network … it can mean overcoming a bias introduced by the artifacts.”
In many computer-based decision support applications, clinical attributes take the form of tabular data. Being so prevalent, not just in medicine but also for risk models in other domains ranging from banking to insurance, this class of data deserves particular focus and it is the subject of the rest of this piece. For tabular data specifically, one way around the issue of transparency is with models that are interpretable by design [6].
Interpretability by design has long been known to be possible with linear-in-the-parameters models and with decision trees, albeit at the expense of classification performance. Although rule-based predictors [7] and risk scores derived from logistic regression models [8] have been effective to aid decision making in clinical practice and indeed have performance levels that are competitive even against modern approaches such as deep learning [9] there are significant shortcomings. In order to cope with non-linear dependence on clinical attributes with linear models, input variables are frequently discretised. An example of this would be to group age intervals into multiple categories. However, if age bands are for instance by decades, this would treat someone aged 39 as more similar to a 30-year-old than to a 40-year-old. Discretisation will mask variation within each group and, furthermore, it can lead to considerable loss of power and residual confounding [10].
One way to manage non-linearities with interpretable models is to fit a Generalised Additive Model (GAM) estimating the dependence on individual variables with splines [11]. This class of flexible models is in fact a gold standard for interpretability [12]. They are self-explaining [13] and new formulations are emerging which do not require careful tuning of spline parameters but replace them with machine learning modules. In the case of Explainable Boosting Machines [14] the modules are random forests and gradient boosted trees, whereas Neural Additive Models [15] have the structure of a self-explaining neural network. Both are bespoke models and estimate the component functions of the GAM in tandem with inferring an optimal sparse model structure. Along with linear and logistic regression, GAMs lend themselves to practical implementation in the form of nomograms, which are already familiar to clinicians for visualisation of risk scores [16,17].
But what about existing machine learning models?
A key to opening probabilistic black box classifiers without sacrificing predictive performance is an old statistical tool, Analysis of Variance (ANOVA). It is well-known that ANOVA decompositions can express any function as an exact sum of functions of fewer variables, comprising main effects for individual variables together with interaction terms [18]. This is a natural way to derive additive functions with gradually increasing complexity. The derived functions are non-linear and mutually orthogonal, ensuring that the terms involving several variables do not overlap with the information contained in the simpler component functions.
All black box models generate multivariate response functions and hence can be expressed in the form of GAMs using ANOVA. For probabilistic models, this can be applied to the logit of the predicted probabilities. Selecting univariate and bivariate additive terms provides interpretability. The black box is then explained by replacing the original data columns with the ANOVA terms and selecting the most informative components with an appropriate statistical model, such as the Least Absolute Shrinkage and Selection Operator [19].
There are two measures that can be applied in ANOVA, both related to the commonly used partial dependence functions. The Dirac measure corresponds to a cut across the predicted surface and the Lebesgue measure is an average over the same surface, sampled over the training data by setting the values of only the variables in the argument of each component function and sweeping them across their full range. In practice, the main difference between the two measures is a small variation in the models that are selected. This framework is remarkably stable showing that partial dependence functions, normally used only for visualisation, work very well for model selection and are effective for prediction.
Once the black box has been mapped onto a GAM, from there onwards the two measures yield exactly the same component functions. Interestingly, the Shapley additive values, already used in medicine [4] are exactly the terms in the GAM expansion [20].
A natural next step is to replicate the interpretable model derived from the black box by implementing it in the form of a Generalised Additive Neural Network (GANN) also known as a Self-Explaining Neural Network (SENN). This will ensure that the univariate and bivariate component functions can be further optimised given the selected structure. Model refinement is possible by a renewed application of the ANOVA decomposition, this time to separate and orthogonalize the first and second order terms in the GANN/SENN [20] rather than the original MLP. This results in a streamlined model that is optimised to the final sparse structure. A schematic of the model inference process is shown in fig. 1.
Second-order terms appear to be sufficient to achieve strong performance [20] no doubt due to the inherent noise in the data. Moreover, starting with a black box model, the structure and form of the original interpretable model is generally very close to that of the GANN/SENN estimated de novo by re-initialising and re-training, as are the predictive performances of the two models [20].
The derived GAMs make clinically plausible predictions for real-world data and buck the performance-transparency trade-off even against deep learning [21]. They solve one of the biggest hurdles for AI by enabling physicians and other end-users to easily interpret the results of the models. Arguably, transparency has arrived for tabular data, setting a new benchmark for the clinical application of flexible classifiers.