Generative models improve fairness of medical classifiers under distribution shifts – Nature.com

Our research complies with all relevant ethical regulations. We only repurposed existing assets and datasets and did not collect new assets for the purposes of our study, beyond annotations by dermatology experts for the generated images. The non-accessible data used in the study can be used for research purposes without further scrutiny or collection of consent from the source individuals.

In this section, we describe the datasets we used to train the downstream classifiers and diffusion models across the different modalities and medical contexts. Three different datasets were used, all of which are de-identified; informed consent was obtained from the participants in the original studies that collected these data.

We used data from the CAMELYON17 challenge21 that include labeled and unlabeled data from three different hospitals for training, as well as one in-distribution and one OOD validation hospitals. Data from the different hospitals differ because of the staining procedure used. The task was to estimate the presence of breast cancer metastases in the images, which are patches of whole-slide images of histological lymph node sections. The number of samples per hospital is given in Extended Data Table 1a; all subsets were approximately evenly split into those containing tumors and those that did not. We used the training data (302,436 examples) and the unlabeled data (1.8 million examples) to train the diffusion model. We performed patch-based instead of whole-slide classification to align with the WILDS challenge22 and follow-up works that evaluated methods on the same setup.

In terms of label distribution, there were 151,046 patches of healthy tissue in the training set and 151,390 patches of cancerous tissue. For the ID (validation) dataset, these statistics are 16,952 and 16,608, respectively, while in the OOD (validation) and OOD (test) splits there were 17,452 and 42,527 patches corresponding to each class, respectively (that is, both OOD datasets were perfectly balanced).

We trained the cascaded diffusion and downstream discriminative model on a total of 201,055 samples from the CheXpert database23, with 119,352 individuals annotated as male and 81,703 as female (the dataset only contained binary gender labels). We show the age and original label distribution in Extended Data Fig. 3a,b. The original CheXpert training set contained positive, negative, uncertain and unmentioned labels. The uncertain samples were not considered when learning the diagnostic model, but they were used to train the diffusion model. The unmentioned label was considered a negative (that is, the condition was not present), which yielded a highly imbalanced dataset. The evaluation National Institutes of Health dataset24 denoted as OOD consisted of 17,723 individuals, out of which 10,228 were male and 7,495 were female.

Extended Data Fig. 3c,d illustrates how often different conditions co-occurred in the training and evaluation samples. Capturing the characteristics of a single condition can be challenging because they frequently coexist with other conditions in a single case. One characteristic example is pleural effusion, which was included in the diagnosis of atelectasis, consolidation and edema in approximately 50% of cases. However, the scenario is slightly different for the OOD ChestX-ray14 dataset, where for most pairs of conditions the corresponding ratio was much lower.

The imaging samples in the dermatology dataset were often accompanied with metadata that include attributes like biological sex, age and skin tone. Skin tone was labeled according to the Fitzpatrick scale, giving rise to six categories (plus unknown). The ground truth labels for the condition were the result of aggregation of clinical assessments by multiple experts, who provided a list of top-3 conditions along with a confidence score (between 1 and 5). A weighted aggregate of these labels gave rise to soft labels that we used for training the generative and diagnostic models. The dermatology datasets were characterized by complex shifts with respect to each other as the label distribution, demographic distribution and capture process may all vary across them. To demonstrate the severity of the prevalence shift across locations, we visualized the distribution of conditions in the evaluation datasets in Extended Data Fig. 4.

To disentangle the effect of each of those shifts, we artificially skewed the source dataset along three sensitive attribute axes: sex, skin tone and age. Skewing the dataset allowed us to understand which methods performed better as the distribution shifts became more severe. For example, if our original dataset was skewed toward younger age groups, conditioning the generative model on age and (over)sampling from older ages could potentially help close the performance gap between younger and older populations. To study this aspect, we could not rebalance our datasets because we had too few samples from the long tail of our distribution with regard to the label or sensitive attribute. We skewed the training labeled dataset to make it progressively more biased (by removing instances from the least represented subgroups) and investigate how performance suffered because of skewing. For each sensitive attribute, we created new versions of the in-distribution dataset progressively more skewed to the high-data regions. We show how the resulting training dataset was skewed with respect to each of the sensitive attributes in Extended Data Table 1bd. We also report similar demographic statistics for the three evaluation datasets in Extended Data Table 1eg. The cascaded diffusion model was always trained on the union of the labeled training data and the total of unlabeled data across the three available domains. The discriminative model was always evaluated on the same three evaluation datasets (one in-distribution held-out dataset and two OOD datasets) for consistency.

Generative models, especially generative adversarial networks (GANs)29, have been used by several studies to improve performance in different medical imaging tasks30,31,32,33,34 and, in particular, for underrepresented conditions35. Data obtained by exploring different latent image attributes through a generative model have also been shown to improve adversarial robustness of image classifiers36. In the clinical setting, GANs have been used by several studies to improve performance in different tasks, for example, disease diagnosis, in scenarios where few labeled samples were available. Such models have been used to augment medical images for liver lesion classification30, classification of diabetic retinopathy from fundus images31 and breast mass diagnosis in mammography32. In dermoscopic imaging33, a progressive generative model was introduced to produce realistic high-resolution synthetic images, while34 focused on improving balanced multiclass accuracy and, in particular, sensitivity for high-risk underrepresented diagnostic labels like melanoma37. It focused on a similar approach for chest X-rays by combining real and synthetic images generated with GANs to improve classifier accuracy for rare diseases35. It used conditional image generation in scenarios where the conditioning vector was not always available to disentangle image content and image style information. They applied the method to dermoscopic images (HAM10000 dataset) corresponding to seven types of skin lesions and lung computed tomography scans from the Lung Image Database Consortium-Image Database Resource Initiative.

Apart from whole-image downstream tasks, GAN-based augmentation techniques have been used to improve performance on pixel-wise classification tasks, for example, vessel contour segmentation on fundus images38 and brain lesion segmentation39. Given that pixel-wise downstream tasks were not within the scope of our study, we refer the reader to a more thorough review of GAN-based methods in medical image augmentation by Chen et al.40; Bissoto et al.41, in turn, provide an overview of GAN-based augmentation techniques with a main focus on skin lesion augmentation and anonymization.

Despite the wide variety of health applications that have adopted GAN-based generative models to produce learned augmentations of images, these are often characterized by limited diversity and quality42. More recently, DDPMs19,20,43,44,45 presented an outstanding performance in image generation tasks and have been probed for medical knowledge by Kather et al.46 in different medical domains. Other works extended diffusion models to three-dimensional magnetic resonance and computed tomography images47 and demonstrated that they can be conditioned on text prompts for chest X-ray generation48. Given the ethical questions around the use of synthetic images in medicine and healthcare46,49, it is important to make a distinction between using generative models to augment the original training dataset and replacing real images with synthetic ones, especially in the absence of privacy guarantees. None of these works claimed that the latter would be preferable, but rather came to the rescue when obtaining more abundant real data is either expensive or not feasible (for example, in the case of rare conditions), even if this solution is not a panacea42. While some studies view generative models as a means of replacing real data with anonymized synthetic data, we abstain from such claims because greater care needs to be taken to ensure that generative models are trained with privacy guarantees, as shown by Carlini et al.50 and Somepalli et al.51.

Many scholars recently scrutinized ML systems and surfaced different types of biases that emerge through the ML pipeline, including problems due to data acquisition protocols, flawed human decision-making, missing features and label scarcity52. They identified and characterized various biases that can emerge during model development and are exacerbated during model deployment, and in clinical interactions, while they argued that ensuring fairness in those contexts is essential to advance health equity. The relevant literature discussed below was inspired by the realization that, if we break down performance of automated systems that rely on ML algorithms (for example, computer vision, judicial systems) based on certain demographic or socioeconomic traits, there can be vast discrepancies in predictive accuracy across these subgroups. This is alarming for applications influencing human life and it is particularly concerning in the context of computer-aided diagnosis and clinical decision-making.

One of the first studies to dive into the effect of training data composition on model performance across the sexes when using chest X-rays to diagnose thoracic diseases was the one led by Larrazabal et al.12. They found that the prevalence of a particular sex in the training set is directly linked to the predictive accuracy of the model for the same group at the test time. In other words, a model trained on a set highly skewed toward female patients would demonstrate higher accuracy for female patients at test time compared to a counterpart trained on a male-dominated set of images. Even though this finding might not come as a surprise, one would expect that a ML model used in clinical practice across geographical locations be robust to demographic shifts of this kind. In a similar vein, Seyyed-Kalantari et al.13 further explored how differences in age, race or ethnicity, and insurance type (as a proxy of socioeconomic status) are manifested in the performance of a classifier operating on chest radiographs. A crucial finding was that the algorithm would exhibit a higher false positive rate, that is, underdiagnose ethnic minorities. These effects were compounded for intersectional identities (that is, the false positive rate was higher for Black female patients compared to Black male patients). Similar findings were reported by Puyol-Antn et al.53 in a cardiac segmentation task with respect to sex and racial biases, and by Gianfrancesco et al.54 in a different modality (electronic health records) for patients with low socioeconomic status.

The method is illustrated in Fig. 1b and leverages diffusion models to learn augmentations of the data. The approach consists of three main steps: (1) we trained a generative model given the available labeled and unlabeled data; (2) we sampled from the generative model according to a sampling strategy; (3) we enriched our original training dataset from the source (also called in-distribution) domain with the synthetic images sampled from the generative model and trained a diagnostic model (potentially for multiple labels, if more than one condition can be present at once). We treated the mixing ratio between real and synthetic as a hyperparameter in all three settings and we selected the best value based on model performance on the validation set. We provide more specific details about the experimental setting for each modality in the following section and the pseudocode for our method in Fig. 1a.

Algorithm 1: pseudocode of proposed method

Input: modality

if Modality == "histopathology" then

Num_labels 2

A (in) {"hospital_id"}

else if Modality == "radiology" then

Num_labels 5

A (in) {"sex", "race"}

else if Modality == "dermatology" then

Num_labels 27

A (in) {"sex", "age", "skin_tone"}

end if

Input: ({{X}}{{in }}{{mathbb{R}}}^{{mathrm{Batch}}{}times{mathrm{Height}}{{times }}{mathrm{Width}}{{times }}{mathrm{Channels}}}{{;Y}}{{in }}{{mathbb{R}}}^{{mathrm{Batch}}{{times }}{{Nu}}{mathrm{m}}_labels})

Train diffusion model (hat{p}sim {mathrm{DDPM}}({{X}},Y,{{A}}))

if Modality (in) {"radiology", "dermatology"} then

Train upsampler diffusion model ({hat{p}}_{mathrm{upsampler}}sim {mathrm{DDPM}}({{X}},Y,{{A}}))

end if

Sample ({{X}}{prime}) from (hat{p}), ({hat{p}}_{mathrm{upsampler}}) according to a fair distribution (hat{p}(Y,{{A}}))

We assume: (hat{p}({{A}})sim mathrm{uniform}), (hat{p}(Y)=p(Y))

Output: ({{Xtext{'}}}{{in }}{{mathbb{R}}}^{{mathrm{Samples}}{{times }}{mathrm{Height}}{{times }}{mathrm{Width}}{{times }}{mathrm{Channels}}}{mathrm{;}}Y{{{prime} }}{{in }}{{mathbb{R}}}^{{mathrm{Samples}}{{times }}{mathrm{Nu}}{mathrm{m}}_labels}) synthetic samples

Sample random number ({rng}in [mathrm{0,1}])

Train diagnostic model (d({Y|}{{X}})=mathrm{ResNet}({{X}})) using ({{{x}}}_{d},{y}_{d}) and mixing ratio (a)

if ({rng} < a) then

({{{x}}}_{d},{y}_{d}in {{(}}{{X}}{{,}}Y{{)}})

else

({{{x}}}_{d},{y}_{d}in {{(}}{{Xtext{'}}}{{,}}Y{{{prime} }}{{)}})

end if

For histopathology, we trained a diffusion model to generate images at 9696 resolution, which is the smallest in comparison to the other imaging modalities. The data used to train the diffusion model consisted of labeled and unlabeled data only from the in-distribution hospitals. To condition the diffusion model, we considered either the diagnostic label (that is, cancer or no cancer) or the diagnostic label and hospital ID together. For the unlabeled data, which did not contain the diagnostic label, we padded the corresponding conditioning vector with zeros. We then sampled from the diffusion model assuming a uniform distribution across hospital IDs and preserving the diagnostic label distribution. The synthetic-to-real data ratio used in histopathology is 50:50, meaning that 50% of the total training samples corresponded to real patches and 50% to synthetic samples from the diffusion model. For the diagnostic model, we focused on a patch-based classification setup instead of whole-slide image classification. Both experimental design decisions, that is, the image resolution and the classification setup, were made to align with the WILDS challenge22 and the wealth of literature that evaluates ML methods on in-the-wild distribution shifts using the same setting55. We evaluated on the held-out in-distribution and OOD hospitals (results shown in Fig. 2).

For chest radiology, we trained two diffusion models (one generating images at 6464 resolution and one upsampling those generated images to 224224 resolution) on labeled images from the in-distribution dataset. Therefore, in this scenario, we did not have access to any unlabeled data or data from the OOD dataset. This holds for both the diffusion models and the diagnostic model, that is, the OOD dataset was only used for evaluation. We conditioned both generative models on the diagnostic label only. While treating the synthetic-to-real data ratio as a hyperparameter, we found that training the downstream diagnostic model purely on synthetic data led to the best accuracy and fairness trade-off. We did not alter the diagnostic label distribution, that is, we used the labels of the real data to condition the diffusion models and yield a synthetic sample. In this setting, the model backbone was shared across all conditions, while a separate (binary classification) head was trained for each condition, given that multiple conditions can be present at once.

For dermatology, we trained two diffusion models (one generating images at 6464 resolution and one upsampling those generated images to 256256 resolution) on labeled images from the in-distribution dataset and unlabeled images from the in-distribution and OOD datasets. At no stage of training did we have access to labeled samples from the OOD datasets. We conditioned both generative models on the diagnostic label (padded with zeros for the unlabeled samples) or the diagnostic label and a demographic attribute. While treating the ratio of synthetic-to-real data as a hyperparameter, we found that training the downstream diagnostic model on 75% synthetic images and 25% real images yielded the best results. When we artificially skewed the dataset against certain demographic subgroups, we ensured that both the generative models and the diagnostic model had access to the same labeled examples (that is, we trained a different diffusion model for each skewed setting). When we sampled from the diffusion model, we preserved the diagnostic label distribution and assumed a uniform demographic attribute distribution.

We motivated the use of generated data and demonstrated its utility in several toy settings, which simulate the problem of having only a few number of samples from the underlying distribution or parts of the underlying distribution. We wished to have high performance despite this lack of data. We demonstrated that even in these toy settings, synthetic data were useful.

We assumed we had a dataset ({D}_{mathrm{train}}={left{left({{{x}}}_{i},{y}_{i}{,{{a}}}_{i}right)right}}_{i=1}^{N}) where ({{{x}}}_{i},{y}_{i}) is an image and label pair, ({{{a}}}_{i}) is a list of attributes about the datapoint and is the number of training samples. The attributes may include attributes such as sex, skin type and age, or the hospital ID (in the case of histopathology). We had an additional dataset ({D}_{u}={left{{hat{{{x}}}}_{j}right}}_{j=1}^{M}) of unlabeled images, being the number of samples, that could be used as desired. We had a generative model (hat{p}) trained with ({D}_{mathrm{train}}) and ({D}_{u}) (we make (widetilde{theta }) implicit in the following). We dropped the subscripts in the following for simplicity where obvious.

To achieve fairness, we assumed we had a fair dataset ({D}_{{mathrm{f}}}={left{left({{{x}}}_{i},{y}_{i}{,{{a}}}_{i}right)right}}_{i=1}^{F}) with datapoints that consisted of samples from the fair distribution ({p}_{{mathrm{f}}}) over which we aimed to minimize the expectation of the loss. ({f}_{theta }({{x}})) was the classifier and (L) the loss function (for example, binary cross-entropy). We aimed to optimize the following objective:

$$mathop{{bf{min }}}limits_{{{theta }}}mathop{{mathbb{E}}}limits_{{{{D}}}_{{{{mathrm{f}}}}}}left({{L}}left(;{f}_{{{theta }}}({{x}}),y,{{a}}right)right)$$

(1)

We can decompose the data generating process into ({p}_{{mathrm{f}}}({{x}}|{{a}},y){p}_{{mathrm{f}}}({{a}}{|y}){p}_{{mathrm{f}}}(;y)). For example, we may have created ({D}_{{mathrm{f}}}) by sampling uniformly over an attribute (such as sex) and labels. We assumed that the training dataset ({D}_{mathrm{train}}{subset D}_{{mathrm{f}}}) was sampled from a distribution ({p}_{mathrm{train}}) where ({p}_{mathrm{train}}({{x}}|{{a}},y){=p}_{{mathrm{f}}}({{x}}|{{a}},y)). When ({p}_{mathrm{train}}(;y,{{a}}){ne p}_{{mathrm{f}}}(;y,{{a}})), then we have a distribution shift between the training and fair distribution (for example, the training distribution is more likely to generate images of a particular attribute or combinations of label and attribute than the fair distribution).

We aimed to combine the training dataset ({D}_{mathrm{train}}) and synthetic data sampled from the generative model (hat{p}) to mimic most closely the fair distribution and improve fairness. We constructed a new dataset ({hat{D}}) according to a distribution ({hat{p}}) from these distributions using some probability parameter (alpha):

$$left({{x}},{{a}},yright) sim{p^{prime}}left{begin{array}{l}left({{x}},{{a}},yright) sim {D}_{mathrm{train}}qquad:alpha \ left({{x}},{{a}},yright),{x}sim hat{p}left({{x}}|y,{{a}}right),left({{a}},yright) sim hat{p}({{a}}{{,}}y)qquad:(1-alpha )end{array}right.$$

(2)

So instead of minimizing equation (1), we minimized the following sum of expectations:

$$mathop{{bf{min }}}limits_{{{theta }}}alpha mathop{{mathbb{E}}}limits_{left({{x}},{{a}},yright) sim {D}_{{{mathrm{train}}}}}left({{L}}left(;{f}_{{{theta }}}({{x}}),{{a}},yright)right)+(1-alpha )mathop{{mathbb{E}}}limits_{left({{x}},{{a}},yright) sim hat{p}}left({{L}}left(;{f}_{{{theta }}}({{x}}),{{a}},yright)right)$$

(3)

The question is then how to choose (alpha) and (hat{p}({{a}},y)). For all settings in the main article, we maintained the label distribution (hat{p}(;y)=p(;y)) but sampled uniformly over the attribute (alpha). We validated this choice on dermatology in the Supplementary Information. We treated (alpha) as a hyperparameter in all settings.

Whenever we required an upsampler (that is, in radiology and dermatology), we trained it by preprocessing the original images using the following steps: (1) upsampled images from the 6464 input resolution to the desired output resolution with bilinear interpolation and used an anti-alias with 0.5 probability; (2) added random Gaussian noise with 0.2 probability and =4.0 (in the (0255) range); (3) applied random Gaussian blurring with a 77 kernel and mean=0, s.d.=0.2; (4) quantized the image to 256 bins; and (5) normalized the image to the (1 to 1) range.

For both the generative model and the upsampler, we filled the conditioning vectors with zeros (indicating an invalid vector) for the unlabeled data. This allowed us to use classifier-free guidance20 to make images more canonical with respect to a given label or property.

In this section, we describe the exact model architecture used for the trained diffusion models and classifiers, as well as the hyperparameters used for the presented results. Hyperparameters were selected based on the baseline model performance on the respective in-distribution validation sets and held constant for the remaining methods. This meant that we did not finetune hyperparameters for each method (other than the baseline) separately. We use the DDPM as presented by refs. 19,20,43 for the generation and the upsampler (only the radiology and dermatology datasets required higher-resolution images). The backbone model was always a UNet architecture. The hyperparameters used for the cascaded diffusion models were based on the standard values mentioned in the literature with minimal modifications. We present all hyperparameters in Extended Data Table 2.

For this modality, augmentations included brightness, contrast, saturation and hue jitter. Hue and saturation were sufficient to achieve the high-quality results described by Tellez et al.56.

The heuristic augmentations considered for this modality included: random horizontal flipping; random cropping to 202202 resolution; resizing to 224224 with bilinear interpolation and anti-alias; random rotation by 15 degrees, shifting luminance by a value sampled uniformly from the (0.1 to 0.1) range; and shifting contrast using a value uniformly sampled from the (0.8 to 1.2) range (that is, pixel values were multiplied by the shift value and clipped to remain within the (0 to 1) range).

For this modality, we used the following heuristic augmentations: random horizontal and vertical flipping; adjusting image brightness by a random factor (maximum (delta =0.1)); adjusting image saturation by a random factor (within the (0.8 to 1.2) range); adjusting the hue by a random factor (maximum (delta =0.02)); adjusting image contrast by a random factor (within the (0.8 to 1.2) range); random rotation within the (150 to 150) range; and random Gaussian blurring with standard deviation uniformly sampled from the following values: {0.001, 0.01, 0.1, 1.0, 3.0, 5.0, 7.0}.

In all contexts, we considered the strongest heuristic augmentations as a baseline. These augmentations (heuristic or learned) can be combined with any alternative learning algorithm that aims to improve model generalization. For the sake of our experiments, we used empirical risk minimization57 because there is no single method that consistently outperforms it under distribution shifts55. Even though our experiments and analysis focus on DDPMs for generation, any conditional generative model that produces high-quality and diverse samples can be used. In general, the risk, that is, how well the algorithm will fit the data, cannot be computed on the true data distribution (P(x,y)) because it is unknown to the learning algorithm. However, we could compute an approximation, called empirical risk, by averaging the loss function on the training set samples.

For this modality, all models used the same ResNet-152 backbone. We compared (1) a baseline using no augmentation (Baseline) and (2) one using standard color augmentations (Color augm.) as applied in standard ImageNet training. This augmentation included brightness, contrast, saturation and hue jitter. Hue and saturation were sufficient augmentations to achieve the highest-quality results by Tellez et al.56; hence, we did not evaluate other heuristic augmentations. Our baseline did not use pretraining because it previously did not yield any benefits on this particular dataset as reported by Wiles et al.55. We also compared the models to those applying heuristic color augmentations on top of the synthetic data.

All models used the same BiT-ResNet-152 backbone58. We considered baselines that use (1) different pretraining, (2) different heuristic augmentations and combinations thereof, and (3) focal loss. We investigated using JFT59 and ImageNet-21K60 for pretraining to explore how much different pretraining datasets impacted the final results. We investigated using RandAugment61, ImageNet Augmentations as described above, and RandAugment+ImageNet Augmentations to determine how much performance we could gain by using heuristic augmentations. Finally, we considered using focal loss62, which was developed to improve performance on imbalanced datasets.

All models used the same BiT-ResNet backbone58. We considered baselines that (1) used different pretraining, (2) used different heuristic augmentations, (3) resampled the dataset and (4) used the focal loss. We investigated using JFT59 and ImageNet-21K60 for pretraining. We investigated using RandAugment61, ImageNet Augmentations and RandAugment+ImageNet Augmentations. We then resampled the dataset so that the distribution over attributes was even (we upsampled samples from low-data regions so that they occurred more frequently in the dataset). Finally, we considered using focal loss62, which was developed to improve performance on imbalanced datasets.

To account for potential variations with respect to model initialization, we evaluated all versions of our model and baselines with five different initialization seeds and report the average and standard deviation across those runs for all metrics. We ran all experiments on tensor processing units.

Different definitions of fairness have been proposed in the literature, which are often at odds with each other63. In this section we discuss our choice of fairness metrics for each modality. In histopathology, we used the gap between the best and worst performance among the in-distribution hospitals. For radiology, we considered AUC parity, namely the parity of the area under the ROC for different demographic subgroups identified by the sensitive attribute (A), which can be seen as the analog of equality of accuracy64. Therefore, for this modality, we report the AUC gap between males and females in Fig. 3a. We considered this most relevant given that the positive and negative ratio of samples across all conditions was very imbalanced.

In dermatology, we report the gap between the best and worst subgroup performance, where subgroups are defined based on the sensitive attribute axis under consideration in Fig. 4. We also report the central best estimate for the a posteriori estimate of performance (that is, top-3) difference between a group and its outgroup. The steps to obtain the values plotted in Supplementary Fig. 7 are the following: (1) we defined a group (and its matching outgroup) as the set of instances characterized with a particular value of a sensitive attribute A=, that is, group={(xi,ci)|ai=} and group={(xi,ci)|ai}. Here A {sex, skin type, age}; (2) we assumed a uniform Beta distribution Beta(1,1) as a prior for the performance difference between top3group and top3outgroup and fitted this to the observed data; (3) we sampled n=100,000 samples from the estimated posterior differences between tp3group and tp3outgroup and report the spread, that is, the standard deviation of the maximum a posteriori estimates, which can be interpreted as the central best estimate for fairness.

We computed domain mismatches considering the space where decisions are performed, that is, the output of the penultimate layer of each model. Thus, we projected each data point from the input space of size ({{mathfrak{R}}}^{64x64}) to a representation of size ({{mathfrak{R}}}^{6144}) and then computed the maximum mean discrepancy (MMD) between two distributions (that is, datasets). Given two distributions (U) and (Z), their respective samples (hat{U}={{u}_{1},ldots ,{u}_{N}}) and (hat{Z}={{z}_{1},ldots ,{z}_{N}}), and a kernel (K), we considered the MMD empirical estimate as defined below:

$$begin{array}{l}{widehat{{rm{MMD}}}}^{2}(u,{mathcal{Z}})=frac{1}{N(N-1)}mathop{sum }limits_{i,;j=1}^{N}K({u}_{i},{u}_{j})+frac{1}{N(N-1)}mathop{sum }limits_{i,;j=1}^{N}K({z}_{i},{z}_{j})\qquadqquadqquadquad-frac{2}{{N}^{2}}mathop{sum }limits_{i,;j=1}^{N}K({u}_{i},{z}_{j})end{array}$$

(4)

We used a cubic polynomial kernel to minimize the number of hyperparameters to be selected and to capture mismatches between up to the third-order moments of each distribution. We computed (S=30) estimates of MMD between all pairs of domains using representations from the different models considering samples of size (n=300). A MannWhitney U-test under a significance level of 95% was then carried out to test for the hypothesis that, for a fixed pair of distributions, the data augmentation strategy had a significant effect on the estimated MMD values. Importantly, we highlight that models were trained under the same experimental conditions so that our analysis was capable of isolating the effect of the data augmentation protocol on the estimated pairwise distribution shifts.

In this section, our analysis focuses on the modality of dermatology and puts forward several properties of our synthetic data that may be important for our experimental results, which demonstrate the utility of synthetic data for improving performance.

First, we show images generated at high resolution for this challenging natural setting and several dermatological conditions in Fig. 5. Our conditional generative model captured the characteristics well for multiple, diverse conditions, even for cases that are more scarce in the dataset, such as seborrheic dermatitis, alopecia areata and hidradenitis.

We further evaluated how realistic the generated images were as determined by expert dermatologists to validate that these images did contain properties of the disease used for conditioning. Synthetic images did not need to be perfect, as we were interested in the downstream diagnostic performance. However, being able to generate realistic images validates that the generative model captures the relevant features of the conditions. To evaluate this, we asked dermatologists to rate a total of 488 synthetic images each, evenly sampled from the four most common classes (eczema, psoriasis, acne, seborrheic keratosis/irritated seborrheic keratosis) and four high-risk classes (melanoma, basal cell carcinoma, urticaria, SCC/SCCIS). They were tasked to first determine if the image was of a sufficient quality to provide a diagnosis. They were then asked to provide up to three diagnoses from over 20,000 common conditions with an associated confidence score (out of 5, where 5 was most confident). These 20,000 conditions were mapped to the 27 classes we used in this paper (one class, Other, encompasses all conditions not represented in the other 26 classes). We report the mean and standard deviation for all metrics across the three raters; 50.012.6% of those images were of a sufficient quality for diagnosis, while dermatologists had an average confidence of 4.130.43 out of 5 for their top diagnosis. They had a top-1 accuracy of 56.011.9% on the generated images and a top-3 accuracy of 67.712.5%.

We compared these numbers to a set of real images of the same eight conditions considered above (for the images considered, most raters considered the diagnosis of this disease as the most prevalent in the image). Among 101 board-certified dermatologists rating 789 real images in total, we found that their top-1 accuracy was 54.021.1% and top-3 accuracy 67.122.7%; a slightly higher performance in terms of top-1 (63%) and top-3 (75%) accuracy was shown by Liu et al.4 across a more diverse set of dermatological conditions. For this latter analysis, if an image was rated by n dermatologists, we considered a single raters accuracy with respect to the aggregated diagnosis of the remaining n1 raters. This demonstrates that, when diagnosable as per the experts evaluation, synthetic images are indeed representative of the condition they are expected to capture and similarly so to the real images. Even though not all generated images were diagnosable, this can also be the case for real samples, given that the images used to train the generative model did not necessarily include the body part or view that best reflected the condition.

We hypothesized that the reason why models are more robust to prevalence shifts is because of synthetic images being more canonical examples of the conditions. To understand how canonical ground truth images for a particular condition were, we investigated cases with a high degree of concordance in raters assessments and compared those to synthetic images for the same condition. More specifically, we thresholded the aggregated ground truth values to filter the images within the training data that experts were most confident about presenting as a condition. The aggregation function operates as follows: assume we have a set of four conditions ({A,{B},{C},{D}}); if rater ({R}_{1}) provides the following sequence of ((mathrm{condition},mathrm{confidence})) diagnosis tuples ({(A,4),(B,3)}) and rater ({R}_{2}) provides ({(A,3),(D,4)}), then we obtained the following soft labels ({0.5,mathrm{0.167,0,0.333}}) (after weighting each condition with the inverse of its rank for each labeler, summing across labelers and normalizing their scores to 1). If we looked for instances for which there is consensus among raters and high confidence that a condition is present, we could threshold the corresponding soft label for that condition with a strict threshold, for example, (t=0.9). In our example, this did not hold for any of the four conditions; however, if we lowered the threshold to 0.5, then it would hold for condition (A). In Extended Data Fig. 5 we show an example for melanoma. For this particular diagnostic class, we generated multiple synthetic instances of the condition, while we recovered only five images (out of more than 15,000) that clinicians rated with high confidence, that is, ({t}_{mathrm{melanoma}}=0.9). The nearest neighbors from the training dataset identified based on an ({l}^{2})-norm are also shown in Extended Data Fig. 5.

Previous work on OOD generalization65,66,67 pointed out that several factors can affect the performance of a model on samples from domains beyond the training data. In this analysis, we investigated the models trained with our proposed learned augmentations in terms of changes in distribution alignment between all pairs of distributions measured using MMD68. We computed domain mismatches considering the space where decisions are performed and projected each data point from the input space to a representation. We found that learned augmentations yielded on average 18.6% lower MMD compared to heuristic augmentations (for more details, refer to Methods, Distribution shift estimation) which leads to the following conclusions: (1) data augmentation has a significant effect on distribution alignment. Improvement on OOD performance suggests this is happening via learning better predictive features rather than capturing spurious correlations; (2) the generated data help the model to better match different domains by attenuating the overall discrepancy between domains; (3) given the minor decline in performance when adding generated data in the less skewed setting, as shown in Fig. 4, these findings suggest that learning such features might conflict with learning spurious correlations that were helpful for in-distribution performance. In other words, introducing synthetic data allowed the diagnostic model to allocate more capacity for disease-specific features rather than domain-specific (for example, hospital) features.

To further compare the effect of different augmentation schemes on the features learned by the diagnostic model, we investigated the representation space occupied by all considered datasets, including samples obtained from the generative model. In practice, we projected n randomly sampled instances from each dataset to the feature space learned by each model and applied the principal component analysis algorithm69 to identify the most significant modes of variation. We then extracted the number of principal components required to represent different fractions of the variance across all instances. We observed that for a fixed dataset, features from models trained with synthetic data require 5.4% fewer principal components to retain 90% of the variance in the latent feature space (results for different fractions are provided in Supplementary Fig. 3). This indicates that using synthetic data induces more compressed representations compared to augmenting the training data in a heuristic manner. Considering this finding in the context of the results in Extended Data Table 3, we posit that the observed effect is due to domain-specific information being attenuated in the feature space learned by models trained with synthetic data. This suggests that our proposed approach is capable of reducing the models reliance on correlations between inputs and labels that do not generalize OOD. For example, if most images of melanoma in the training set correspond to individuals with light skin tones, the model could learn to predict skin tone instead of the condition.

Extended Data Fig. 2 presents some examples of generated images by the class-conditioned diffusion models for healthy and abnormal whole-slide images of histological lymph node sections.

The histopathology dataset was balanced, so it did not demonstrate whether synthetic data were useful in the presence of data imbalance. To understand the impact of the number of labeled examples on both in-distribution and OOD generalization, we created different variants of the labeled training set, where we varied the number n of samples from two of the training hospitals. The number of labeled examples from one hospital was constant. For each value of n, we trained a diffusion model using the labeled and unlabeled dataset. We considered two settings when conditioning the diffusion model: (1) we used only the diagnostic label when available; and (2) we used the diagnostic label together with the hospital ID.

We subsequently sampled synthetic samples from the diffusion model and trained a downstream classifier that we evaluated on the held-out in-distribution and OOD datasets.

We trained the downstream classifier with five seeds and plotted the mean and standard deviation in Extended Data Fig. 1a. We found that using synthetic data outperformed both baselines consistently over varying n in-distribution. The same holds for the low-data regime in the OOD setting. Using our approach can achieve the performance that the baseline model achieves with 1,000 labeled samples in-distribution using only 110 samples (yielding 3 better label efficiency in terms of the low-data regions). We also performed color augmentation on top of the generated samples and found that this generalized best overall, leading to approximately 5% improvement OOD over the model trained with color augmentations in the high-data regime (1,00010,000 samples) and approximately 4.3% in the low-data regime (one labeled sample).

Extended Data Fig. 2 presents examples of the images generated by the class-conditioned diffusion models for healthy chest X-rays and those with thoracic conditions. Higher-resolution images were generated for chest X-rays (224224) compared to histopathology (9696), which requires training a separate upsampler diffusion model in the former case.

We show the models AUC values across method in-distribution and OOD in Extended Data Fig. 1b. Some conditions, that is, cardiomegaly, benefited significantly from synthetic data, while others, for example, effusion, benefited more from OOD than in-distribution. Finally, for atelectasis, synthetic images were only marginally beneficial to OOD.

We use the primary race labels obtained from https://stanfordaimi.azurewebsites.net/datasets/192ada7c-4d43-466e-b8bb-b81992bb80cf for the in-distribution CheXpert dataset. We plotted the difference between the best and worst performing group in terms of ROC-AUC against overall performance across conditions in Fig. 3b. The number of individuals associated with each racial label was as follows: white, 6,047; other, 1,623; white, non-Hispanic, 1,359; Asian, 1,254; unknown, 1,019; Black or African American, 557; race and ethnicity unknown, 513; other, Hispanic, 239; native Hawaiian or other Pacific Islander, 177; Asian, non-Hispanic, 166; Black, non-Hispanic, 133; white, Hispanic, 63; other, non-Hispanic, 39; patient refused, 31; American Indian or Alaska native, 30.

For each sensitive attribute and distribution shift, we ran all baselines with five random seeds. We then trained a diffusion model at 6464 (for faster iteration) using the labeled and unlabeled data for that specific shift and combined synthetic and real data. We considered conditioning either only on the label or on the label and sensitive attribute. We plot the top-3 accuracy, balanced accuracy, fairness metric and high-risk sensitivity on the in-distribution and OOD datasets in Supplementary Figs. 58. For both accuracy and fairness, we plotted the normalized metric. (We plotted the improvement over the baseline, where we use Pretrained on JFT as the baseline.)

First, we discuss the results on the accuracy metrics. Across all distribution shifts and all datasets, using generated data either improved or maintained the accuracy metrics on dermatology. In particular, generated data seemed to help most on the OOD, which had a stronger prevalence shift with respect to the training set and on the balanced accuracy metric.

Using heuristic augmentation helped, in particular RandAugment, which consistently improved over the baseline. The other methods (oversampling and focal loss) gave minimal improvements.

Next, we investigated results on the fairness metrics in Supplementary Fig. 7. Using heuristic augmentation led to no consistent improvement over the baseline. However, for sex, skin tone and age, our approach of using generated data consistently improved on or maintained the performance of the baseline model. This was true even on the OOD datasets, but more so for those characterized by stronger shifts in comparison to the in-distribution dataset (that is, OOD 2 was much more similar to the in-distribution dataset compared to OOD 1, where we observed the strongest improvements). This is impressive as Schrouff et al.18 demonstrated that improving fairness on in-distribution datasets does not guarantee performance improvements on OOD datasets. (Note that there were no skin tone labels for the OOD datasets, so for skin tone we only report the results on the in-distribution dataset.)

Finally, we investigated how using synthetic data impacts high-risk sensitivity in Supplementary Fig. 8. In the diagnostics, it is imperative not to miss someone with a high-risk condition. Thus, we investigated whether using synthetic data negatively or positively impacted the models ability to correctly identify the images of a high-risk condition. Of the 27 classes, three were identified as high-risk conditions: basal cell carcinoma, melanoma and SCC/SCCIS. By adding additional data, we wanted to improve (or at least not harm) high-risk sensitivity. We investigated high-risk sensitivity on both the training dataset (held out part of it) and the two OOD datasets. We found that across distribution shifts and datasets, using the additional synthetic data either maintained or improved high-risk sensitivity, most notably on the most OOD dataset. Moreover, synthetic data were consistently similar or better than heuristic augmentation on this metric.

We found that in dermatology, using synthetic data had a host of benefits. While it can to some extent improve balanced accuracy while maintaining overall accuracy, additional synthetic data can improve fairness metrics both in-distribution and OOD and high-risk sensitivity for both in-distribution and OOD datasets. This demonstrates that using synthetic data as an augmentation tool has promise for improving fairness and the diagnosis of high-risk conditions.

We computed domain mismatches considering the space where decisions are performed, that is, the output of the penultimate layer of each model. Thus, we projected each data point from the input space to a representation. We computed multiple estimates ((S)) of MMD between all pairs of domains using representations from the different models considering samples of size (n). Models were trained under the same experimental conditions so that our analysis was capable of isolating the effect of data augmentation on the estimated pairwise distribution shifts. In addition to the heuristic augmentation discussed in the main text, we further included models trained with RandAugment in this analysis. All findings are summarized in Extended Data Table 3.

From the three considered augmentation schemata, RandAugment yielded representations that were more aligned in comparison to the learned and heuristic augmentations for all pairs of domains. We hypothesized this augmentation strategy would promote better in-distribution generalization by allowing domain-specific cues to be removed at the expense of learning spurious correlations. Evidence to support this hypothesis can be found in Supplementary Fig. 7, which shows that models trained with RandAugment yielded improved performance in-distribution and in the OOD 2 domain, which is more similar to the training distribution than OOD 1 (Extended Data Fig. 4).

Inspired by a recent study by Bommasani et al.70 that looked at how often the same individuals are underserved by ML models that have been trained on the same data, we investigated whether the same individuals with high-risk conditions were consistently misclassified. In Extended Data Fig. 6, we illustrate for all sample IDs across the in-distribution and OOD evaluation datasets whether there were particular individuals within each demographic subgroup (male or female) who benefited more from the generated data than from other augmentation techniques. For each of the three setups, that is, (1) standard ImageNet augmentations, (2) RandAugment and (3) generated data, we performed five training runs and considered a test sample as incorrectly classified for a setup if it had been consistently misclassified by its five trained models. For better comparison, we reordered the sample indices such as to form contiguous blocks of correctly and incorrectly classified samples. While most of the individual predictions were the same between setups, each setup enabled some samples to be correctly classified, which the other setups could not. Particularly, in Extended Data Fig. 6a, d, training with generated data significantly reduced the number of consistently misclassified samples compared to standard ImageNet augmentations or RandAugment. Even though the training dataset was more skewed toward females, OOD males with high-risk conditions in panel d were more often correctly classified for a model trained with the generated data. Hence, using generated data reduced the number of underserved individuals compared to standard augmentation techniques, which only applied basic transforms to the original data. Finally, we observed that these training setups were complementary as each of them had its own set of well-classified samples. This could open new research directions for model ensembling to create new models that would benefit from this diversity in individual predictions.

Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.

Here is the original post:

Generative models improve fairness of medical classifiers under distribution shifts - Nature.com

Related Posts