Synthetically Rebalancing Healthcare Datasets via Conditional DDPM

This blog post was written by Keira Behal, Jiayi Chen, Caleb Fikes, and Sophia Xiao and published with minor edits. Our work was guided by Dr. Yuanzhe Xi.In addition to this post, the team has also given a midterm presentation, created a poster blitz video, created a poster, created some code and wrote a paper.

Background

In recent years, machine learning algorithms have become increasingly important in healthcare for tasks like disease prediction, diagnostics, and treatment optimization. However, these algorithms can perpetuate harmful societal biases if the training data contains inherent biases or underrepresentation of certain groups, particularly minority communities. This bias issue is a significant concern in healthcare, where fair and equitable treatment is crucial. Even valuable data sources like Electronic Health Records (EHR), which provide a comprehensive overview of a patient’s health history, including diagnoses, treatments, and demographics, can suffer from underrepresentation of certain racial or ethnic minorities. This imbalance can lead to inequitable health outcomes, with minority groups receiving less accurate diagnoses or treatment recommendations due to their underrepresentation in the training data.

Our Approach

To address this challenge, we propose Minority Class Rebalancing through Augmented Data Generation (McRAGE), a novel approach to augment imbalanced medical datasets using samples generated by a deep generative model. The McRAGE process involves training a Conditional Denoising Diffusion Probabilistic Model (CDDPM) capable of generating high-quality synthetic EHR samples from underrepresented classes. We use this synthetic data to augment the imbalanced dataset, achieving a more balanced distribution across all classes. The dataset can be used to train an unbiased machine learning model. Our work intends to promote fair and accurate healthcare predictions, enhancing patient care and supporting equity through data science and artificial intelligence.

In our project, we focused on data-driven methods of synthetic sample generation through deep generative models because deep generative models aim to capture the underlying data distribution and generate samples that closely resemble the real data so that they can produce high-fidelity synthetic samples that retain the characteristics and variability of the original data. We chose a CDDPM to generate synthetic samples to augment our dataset for the following three reasons:

  1. Stability of Training: DDPMs are easier to train than GANs (Generative Adversarial Networks), which are more susceptible to mode collapse and vanishing gradients.
  2. High Fidelity Generation: diffusion models can produce higher quality samples than other generative models. Many commercial implementations of diffusion models such as Stability AI’s Stable Diffusion and Open AI’s DALL-E 2 have garnered significant funding and public interest.
  3. Class-specific Generation: The CDDPM offers control over the generation of specific classes of data, whereas conventional DDPM is limited to generating samples that reflect the class distribution of its training set. By utilizing CDDPM, we can use the knowledge gained from the majority groups to improve the quality of generated data for the minority groups.

The MCRAGE process is both intuitive and theoretically justifiable. The algorithm results in a synthetically rebalanced training set where each “intersectional group,” or a unique combination of sensitive demographic factors, is equally represented. By generating an artificial stratified random sample, the process promotes statistical parity, meaning that classifiers trained on this data will have the same distribution of decisions for each group. This ensures that the classifier’s performance is nearly equivalent for each subgroup despite class imbalances in the training data. The MCRAGE process is proposed under the assumption that such a result holds. In that case, the process results in the nearest approximation to a stratified sample given all the information in the training set. This project aims to empirically test the performance of MCRAGE compared to previous state-of-the-art methods for mitigating dataset imbalance.

Our Contribution

Our approach to conditional generation in the diffusion model is called classifier-free guidance. It involves modifying the denoising update process by incorporating class information. Unlike the classifier guidance method, which relies on a separate classifier for conditional sample generation, our approach integrates the class information directly into the model. In our implementation, we added an extra class embedding to the time embedding in the conventional DDPM. This enables us to incorporate and utilize class information effectively.

Our Results

While developing our algorithm, we used a subset of MNIST to track our progress. Finally, we applied the algorithm to a tabular EHR dataset called Patient Treatment Classification with a random forest classifier. The dataset consists of Electronic Health Records collected from a private Hospital in Indonesia. The dataset contains eight laboratory test results for 3309 patients, used to determine patient treatment classification (inpatient care or outpatient care). In both cases, we compared the synthetically balanced datasets of our approach to the balanced dataset and those obtained from the SMOTE algorithm. We also compared the different approaches quantitatively using different fitness metrics and found that the classifier trained on data augmented by the DDPM performed the best.

These promising results motivate us to conduct the same procedure on a more extensive EHR data set with more features with multiple classes, possibly incorporating a multinomial diffusion.

Interested in Learning More?

Please see our poster and paper for more details.