Loading [MathJax]/jax/output/CommonHTML/jax.js
[NLP] 2. RNN Basics: Language Model
AI, Deep Learning Basics/NLP

[NLP] 2. RNN Basics: Language Model

이 글은 필자가 Dive into Deep Learning을 읽고 정리한 글입니다.

🏉 Language Model

Given a text sequence that consists of tokens(x1,x2,,xT) in a text sequence of length T, the goal of language model is to estimate the joint probability of the sequence P(x1,x2,,xT). We should know how to model a document or even a sequence of tokens. 

🏉 Learning a language Model

Let us start by applying basic probability rules:

P(x1,x2,,xT)=Tt=1P(xt|x1,,xt1)

There are some solution that tries to solve P(xt|x1,,xt1). Such as:

  • Laplace smoothing
  • Markov Models and n-grams
  • Natural Language Statistics

However, as the length of the token gets bigger, these solutions might not be solvable. Recurrent Neural Network (RNN) came as a solution for it.

🏉 Recurrent Neural Network

Rather than modelling P(xt|x1,,xt1), it is preferable to use a latent variable model:

P(xt|x1,,xt1)P(xt|ht1)

where ht1 is a hidden state that stores the sequence from time step 1 to t1. This can be expressed as ht=f(xt,ht1), where f is a function that outputs the hidden state.

https://www.d2l.ai/_images/rnn.svg

Likewise, a neural network that uses recurrent computation for hidden states is called a Recurrent Neural Network (RNN). Ht can be expressed as

Ht=ϕ(XtWxh+Ht1Whh+bh)

where Xt is the minibatch of inputs at time step t, HtRn×h the hidden variable of time step t. The parameters of the RNN include the weights WxhRd×h,WhhRh×h and the bias bhR1×h.

🏉 BPTT

Let's look at RNN's basic structure.

ht=f(xt,ht1,wh)

ot=g(ht,wo)

Conclusion: Accumulate gradients using chain rule

L(x1,,xT,y1.,yT,wh,wo)=1TTt=1l(yt,ot)

Lwh=1TTt=1l(yt,ot)wh

 

 

🏉 Truncated BPTT

 

🏉 Gradient Clipping / Vanishing Gradient Problem