Skip to content

Representation Learning with Mutual Information Maximization

f-GAN

2016 NIPS - f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization 1

f-divergence

Suppose we want to train a generative model Q that generates data as realistic (close to the true data distribution P) as possible. In other words, we wish to minimize the f-divergence

Df(PQ)=Xq(x)f(p(x)q(x))dx=Eq(x)f(p(x)q(x))

where f:R+R is convex, lower-semicontinuous and satisfies f(1)=0.

Below is a table of the f-divergence family.

Name Df(PQ) f(u) T(u)
Total variation 12|p(x)q(x)|dx 12|u1| 12sign(u1)
Kullback-Leibler (KL) p(x)logp(x)q(x)dx ulogu 1+logu
Reverse KL q(x)logq(x)p(x)dx logu 1u
Pearson χ2 (q(x)p(x))2p(x)dx (1u)2 2(u1)
Neyman χ2 (p(x)q(x))2q(x)dx (1u)2u 11u2
Squared Hellinger (p(x)q(x))2dx (u1)2 (u1)1u
Jeffrey (p(x)q(x))log(p(x)q(x))dx (u1)logu 1+logu1u
Jensen-Shannon 12p(x)log2p(x)p(x)+q(x)+q(x)log2q(x)p(x)+q(x)dx u+12log1+u2+u2logu 12log2uu+1
α-divergence 1α(α1)p(x)[(q(x)p(x))α1]α(q(x)p(x))dx 1α(α1)(uα1α(u1)) 1α1(uα11)

Fenchel conjugate

The Fenchel conjugate of function f(x) is defined as f(x)=supxdomf{x,xf(x)}.

We can easily verify that f is convex and lower-semicontinuous. When f is also convex and lower semi-continuous, f=f.

Variational representation of the f-divergence

We now derive the variational lower bound on f-divergence:

Df(PQ)=Xq(x)suptdomf{tp(x)q(x)f(t)}dxsupTT{Xp(x)T(x)dxXq(x)f(T(x))dx}=supTT{Ep(x)T(x)Eq(x)f(T(x))}

where T is a class of functions T:XR. It is straightforward to see that the optimal T(x)=f(p(x)q(x)) (please do not confuse with conjugate function) by substituting the definition of f and let t be p(x)q(x)

Df(PQ)Ep(x)T(x)Eq(x)[p(x)q(x)T(x)f(p(x)q(x))]=Df(PQ)

The critical value f(1) can be interpreted as a classification threshold applied to T(x) to distinguish between true and generated samples.

f-GAN Objective

We parameterize the generator Q with parameter θ and the discriminator T with parameter ω. The f-GAN objective is then defined as

minθmaxωF(θ,ω)=Ep(x)Tω(x)Eqθ(x)f(Tω(x)).

To account for domf of various f-divergence, we further decompose Tω(x) into Tω(x)=gf(Vω(x)), where Vω:XR is a neural network and gf:Rdomf is an output activation function.

Name gf domf f(t) f(1)
Total variation 12tanh(v) [12,12] t 0
Kullback-Leibler (KL) v R et1 1
Reverse KL ev R 1log(t) 1
Pearson χ2 v R 14t2+t 0
Neyman χ2 1ev (,1) 221t 0
Squared Hellinger 1ev (,1) t1t 0
Jeffery v R W(e1t)+1W(e1t)+t2 0
Jensen-Shannon log2212log(1+ev) (,log22) 12log(2e2t) 0
α-div. (α(0,1)) 11αlog(1+ev) (,11α) 1α(t(α1)+1)αα11α 0
α-div. (1<α) v R 1α(t(α1)+1)αα11α 0

where W is the Lambert-W product log function.

Mutual Information Neural Estimator (MINE)

2018 ICML - MINE: Mutual Information Neural Estimation 2

MINE has two variants termed MINE and MINE-f. The former uses the Donsker-Varadhan representation of the KL divergence, which results in a tighter estimator; the latter uses the f-divergence representation described above.

The Donsker-Varadhan representation of KL

DKL(PQ)supTT{Ep(x)T(x)logEq(x)eT(x)}

Proof: Consider the Gibbs distribution g(x)=1Zq(x)eT(x) where Z=Eq(x)eT(x). Then

Δ:=DKL(PQ)Ep(x)T(x)logEq(x)eT(x)=DKL(p(x)q(x))Ep(x)T(x)logZ=DKL(p(x)q(x))Ep(x)(T(x)logZ)=Ep(x)p(x)q(x)Ep(x)g(x)q(x)=DKL(p(x)g(x))0

where T is a class of functions T:XR such that the two expectations are finite. The equality holds when g(x)p(x), i.e. T(x)=logp(x)q(x)+C.

The f-divergence representation of KL

Adopting the variational lower bound for f-divergence, we have

DKL(PQ)supTT{Ep(x)T(x)Eq(x)eT(x)1}

and the optimal T(x)=1+logp(x)q(x).

Estimating Mutual Information

I(X;Z)=DKL(p(x,z)p(x)p(z))supθΘ{Ep(x,z)Tθ(x,z)logEp(x)p(z)eTθ(x,z)}

We estimate the expectations with empirical samples

I^(X;Z)n=supθΘV(θ)=supθΘ{Ep(n)(x,z)Tθ(x,z)logEp(n)(x)p^(n)(z)eTθ(x,z)}

When using stochastic gradient descent (SGD), the gradient update of MINE

θV(θ)=EBθT(θ)EBeTθθTθEBeTθ

is a biased estimate of the full gradient update (Why?). This is corrected by an exponential moving average applied to the denominator. For MINE-f, the SGD gradient is unbiased.

Contrastive Predictive Coding (CPC) and the InfoNCE Loss

2010 AISTATS - Noise-contrastive estimation: A new estimation principle for unnormalized statistical models 3

2018 NeurIPS - Representation Learning with Contrastive Predictive Coding 4

Noise-Contrastive Estimation (NCE)

Suppose we have observed data xpd() and we want to estimate a model from a family {pm(;α)}α where α is the model parameter. The challenge is that often it is more convenient to define an unnormalized model pm0 such that

pm(;α)=pm0(;α)Z(α)where Z(α)=pm0(u;α)du.

The integral Z(α) is rarely analytically tractable, and if the data is highdimensional, numerical integration is difficult. We include the normalization constant Z(α) as an additional parameter clogZ(α), so that

logpm(;θ)=logpm0(;α)+cwhere θ=α,c.

Performing Maxmimum Likelihood Estimation (MLE) on this objective is not feasible as c would be pushed to infinity. Instead we learn to discriminate between the data x and some artificially generated noise ypn. With T positive (data) and T negative (noise) examples, we aim to correctly classify each of them, and thus define the NCE objective as

JT(θ)=12Tt=1T[logh(xt;θ)+log(1h(yt;θ))]=12Tt=1T[logpm(xt;θ)pm(xt;θ)+pn(xt)+logpn(yt)pm(yt;θ)+pn(yt)]

where h(u;θ)=σ(logpm(u;θ)logpn(u)).

This blog post (in Chinese) shows by gradient calculation that when the number of negative samples approches infinity, the NCE gradient equals to the MLE gradient.

Contrastive Predictive Coding

Let {xt} be a sequence of observations, zt=genc(xt) be the encoded latent representation at time step t, and ct=gar(zt) be the summarized context (global, ar for auto-regressive) latent representation at time step t. Given a set X={x1,,xN} of N random samples containing one positive sample from p(xt+k|ct) and N1 negative samples from the 'proposal' distribution p(xt+k), we wish to preserve the mutual information between the k-step-later input xt+k and the current context ct, by trying to identify the positive sample among all the samples:

p(d=i|X,ct)=p(xi|ct)lip(xl)j=1Np(xj|ct)ljp(xl)=p(xi|ct)p(xi)j=1Np(xj|ct)p(xj)=fi(xi,ct)j=1Nfj(xj,ct)

where fk(xt+k,ct)=Cp(xt+k|ct)p(xt+k) and C is an arbitrary constant. Note that f is unnormalized and can be parameterized by a simple log-bilinear model

fk(xt+k,ct)=exp(zt+kTWkct).

To maximize our contrastive predictive capabilities, we minimize the following InfoNCE loss:

L(InfoNCE)=EX[logfk(xt+k,ct)xjXfk(xj,ct)]

Relation with Mutual Information

I(xt+k;ct)logNL(InfoNCE)

Proof:

L(InfoNCE)=EXlog(1+p(xt+k)p(xt+k|ct)jt+kp(xj|ct)p(xj))EXlog(1+p(xt+k)p(xt+k|ct)(N1)ExjXnegp(xj|ct)p(xj))EXlog(1+p(xt+k)p(xt+k|ct)(N1))EXlog(Np(xt+k)p(xt+k|ct))=I(xt+k;ct)+logN

Note that the approximation is more accurate as the number of negative samples increases.

Relation with MINE

Let F(x,c)=logf(x,c), then

L(InfoNCE)=EX[fk(xt+k,ct)xjXfk(xj,ct)]=EXF(xt+k,ct)EXlog(eF(xt+k,ct)xjXnegeF(xj,ct))EXF(xt+k,ct)EctlogxjXnegeF(xj,ct)=EXF(xt+k,ct)Ect[log1N1xjXnegeF(xj,ct)+log(N1)]

which is equivalent to the MINE estimator:

I^(X;Z)n=supθΘV(θ)=supθΘ{Ep(n)(x,z)Tθ(x,z)logEp(n)(x)p^(n)(z)eTθ(x,z)}

Deep InfoMax (DIM)

2019 ICLR - Learning deep representations by mutual information estimation and maximization 5

Deep InfoMax is a principled framework for training a continuous and (almost everywhere) differentiable encoder Eψ:XZ to maximize mutual information between its input and output, with neural network parameters ψΨ.

Assume that we are given a set of training examples on an input space, X:={x(i)X}i=1N, with empirical probability distribution P. We define Uψ,P as the marginal distribution of z=Eψ(x) where x is sampled from P, i.e., u(z=Eψ(x))=(xEψ(x))1p(x).

We assert our encoder should be trained according to the following criteria:

  • Local and global mutual information maximization
  • Statistical constraints (prior in the latent space v(z)).
minψI(X;Z)+λDKL(v(z)u(z))

As a preliminary, we introduce the local feature encoder Cψ, the global feature encoder Eψ=fψCψ and the discriminator Tψ,ω=Dωg(Cψ,Eψ), where Dω is a neural classifier, and g is a function that combines the local and global features.

The overall DIM objective consists of three parts, global MI, local MI and statistical constraints.

maxω1,ω2,ψ(αI^ω1,ψ(X;Eψ(X))+βM2i=1M2I^ω2,ψ(X(i);Eψ(X)))+minψmaxϕγD^ϕ(VUψ,P)

In the following sections, we first introduce how to enfore statistical constraints D^ϕ and local MI maximization, then discuss objectives for general MI maximization I^.

Statistical Constraints

Matching the output of the encoder to a prior.

Why use adversarial objectives for KL regularization?

Here we could also use VAE-style prior regularization minDKL(q(z|x)p(z)), but this assumes for every data point x, its latent q(z|x) is close to p(z). This will encourage q(z) to pick the modes of p(z), rather than the whole distribution of p(z). See the Adversarial AutoEncoders paper for more details.

DIM imposes statistical constraints onto learned representations by implicitly training the encoder so that the push-forward distribution, Uψ,P, matches a prior V. Following variational representation of the Jensen-Shannon divergence, we optimize this objective by

minψmaxϕDJS(VUψ,P)=Ev(z)logDϕ(z)+Ep(x)log(1Dϕ(Eψ(x)))

Note that the discriminator Dϕ operates in the latent space rather than the input space.

Local MI Maximization

Maximizing the mutual information b/t encoder input and output may not be meaningful enough. We propose to maximize the average MI between the high-level representation and local patches of the image. Because the same global representation is encouraged to have high MI with all the patches, this favours encoding aspects of the data that are shared across patches.

Maximizing mutual information between local features and global features.

First we encode the input to a feature map, Cψ(x)={Cψ(i)}i=1M×M that reflects useful structure in the data (e.g., spatial locality). Next, we summarize this local feature map into a global feature, Eψ(x)=fψCψ(x). We then define our MI estimator on global/local pairs, maximizing the average estimated MI:

maxω,ψ1M2i=1M2I^ω,ψ(Cψ(i)(X);Eψ(X)).

MI Maximization Objectives

The Donsker-Varadhan Objective

This lower-bound to the MI is based on the Donsker-Varadhan representation of the KL-divergence. It is the tightest possible bound on KL divergence, but it is less stable and requires many negative samples.

I(X;Z)I^ψ,ω(DV)(X;Z)=Ep(x,z)Tψ,ω(x,z)logEp(x)p(z)(eTψ,ω(x,z))

The Jensen-Shannon Objective

Since we do not concern the precise value of mutual information, and rather primarily interested in its maximization, we could instead optimize on the Jensen-Shannon divergence. This objective is stable to optimize and requires few negative sample, but it is a looser bound to the true mutual information.

Following f-GAN formulation, with output activation gf=log(1+ev) and conjugate function f(t)=log(1et), we define the following objective:

I^ψ,ω(JS)(X;Eψ(X))=DJS^(p(x,Eψ(x))p(x)p(Eψ(x)))=Ep(x)[T~ψ,ω(x,Eψ(x))Ep(x)f(T~ψ,ω(x,Eψ(x)))]=Ep(x)[log(1+eTψ,ω(x,Eψ(x)))Ep(x)log(1+eTψ,ω(x,Eψ(x)))]=Ep(x)[logσ(Tψ,ω(x,Eψ(x)))+Ep(x)log(1σ(Tψ,ω(x,Eψ(x))))]

where T~=gfT is the discriminator output after activation gf. In section A.1 of the DMI paper 5, the authors show theoretically and empirically that maxψ,ωDJS(p(x,z)p(x)p(z)) is indeed a good maximizer of I(X;Z).

The InfoNCE Objective

This objective uses noise-contrastive estimation to bound mutual information. It obtains strong results, but requires many negative samples.

I^ψ,ω(InfoNCE)(X;Eψ(X))=Ep(x,z)[Tψ,ω(x,Eψ(x))Ep(x)[logxeTψ,ω(x,Eψ(x))]]

Deep Graph Infomax

2019 ICLR - Deep Graph Infomax 6

Deep Graph Infomax (DGI) is a general approach for learning node representations within graph-structured data in an unsupervised manner. It relies on maximizing mutual information between patch representations and corresponding high-level summaries of graphs.

We first introduce

  • The encoder E:RN×Fin×RN×NRN×F such that E(X,A)=H=(h1,,hN) produces node embeddings (or patch representations) that summarize a patch of the graph centered around node i.
  • The readout function R:RN×FRF which summarizes the obtained patch representations into a graph-level representation s=R(E(X,A)). It is implemented as a sigmoid after a mean R(H)=σ(1Ni=1Nhi).
  • The discriminator D:RF×RFR such that D(hi,s) represents the logit scores assigned to this patch-summary pair (should be higher for patches contained within the summary). It is implemented as a bilinear function D(hi,s)=hiWs.
  • Negative samples are generated by pairing the summary vector s of a graph with patch representations h~j from another graph (X~,A~). This alternative graph is obtained as other elements of a training set in a multi-graph setting, or by an explicit corruption function (X~,A~)=C(X,A) which permutes row-wise the node feature matrix X.

Next we introduce the DGI objective for one training graph G=(X,A), based on the Jensen-Shannon objective for Deep InfoMax

maxL=1N+M(i=1NExiV(G)logσ(D(hi,s))+j=1MEx~jV(G~)log(1σ(D(h~j,s))))

InfoGraph

2020 ICLR - InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization 7

InfoGraph studies learning the representations of whole graphs (rather than nodes as in DGI) in both unsupervised and semi-supervised scenarios. Its unsupervised version is similar to DGI except for

  • Batch-wise generation of negative samples rather than random-sampling- or corruption-based negative samples.
  • GIN methodologies for better graph-level representation learning.

In semi-supervised setting, directly adding a supervised loss would likely result in negative transfer. The authors alleviate this problem by separating the parameters of the supervised encoder φ and those of the unsupervised encoder ϕ, and adding a student-teacher loss which encourage mutual information maximization between the two encoders at all levels. The overall loss is:

L=GGll(y~ϕ(G),yG)+GGlGu1|V|uVI^(hφu(G),Hφ(G))λGGlGu1|V|k=1KI^(Hϕ(k)(G),Hφ(k)(G))

In practice, to reduce the computational overhead, at each training step, we enforce mutual-information maximization on a randomly chosen layer of the encoder.

Comments