Performance of Machine Learning Algorithms for Predicting Progression to Dementia in Memory Clinic Patients

Key Points Question Can machine learning algorithms accurately predict 2-year dementia incidence in memory clinic patients and how do these predictions compare with existing models? Findings In this prognostic study of data from 15 307 memory clinic attendees without dementia, machine learning algorithms were superior in their ability to predict incident dementia within 2 years compared with 2 existing predictive models. Machine learning algorithms required only 6 variables to reach an accuracy of at least 90%, and had an area under the receiver operating characteristic curve of 0.89. Meaning These findings suggest that machine learning algorithms could be used to accurately predict 2-year dementia risk and may form the basis for a clinical decision-making aid.

Logistic Regression (LR) is a probabilistic model, meaning it assigns a class probability to each participant. 1 Probabilities are calculated using a logistic function. This function maps a linear combination of the variables of each participant to a value between 0 and 1, which may be viewed as a class probability. If the class probability is greater than a given "decision threshold", the participant is classified as belonging to class 1 (Dementia). For probabilities less than the threshold, they are placed in class 0 (No Dementia). The decision threshold may be varied to adjust the balance between the sensitivity and specificity of the resulting classifier, as explored subsequently.
In contrast to LR, Support Vector Machine (SVM) is a non-probabilistic binary classifier. 2 During training, an SVM classifier finds a boundary, or hyper-plane, that spatially separates the classes. The class to which a point is assigned is determined by which side of the hyper-plane it lies. The separating hyper-plane is chosen to be that which gives the largest margin, or largest separation, between the classes. Similar to LR, a score approximating the probability of class membership can be derived as a function of the distance of the point from the hyperplane. Using a linear SVM allowed for variable importance to be evaluated. Random Forest (RF) and Gradient-Boosted Trees (XGB) are examples of ensemble learning algorithms, where the underlying algorithm is a decision tree. During training a RF will select a random sample of the training data with replacement and fit a decision tree to this sample. 3,4 This process is repeated many times, the exact number being a parameter of the algorithm, to create an ensemble (or forest) of decision trees. XGB differs from RF by training decision trees sequentially such that each new tree is trained to correct the errors from the previously trained tree. 5 At validation stage, the probability of a participant belonging to either class is determined by averaging over the outcomes obtained from applying each individual decision tree to the participant.
To apply the machine learning algorithms to our data, we used one-hot encoding of categorical variables, creating a new binary variable for each of the categories. We scaled the data such that each variable had a mean of zero and variance of one.
To perform the 5-fold cross validation we used the 'StratifiedKFold' function in sci-kit learn, 6 carrying out a parameter search for the models within each fold, using 'GridSearchCV' which performs an exhaustive search of a parameter grid for each model. The best set of parameters were used to specify each model, for each fold, resulting in up to 20 different models with different parameters. All code used for the modelling process is available online. 7

Model evaluation
Performance measures were obtained by bootstrapping the data. Specifically, we selected a random sample of 1000 patients, with replacement, and calculated performance measures in this sample using the predicted classes and class probabilities obtained during 5-fold cross validation. The sampling was repeated 100 times to obtain a distribution of values for each measure. The standard deviation of the distribution of values of a measure is quoted as the error.

Variable importance
To assess the variable importance for each model we used the coef (LR and SVM) and feature importances (RF and XGB) functionality of scikit-learn. 6 For LR and SVM, the importance of a variable is determined by the magnitude of its coefficient when the model is fit to the training data; the larger the magnitude of the coefficient the more important a variable is to the prediction. For RF and XGB, variable importance is determined by the Gini importance. 8 Specifically, for each tree a variable's importance is the total decrease in impurity that occurs when the variable is used to split a node, weighted by the number of samples the node splits. The Gini importance is calculated for each tree and then averaged to give a final variable importance.
During the 5-fold cross validation, we determined variable importance for each fold. The final importance of each variable was calculated by averaging over the folds.
© 2021 James C et al. JAMA Network Open.

Diagnostic Stability
We define reversion as when a participant who was diagnosed with dementia up to 2 years after their first memory clinic visit subsequently receives a diagnosis of no dementia (either MCI or unimpaired cognition) within 2 years of their initial dementia diagnosis.
To investigate the classification accuracy of ML models in participants with reversion, we removed all such participants from the training data in each fold. This is justified by the definition of a reversion; in the data these participants are labelled as 1 (having incident dementia), yet when their subsequent diagnosis of no dementia is taken into account this diagnosis of dementia, and therefore their label, is incorrect. By removing participants with reversion from the training data we ensured that the ML models were trained on, to our knowledge, correctly labelled data only.
We subsequently re-trained the ML models to perform the same classification task, identify whether a participant would develop dementia within 2 years of their baseline assessment, without reversions in the training data. We assessed each model's ability to identify participants with reversion by looking at the labels assigned to these participants. If a participant with a reversion is classified as dementia free (class 0), they are identified; the ML model has labelled them correctly rather than misdiagnosing them, as they have been in the data.