Implementing kl divergence in pytorch

Implementing kl divergence in pytorch


KL divergence, or Kullback-Leibler divergence, is a measure used in information theory and statistics to quantify the difference between two probability distributions. It's a non-symmetric, non-negative measure that indicates how much one distribution diverges from another. In essence, KL divergence measures the "distance" between the two distributions in terms of the information they convey.

Let's take example values for two 3-dimensional Gaussian distributions and use the element-wise KL divergence formula to compute the KL divergence between them.

Gaussian Distribution 1:

  • Mean vector (μ1): [2, 4, 6]
  • Standard deviation vector (σ1): [1, 2, 3]

Gaussian Distribution 2:

  • Mean vector (μ2): [3, 5, 8]
  • Standard deviation vector (σ2): [1.5, 2.5, 4]

First, compute the variance vectors for both distributions:

  • Variance vector (σ1²): [1, 4, 9]
  • Variance vector (σ2²): [2.25, 6.25, 16]

Now, we can compute the element-wise KL divergence using the formula:

KL = 0.5 * Σ_i (((μ1_i - μ2_i)² / σ2_i²) + (σ1_i² / σ2_i²) - 1 - log(σ1_i²) + log(σ2_i²))        

For each component i:

  1. i = 0 => ((2 - 3)² / 2.25) + (1 / 2.25) - 1 - log(1) + log(2.25) = 0.2410
  2. i = 1 => ((4 - 5)² / 6.25) + (4 / 6.25) - 1 - log(4) + log(6.25) = -0.0061
  3. i = 2 => ((6 - 8)² / 16) + (9 / 16) - 1 - log(9) + log(16) = 0.0623

summing the results from all the components

0.5 * (0.2410 - 0.0061 + 0.0623) = 0.1486        

Now let's implement the same thing in pytorch


import torch

set_1_mean = torch.tensor([2, 4, 6], dtype=torch.float32

std_1 = torch.tensor([1,2,3], dtype=torch.float32)

set_2_mean = torch.tensor([3,5,8], dtype=torch.float32)

std_2 = torch.tensor([1.5,2.5,4], dtype=torch.float32)

def kl_div():

    var_1 = torch.pow(std_1, 2)

    var_2 = torch.pow(std_2, 2)

    return 0.5 * torch.sum(

        torch.pow(set_1_mean - set_2_mean, 2) / var_2

        + var_1 / var_2

        - 1.0

        - torch.log10(var_1)

        + torch.log10(var_2)

    ))

print(kl_div()) #0.1486         

you can see that result is same !

To view or add a comment, sign in

More articles by Pasha Sheikh

Explore content categories