A Technical and Regulatory Perspective on Generative Adversarial Networks in Medical Devices

 March 19, 2021
  Jacob Reinhold and Yujan Shrestha, MD
SHARE ON

AI/ML

Introduction 🔗

A fundamental task in image processing and medical image analysis is to improve images for some task. For example, a noisy image may be denoised or a low-resolution image may be super-resolved; in both cases, and in general, the enhancement can be for a human viewer or another image processing method. In recent years, deep neural networks (DNNs) are commonly used to conduct these types of enhancement due to what are often empirically better results. One uniquely useful DNN framework for image enhancement—as well as a slew of other use cases—is that of that of Generative Adversarial Networks (GANs).

GANs were used, in their inception, to generate realistic images like some dataset from scratch (that is, sample from the dataset’s underlying distribution). Research into GANs continues to showcase results like this; you’ve likely seen face images of “people that don’t exist.” These results were likely generated with GANs. But the adversarial framework that is central to GANs has other uses.

As GANs have matured, the medical imaging community has adopted them for various tasks. For example, data augmentation, semi-supervised learning, segmentation, and anomaly detection. In this article, however, we’ll focus on two image enhancement methods—image-to-image translation and super-resolution.

Image-to-image translation is the task of making an image appear like some other type of image. In medical imaging, it is often used to fill in data for a subject that was not acquired (for example, making an MR image appear as a CT image to enable better intracranial volume estimation from the MR image alone).

Super-resolution is the task of making an image appear as if it were acquired at a higher-resolution. While the true high-frequency information is not actually recovered in most applications of super-resolution, a super-resolved image is often preferred for certain tasks. For example, suppose you have a whole-brain segmentation DNN that was trained with 1x1x1 mm³ and you want to segment an image acquired at 1x1x3 mm³; due to the sensitivity of DNN performance with respect to the characteristics of the input, the output segmentation may be better if you first super-resolve the image to appear as if it were acquired at the 1x1x1 mm³ resolution like the training images.

In the remainder of the article, we’ll first discuss the GAN framework as it was in its inception and how that framework is modified for image enhancement. Then we’ll more specifically address how GANs are used for image-to-image translation and super-resolution. Finally, we’ll give some regulatory perspective on GANs.

Generating realistic samples with GANs 🔗

As can be inferred by their name, GANs generate realistic images by adversarially pairing off two neural networks against each other. These two networks are called a generator and a discriminator.

The task of the generator is to map a noise vector to a realistic image. The task of the discriminator is to detect the fake images created by the generator when the discriminator is given both the output of the generator and the real data.

In this way, the generator and discriminator are playing a game against one another. An often used analogy is that of counterfeiters and the police: Counterfeiters try to make fake money that is perceived to be real and the police tries to detect the fake money from real money. To pass off the fake money, the counterfeiters must create increasingly realistic versions of the money to fool increasingly sophisticated police. As the game progresses, the counterfeiters end up making “supernotes” and the police end up looking for optically variable ink and 3D security ribbons. The hope with GANs is that a similar type of arms race will take place, with the generator playing the role of the counterfeiters and the discriminator playing the role of the police, and in the end the generator will be able to produce images that are indistinguishable from real images.

To induce this arms race, the generator and discriminator are trained in tandem. Every epoch, the discriminator is trained to classify the generated images as fake and the generator is trained to fool the discriminator. Ideally this game converges to an equilibrium, where the discriminator does no better than a completely random guess when presented with a sample from the generator.

This arms race can be formulated as a minimax optimization problem where the desired point of convergence is an equilibria where neither the generator nor the discriminator can do (strictly) better by changing their parameters. The original paper used this equation:

\[\min_{G} \max_{D} V(G,D) = \frac{1}{2} \mathbb{E}_{x\sim p_\text{data}}[\log(D(x))] + \frac{1}{2}\mathbb{E}_{z\sim p_\text{noise}}[\log(1-D(G(z)))]\]

where $G$ is the generator, $D$ is the discriminator, and $x\sim p_\text{data}$ is a sample from the dataset and $z \sim p_\text{noise}$ is a vector of noise (for example, a vector of length $N$ whose entries are drawn from a standard normal distribution). When the discriminator is powerful enough, the optimization results in the generator trying to closely match the true data distribution (by minimizing the Jensen-Shannon divergence between the two distributions).

In image enhancement, this optimization procedure is modified so that, instead of a noise vector, the generator takes an image as input. The generator then tries to modify the image such that the desired property is enhanced. For image-to-image translation, the discriminator will be used to induce the generator to more realistically modify, for example, an MR image such that the underlying anatomy looks as if it were acquired from a CT image. In the case of super-resolution, the discriminator will induce the generator to realistically modify a low-resolution image such that it appears as if it were acquired at a higher resolution.

Problems with training GANs 🔗

One major issue with training traditional GANs, and consequently a major research focus, is “mode collapse.” This is a situation where the generator produces only a few realistic looking images when many are desired. This situation happens because the few samples generated are realistic enough to fool the discriminator but are numerous enough such that the discriminator can’t memorize the fake samples. In the worst case, mode collapse can cause the generator to produce only one output. Ideally, the generator should represent the diversity of the true distribution.

Another aspect of training can be particularly tricky; namely, that of vanishing gradients. As can be seen in the loss function written above, the generator is trained by gradients backpropagating through the discriminator. If the discriminator produces a (near) zero gradient somewhere, then the generator will fail to update its parameters. In practice, the loss function described above is not used because the loss is prone to saturation early in training; instead, (mostly heuristic) variants of the loss function are used.

In the image enhancement setting, the adversarial loss is used as an auxiliary term and these problems are mostly avoided. However, in certain settings like unpaired image-to-image translation, the training can run into the above problems.

Training a GAN on a toy dataset 🔗

The problems mentioned above make training traditional GANs a very finicky process. Getting even toy examples to work can be laborious in spite of advances in training methods. In this section, we’ll quickly walk through an example of training a (vanilla) GAN in PyTorch to produce samples from a very simple distribution: a mixture of two 2D Gaussians. A jupyter notebook with the full code is given here; note that the actual implementation slightly differs from what is shown here, and there are also implementations of several GAN variants included in the notebook that can be experimented with. While this example is very simple, it should provide a clear example of what GANs do and the problems associated with them.

The problem setup is that we have a set of data that can be described with the 2D histogram shown in Figure 1.

Diagram showing a 2D bimodal gaussian distribution.
Figure 1: True distribution which we will try to learn with a GAN

We want to train a generator to produce samples according to this distribution. In this case, both the generator and discriminator only need a small series of fully-connected layers to capture this distribution; however, if the samples consisted of high-quality real images (for example, that of high-resolution face photos as is commonly used in modern GAN research), then the network would need to be a much more complicated (convolutional) neural network with many tricks to make generator produce realistic samples.

Let x_real be the set of data from the above distribution. Note that the following steps will be taken in every epoch. As mentioned before, the generator takes a noise vector to a realistic sample. So we’ll first create a noise vector like so:

z = torch.randn_like(x_real)

Then, to train the discriminator network, we’ll simultaneously update the weights according to the discriminated fake samples and the real samples:

# discriminate real samples
D_real = discriminator(x_real)
real_labels = torch.ones(D_real.size(0), 1, device=device)
loss_real = F.binary_cross_entropy_with_logits(D_real, real_labels)

# discriminate fake samples
with torch.no_grad():    
	x_fake = generator(z)
D_fake = discriminator(x_fake)
fake_labels = torch.zeros(D_fake.size(0), 1, device=device)
loss_fake = F.binary_cross_entropy_with_logits(D_fake, fake_labels)        

# update discriminator weights
D_loss = loss_fake + loss_real
D_loss.backward()
D_opt.step()

To update the generator, we’ll take the following steps:

x_fake = generator(z)
D_fake = discriminator(x_fake)
G_loss = F.binary_cross_entropy_with_logits(D_fake, real_labels)

# update generator weights
G_loss.backward()
G_opt.step()

Ultimately, we’ll end up with something like we see in Figure 2.

2D histogram showing a single mode gaussian distribution.
Figure 2: Shows an example distribution learned where the random seed was set to 8 for training. The plot was generated by sampling noise vectors from a standard normal and passing them through the generator.

Note that the generator captures one mode but not the other; this is the previously mentioned problem of mode collapse. If you change the random seed at the beginning of the notebook from a 8 to a 9, the generator captures the other mode of the distribution.

2D histogram showing a (subtle) bi-modal gaussian distribution.
Figure 3: Example distribution learned where the random seed was set to 9 for training. The plot was generated with the same procedure as Figure 2.

In reality, the above code didn’t work well—there was mode collapse worse than what is shown in Figures 2 and 3—and a different network and training scheme had to be used to get the diversity shown in those plots above. But the principle of the training scheme is the same. In the scheme described above, the observed mode collapse problem was more severe. The plots generated above are actually fairly nice: the mode that is learned by the generator matches the true distribution well in both cases.

To get a better feel for GANs, play around with the code in the jupyter notebook and see how finicky the training process can be and how the degenerate conditions can arise.

GANs for image-to-image translation 🔗

The adversarial framework of GANs is amenable to the task of image-to-image translation. There are two main directions in image-to-image translation which I’ll discuss separately: paired and unpaired image-to-image translation.

Paired image-to-image translation 🔗

Paired image-to-image translation is used when you have examples of the translation that you want the generator network to perform. For example, if you have MR and CT images from the same subject, you can register the two images and the network can learn the pixel- or voxel-wise correspondence between the two types of images. In this case, GANs are used in an auxiliary role, and the generator is trained with two terms in the loss function. One term is the adversarial loss term, which is as described previously; that is, it’s a term that trains the generator as a result of how the output does according to the discriminator which tries to classify samples from the generator as fake. The other term is a content loss term which is similar to the mean squared error; that is, the content loss term tries to take advantage of the fact that there are examples where the pixel- or voxel-wise correspondence can be learned.

The two terms work together to make the translated image respect the geometry (i.e., anatomy) present in the image (due to the content term) while looking more realistic (due to the adversarial term). The presence of the content loss also makes training the network much easier. There are no real problems, to the best of our knowledge, with mode collapse when using this scheme.

The model that is most frequently used in this space of image-to-image translation is that of the pix2pix model, which employs almost exactly what was described above but with an $L_1$ content loss term instead of mean squared error. ($L_1$ loss is commonly used in image-to-image translation because the resultant images are often qualitatively and quantitatively better.) Code for this model is publicly available and in frequent use, both in PyTorch and Tensorflow. See, for example, this paper for a use-case of this framework for paired image-to-image translation.

Unpaired image-to-image translation 🔗

More frequently than not, paired examples of the correspondence you want to learn are difficult to obtain and that is where unpaired image-to-image translation techniques come in. Suppose you want to learn how to turn an MR image into a CT image and only have access to an unpaired set of MR and CT images. You can approximate the mapping with a model called the CycleGAN.

CycleGAN translates between two domains, $A$ and $B$, using two generator-discriminator pairs. One generator learns the mapping $G:A\to B$ and the other learns $F: B\to A$. One discriminator $D_A$ learns to detect real examples of images from set $A$ and the other $D_B$ detects real examples from set $B$. This framework might appear as if it would be enough to perform the translation, but that is not the case because the generator can arbitrarily change the input image to fool the discriminator. That is, it doesn’t need to respect the geometry (i.e., the anatomy) present in the image.

To encourage the generators to respect the geometry in the input image, the scheme adds an additional loss term to the normal adversarial loss. This additional term is called a “cycle-consistency” loss. The idea is that an input should be image when cycled through both generators. In terms of math, that would mean for sample $x$ from set $A, F(G(x)) \approx x$, and similarly for a sample $y$ from set $B$. The cycle consistency term is then the $L_1$ loss between $F(G(x))$ and $x$ as well as $G(F(y))$ and $y$.

This clever trick has been shown to work well, but it has some practical difficulties. CycleGANs are still prone to mode collapse. But there are deeper problems with this method. Namely, the cycle consistency doesn’t guarantee the generator will respect the geometry of the input image. Transforming a circle to a square and back is a cycle consistent transformation. A recent paper has shown this problem in application; they train a CycleGAN to perform an image-to-image translation task on healthy images and then input an image with a tumor, and the resulting image has removed the tumor!

There has been a slew of papers exploring how to reduce the severity of this problem. Most encourage the network to maintain the geometry of the input image by adding other loss terms, however, there are no guarantees. Images produced with this method must be handled with care.

A paper using a CycleGAN for the problem of MR to CT image translation is here. Like pix2pix, this model is also publicly available (in fact, they are in the same package). See, for example, this paper for a use-case of a CycleGAN for unpaired image-to-image translation.

GANs for super-resolution 🔗

Super-resolution with GANs is—like the paired image-to-image translation problem—broken down into optimizing the generator with two loss functions capturing different characteristics that are desired in the output; that is, the content and adversarial terms. In super-resolution, the main approach to the problem is to take high-resolution examples, low-pass filter them so that they appear as low-resolution images, and learn the mapping between the simulated low-resolution image and the high-resolution image. The underlying idea being that a real low-resolution image can then be put into the network that has learned this mapping and an approximate high-resolution image will be output.

The method then uses almost an identical approach to paired image-to-image translation where the terms even express the same desiderata; an example of which can be seen here. However, unlike image-to-image translation, there are methods to use a single image for super-resolution (see this paper). In fact—because medical image volumes like MR and CT images are frequently acquired with an anisotropic resolution—medical images are particularly well-suited for single-image super-resolution. For example, if the an MR volume were acquired at 1x1x3 mm³ resolution, then the high-resolution in-plane images can be used to train a network using the simulated low-resolution method described above and be applied to the low-resolution through-plane direction. See this paper for more information.

Super-resolved images, however, should be handled with caution. Real super-resolution, outside of very limited circumstances, is theoretically impossible. The acquisition procedure is non-invertible, so the true high-resolution image cannot be recovered with a meaningful guarantee. A super-resolution network can fill in likely high-frequency information (making the image appear as if it were acquired with high-resolution), but these methods cannot provide the guarantees necessary to use super-resolved images in high-stakes situations like diagnosis. A super-resolved image can potentially remove real important data from or add misleading data to an image; in spite of the resulting image looking like it contains more information because of the perceived higher-resolution. In spite of these limitations, super-resolution is useful for image analysis pipelines. For example, to super-resolve an image so that it is the “same” resolution as the data a segmentation network was trained on—potentially improving the segmentation performance.

Takeaways 🔗

We discussed GANs their purpose and the problems they face. In doing so, we implemented a toy example that showcased the utility of GANs and provided a sandbox in which to experiment with simple GAN methods. As can be seen in the implementation and as discussed above, GANs are finicky models to train and must be carefully corralled to be used as intended.

Image enhancement is an area in which GANs shine. They can add highly-detailed features back into the image that wouldn’t otherwise be present with other more traditional methods. While GANs cannot perform magic and truly enhance an image beyond what is theoretically possible, GANs can help enhancing images to the point where the images are more useful for human viewing or image processing pipelines.

Regulatory Perspective 🔗

As of March 2021, we are not aware of any FDA guidance on the usage of GANs in medical devices. However, from our current understanding of medical device software and AI/ML regulations, we feel comfortable making the following recommendations:

  1. It is risky to use the generative portion of a GAN in a medical device. For example, using a GAN to fill in missing slices or convert a T2 MRI sequence into a T2-FLAIR are probably unacceptable.
  2. It is likely acceptable to use the discriminator part of a GAN in a medical device. For example, a GAN can be trained to discriminate between a benign and malignant tumors. The generator part of the GAN can be thrown away leaving only the discriminator to be used in a clinical environment. From a regulatory standpoint, we believe a GAN is not fundamentally different to other binary classification architectures such as one-shot learning with Siamese Networks or the ubiquitous CNN followed by fully connected layers.
  3. It is probably acceptable to use a GAN for data augmentation provided it is validated by clinical experts. You need to make sure the GAN is outputting clinically representative images. For example, you need to make sure it is not getting too creative and generating CTs with three kidneys. We also advise you to turn off all data augmentation on your final testing data partition to reduce the likelihood of overfitting to the GAN augmented images.
  4. Keep in mind that your product may not even be a medical device at all. For example, a GAN used to generate brain tumors for a USMLE question bank is not considered a medical device. FDA oversight does not apply in these cases, but you will still need to prove your product works to potential customers. We still advise you to follow good engineering principles such as those described in our previous article.

Further reading 🔗

Yi, Xin, Ekta Walia, and Paul Babyn. “Generative adversarial network in medical imaging: A review.” Medical image analysis 58 (2019): 101552. Link to pre-print

Kazeminia, Salome, et al. “GANs for medical image analysis.” Artificial Intelligence in Medicine (2020): 101938. Link to pre-print

Wang, Zhengwei, Qi She, and Tomas E. Ward. “Generative adversarial networks in computer vision: A survey and taxonomy.” arXiv preprint arXiv:1906.01529 (2019). Link to paper

Goodfellow, Ian. “NIPS 2016 tutorial: Generative adversarial networks.” arXiv preprint arXiv:1701.00160 (2016). Link to paper

Creswell, Antonia, et al. “Generative adversarial networks: An overview.” IEEE Signal Processing Magazine 35.1 (2018): 53-65. Link to pre-print

SHARE ON
×

Get To Market Faster

Monthly Medtech Insider Insights

Our monthly Medtech tips will help you get safe and effective Medtech software on the market faster. We cover regulatory process, AI/ML, software, cybersecurity, interoperability and more.