Some tools:

  • Stochastic variational inference
  • Variance reduction
  • Normalizing flows
  • Gaussian processes
  • Scalable MCMC algorithms
  • Semi-implicit variational inference

Bayesian framework

  • Bayes theorem:
conditional=jointmarginal p(xy)=p(x,y)p(y)

It defines a rule for uncertainty conversion when new information arrives

posterior=likelihood×priorevidence
  • Product rule: any joint distribution can be expressed with conditional distributions
p(x,y,z)=p(xy,z)p(yz)p(z)
  • Sum rule: any marginal distribution can be obtained from the joint distribution by integrating out
p(y)=p(x,y)dx
  • Statistical inference

    Problem: given i.i.d. data X={xi}i=1n from distribution p(xθ), estimate θ

    1. Frequentist framework: use maximum likelihood estimation (MLE)
    θML=argmaxp(Xθ)=argmaxi=1np(xiθ)=argmaxinlogp(xiθ)

    Applicability: nd

    1. Bayesian framework: encode uncertainty about θ in a prior p(θ) and apply Bayesian inference
    p(θX)=inp(xiθ)p(θ)inp(xiθ)p(θ)dθ

    Applicability: nd

    Advantages: - we can encode prior knowledge/desired properties into a prior distribution - prior is a form of regularization - additionally to the point estimate of θ, posterior contains information about the uncertainty of the estimate - frequentist case is a limit case of Bayesian one limn/dp(θx1,,xn)=δ(θθML)

Bayesian ML models

In ML, we have x features (observed variables) and y class labels or hidden representations (hidden or latent variables) with some model parameters θ (e.g. weights of a linear model).

  • Discriminative approach, models p(y,θx)

    • Cannot generate new objects since it needs x as an input and assumes that the prior over θ does not depend on x: p(y,θ)=p(yx,θ)p(θ)
    • Examples: 1) classification/regression (hidden space is small) 2) Machine translation (complex hidden space)
  • Generative approach, models p(x,y,θ)=p(x,yθ)p(θ)

    • It can generate objects (pairs p(x,y)), but it can be hard to train since the observed space is most often more complicated.
    • Example: Generation of text, speech, images, etc.
  • Training

    Given data points (Xtr,Ytr) and a discriminative model p(y,θx).

    Use the Bayesian framework:

    p(θXtr,Ytr)=p(YtrXtr,θ)p(θ)p(YtrXtr,θ)p(θ)dθ

    This results in a ensemble of algorithms rather than a single one θML. Ensembles usually performs better than a single model.

    In addition, the posterior captures all dependencies from the training data and can be used later as a new prior.

  • Testing

    We have the posterior p(θXtr,Ytr) and a new data point x. We can use the predictive distribution on its hidden value y

    p(yx,Xtr,Ytr)=p(yx,θ)p(θXtr,Ytr)dθ
  • Full Bayesian inference

During training the evidence p(YtrXtr,θ)p(θ)dθ or in testing the predictive distribution p(yx,θ)p(θXtr,Ytr)dθ might be intractable, so it is impractical or impossible to perform full Bayesian inference. In other words, there is not closed form.

Conjugacy

Conjugate distributions

Distribution p(y) and p(xy) are conjugate p(yx) belongs to the same parametric family as p(y)

p(y)A(α),p(xy)B(y)p(yx)A(α)
  • There’s not conjugacy We can perform MAP to approximate the posterior with θMP since we don’t need to calculate the normalization constant, but we cannot compute the true posterior.
θMP=argmaxp(θXtr,Ytr)=argmaxp(YtrXtr,θ)p(θ)

During testing:

p(yx,Xtr,Ytr)=p(yx,θ)p(θXtr,Ytr)dθp(yx,θMP)

Conditional conjugacy

Given the model: p(x,θ)=p(xθ)p(θ) where θ=[θ1,,θm]

Conditional conjugacy of likelihood and prior on each θj conditional on all other {θi}ij

p(θjθij)A(α),p(xθj,θij)B(θj)p(θjx,θij)A(α)

Check conditional conjugacy in practice: For each θj

  • Fix all other {θi}ij (look at them as constants)
  • Check whether p(xθ) and p(θ) are conjugate w.r.t. θj

Variational Inference

Given the model p(x,θ)=p(xθ)p(θ), find a posterior approximation p(θx)q(θ)Q, such that:

KL(q(θ)p(θx))minq(θ)Q

KL is a good mismatch measure between two distributions over the same domain (see figure). And it has the following properties:

  1. KL(qp)0
  2. KL(qp)=0q=p
  3. KL(qpKL(pq))

KL

Evidence Lower Bound (ELBO) derivation

  • Posterior: p(θx)
  • Evidence: p(x), shows the total probability of the observing data.
  • Lower bound: logp(x)L(q(θ))
logp(x)=q(θ)logp(x)dθ=q(θ)logp(x,θ)p(θx)dθ=q(θ)logp(x,θ)q(θ)p(θx)q(θ)dθ=q(θ)logp(x,θ)q(θ)dθ+q(θ)logq(θ)p(θx)dθ=L(q(θ))+KL(q(θ)p(θx))

Note:

  • logp(x) does not depend on q
  • L and KL depend on q
  • minimizing KL is the same as maximizing L.
KL(q(θ)p(θx))minq(θ)QL(q(θ))maxq(θ)Q

Optimizing ELBO L

Goal: L(q(θ))maxq(θ)Q

L(q(θ))=q(θ)logp(x,θ)q(θ)dθ=q(θ)logp(xθ)dθ+q(θ)logp(θ)q(θ)dθ=Eq(θ)logp(xθ)KL(q(θ)p(θ))
  • Data term: Eq(θ)logp(xθ)
  • Regularizer: KL(q(θ)p(θ))

Necessary to perform optimization w.r.t. a distribution maxq(θ)QL(q(θ)). Hard problem! In VI, we approximate with an approximate distribution q. This approximate distribution can belong to a factorized or parametric family.

  1. Mean field approximation: Factorized family, q(θ)=j=1mqj(θj), θ=[θ1,,θm]
  2. Parametric approximation: Parametric family, q(θ)=q(θλ)

Mean Field Approximation

Mean field assumes that θ1,,θm are independent.

  • Apply product rule to distribution q: q(θ)=j=1mqj(θjθ<j)
  • Apply i.i.d. assumption: q(θ)=j=1mqj(θj)

The optimization problem becomes:

maxj=1mqj(θj)QL(q(θ))

This can be solved with block coordinate assent as follows: at each step fix all factors {qi(θi)}ij except one and optimize w.r.t. to it maxqj(θj)L(q(θ))

Derivation

L(q(θ))=Eq(θ)logp(x,θ)Eq(θ)logq(θ)=Eq(θ)logp(x,θ)k=1mEqk(θk)logqk(θk)=Eq(θ)logp(x,θ)Eqj(θj)logqj(θj)+C={rj(θj)=1Zjexp(Eqijlogp(x,θ))}=Eqj(θj)logrj(θj)Eqj(θj)logqj(θj)+C=KL(qj(θj)rj(θj))+C

So, the optimization problem for step j is:

maxqj(θj)L(q(θ))=maxqj(θj)KL(qj(θj)rj(θj))+C

Where this happens when:

qj(θj)=rj(θj)=1Zjexp(Eqijlogp(x,θ)) logqj(θj)=Eqijlogp(x,θ)+C

Block coordinate assent can be described in two steps 1) initialize; 2) iterate

  1. Initialize: q(θ)=j=1mqj(θj)
  2. Iterate (repeat until ELBO converge):
    • Update each factor q1,,qm: qj(θj)=1Zjexp(Eqijlogp(x,θ))
    • Compute ELBO L(q(θ))

Note: Mean-field can be applied when we can compute analytically Eqijlogp(x,θ). In other words, applicable when we can compute the conditional conjugacy.

Parametric Approximation

Select a parametric family of variational distributions, q(θ)=q(θλ), where λ is a variational parameter.

The restriction is that we need to select a family of some fixed form, and as a result:

  • it might be too simple and insufficient to model the data
  • if it is complex enough then there is no guarantee we can train it well to fit the data

The ELBO is:

maxλL(qθλ)=q(θλ)logp(xθ)q(θλ)dθ

If we’re able to calculate derivatives of ELBO w.r.t θ, then we can solve this problem using some numerical optimization solver.

Inference methods

So we have:

  1. Full Bayesian inference: p(θx)
  2. MAP inference: p(θx)δ(θθMP)
  3. Mean field variational inference: p(θx)q(θ)=j=1mqj(θj)
  4. Parametric variational inference: p(θx)q(θ)=q(θλ)

Latent variable model

Mixture of Gaussians

Establish a latent variable zi for each data point xi that denotes the ith gaussian where the model was generated.

Model:

p(X,Zθ)=inp(xi,ziθ)=inp(xizi,θ)p(ziθ)=inπziN(xiμzi,σzi2)

where πj=p(zi=j) is the prior of the jth gaussian and θ={μj,σj2,πj}j=1K are the parameters to estimate.

Note: If X and Z are known, we can use ML. For instance:

θML=argmaxθp(X,Zθ)=argmaxθlogp(X,Zθ)
  • Since Z is a latent variable, we need to maximize the log of incomplete likelihood w.r.t. θ.
  • Instead of optimizing logp(Xθ), we optimize the variational lower bound w.r.t. to both θ and q(Z)
  • This can be solved by block-coordinate algorithm a.k.a. EM-algorithm.

Variational Lower Bound: g(ξ,θ) is the variational lower bound function for f(x) iff:

  1. For all ξ for all x: f(x)g(ξ,x)
  2. For any x0 exists ξ(x0) such that: f(x0)=g(ξ(x0),x0)

If we find such variational lower bound, instead of solving f(x)maxx, we can interatively perform block coordinate updates of g(ξ,x).

  1. xn=argmaxxg(ξn1,x)
  2. ξn=ξ(xn)=argmaxξg(ξ,xn)

Expectation Maximization algorithm We want to solve:

argmaxq,θL(q,θ)=argmaxq,θq(Z)p(X,Zθ)q(Z)dZ

Algorithm: Set an initial point θ0

Repeat iteratively 1 and 2 until convergence

  1. E-step, find: q(Z)=argmaxqL(q,θ0)=argmaxqKL(qp)=p(ZX,θ0)
  2. M-step, solve: θ=argmaxθL(q,θ)=argmaxθEZlogp(X,Zθ)
    • Set θ0=θ and go to 1

EM monotonically increases the lower bound and converges to a stationary point of logp(Xθ), see figure.

EM algorithm

Benefits of EM

  • In some cases E-step and M-step can be solved in closed-information
  • Allow to build more complicated models
  • If true posterior p(ZX,θ) is intractable, we may search for the closest q(Z) among tractable distributions by solving optimization problem
  • Allows to process missing data by treating them as latent variables
    • It can deal with both discrete and latent variables

Categorical latent variables Since zi{1,,K} the marginal of a mixture of gaussians is a finite mixture of distributions:

p(xiθ)=k=1Kp(xik,θ)p(zi=kθ)
  • E-step is closed-form: q(zi=k)=p(zi=kxi,θ)=p(xik,θ)p(zi=kθ)l=1Kp(xil,θ)p(zi=lθ)
  • M-step is a sum of finite terms: EZlogp(X,Zθ)=i=1nEzilogp(xi,ziθ)=i=1nk=1Kq(zi=k)logp(xi,kθ)

Continuous latent variables A mixture of continuous distributions

p(xiθ)=p(xizi,θ)p(ziθ)dzi
  • E-step: only done in closed form when conjugate distributions, otherwise the true posterior is intractable

    q(zi)=p(zixi,θ)=p(xizi,θ)p(ziθ)p(xizi,θ)p(ziθ)dzi

Typically continuous latent variable are used for dimensionality reduction a.k.a. representation learning

Log-derivative trick

xp(yx)=p(yx)xlogp(yx)

For example, we commonly find expressions as follows:

xp(yx)h(x,y)dy=xp(yx)h(x,y)dy=(h(x,y)xp(yx)+p(yx)xh(x,y))dy=p(yx)xh(x,y)dy+h(x,y)xp(yx)dy=p(yx)xh(x,y)dy+p(yx)h(x,y)xlogp(yx)dy

Now, the first term can be replaced with Monte Carlo estimate of expectation. Using the log-derivative trick, the second expectation can also be estimated via Monte Carlo.

Score function

It is the gradient of the log-likelihood function with respect to the parameter vector. Since it has zero mean, the value zi in ϕlogq(ziθ) oscillates around zero.

ϕlogq(ziθ)

Proof it has zero mean:

q(ziθ)ϕlogq(ziθ)dzi=q(ziθ)q(ziθ)ϕq(ziθ)dzi=ϕq(ziθ)dzi=ϕ1=0

REINFORCE