Paper: https://arxiv.org/abs/1906.08237 Github: https://github.com/zihangdai/xlnet
AR Language Model & AE
Autoregressive Language Model
AR LM (autoregressive language model) 은 순차적으로 다음 토큰을 예측하는 language model입니다. 수식으로 표현하면 텍스트 시퀀스 \(\mathbf{x} = (x_1, \cdots, x_T)\) 에 대해 정방향 혹은 영방향의 likelihood를 계산하는 모델입니다.
\[\begin{align} p(\mathbf{x}) = \prod_{t=1}^{T}p(x_T | \mathbf{x}_{\lt t}) \\ p(\mathbf{x}) = \prod_{t=T}^{1}p(x_T | \mathbf{x}_{\gt t}) \end{align}\]AR LM은 정방향 혹은 역방향의 단방향 context만 인코딩할 수 있어 깊은 양방향 context를 인코딩하기에 부족합니다. (ELMo에서 양방향 LM을 사용하기는 했지만, 정방향 LM과 역방향 LM의 얕은 결합이기 때문에, 깊은 context를 인코딩하기에는 한계가 있습니다.) 따라서 downstream language tasks에서 필요한 양방향 context (bidirectional context information)을 학습하기에 한계가 있고, 이는 AR LM으로 효과적인 pretrain을 기대하기 어렵게 합니다.
Autoencoding
AE (autoencoding) 모델은 어떤 explicit density estimation 대신 노이즈가 섞인 입력에서 원 데이터를 복원하는 과정을 거칩니다. 최근의 가장 유명한 SOTA 모델로는 BERT가 있습니다. BERT는 입력 텍스트 시퀀스에서 일정 확률로 특정 토큰을 [MASK]
토큰으로 치환한 뒤 원래 토큰을 예측하는 식으로 언어를 학습합니다 (Masked LM). 이를 통해 BERT는 AR의 한계인 양방향 context를 잘 학습할 수 있습니다. 하지만[MASK]
토큰은 오직 pretraining 과정에만 등장하기 때문에 pretraining과 finetuning 과정에서 괴리가 발생합니다 (pretrain-finetune discrepancy). 또한 예측 대상 토큰이 입력 중간 중간에 마스킹되어 있기 때문에 BERT는 AR에서의 product rule을 이용한 결합 확률 (joint probability)을 모델링할 수 없습니다. 즉 BERT에서는 예측 대상 토큰 (마스킹된 토큰)이 서로 독립적이라고 가정하기 때문에 이는 자연어에서 필요한 고차, 장기 의존성을 모델링하기에는 과단순화된 것이라고 할 수 있습니다.
XLNet
논문에서는 AR LM과 AE의 장점을 취하는 generalized language model, XLNet을 소개합니다. 주요 특징으로는 아래와 같습니다.
- AR에서의 고정된 factorization 순서와 달리, 가능한 모든 factorization 순서의 permutation의 log likelihood를 최대화합니다.
- XLNet은 입력 데이터에 노이즈를 추가하지 않습니다. 따라서 pretrain-finetune 괴리를 겪지 않고, BERT에서와 같이 예측 토큰끼리의 독립 가정을 하지 않습니다.
Background (AR LM & BERT)
텍스트 시퀀스 $\mathbf{x} = (x_1, \cdots, x_T)$ 에 대해 AR LM은 autoregressive factorization 으로 아래의 likelihood를 최대화하기 위해 pretrain합니다.
\[\underset{\theta}{\text{max}}\quad\log\;p_{\theta}(\mathbf{x}) = \sum_{t=1}^{T}\text{log}\;p_{\theta}(x_t | \mathbf{x}_{\lt t}) = \sum_{t=1}^{T}\text{log}\frac{\text{exp}\bigl(h_{\theta}(\mathbf{x}_{1:t-1})^{\top}e(x_{t}) \bigr)}{\sum{_{x^{\prime}} \text{exp} \bigl( \mathbf{x}_{1:t-1})^{\top}e(x^{\prime}) \bigr)}}\]위 수식에서 $h_{\theta}(\mathbf{x}_{1:t-1})$ 은 context representation, $e(x)$ 는 $x$ 의 임베딩을 가리킵니다. 이에 비해 BERT는 기본적으로 denoising autoencoding인데요. 구체적으로는 시퀀스 $\mathbf{x}$ 의 15%만큼을 [MASK]
로 치환한 $\hat{\mathbf{x}}$ 를 만듭니다. 마스킹된 토큰을 $\bar{\mathbf{x}}$ 라고 했을 때 BERT의 $\bar{\mathbf{x}}$ 로 $\hat{\mathbf{x}}$ 를 재구성하는 목적 함수로 학습합니다.
위 수식에서 $m_{t} = 1 $ 은 마스킹 여부를 나타내는 것으로, $x_{t}$ 가 마스킹된 토큰이라는 뜻입니다. $H_{\theta}$ 는 $T$ 길이의 시퀀스 $\mathbf{x}$ 를 hidden vector 시퀀스 \(H_{\theta}(\mathbf{x}) = [ H_{\theta}(\mathbf{x})_{1}, H_{\theta}(\mathbf{x})_2, \cdots, H_{\theta}(\mathbf{x})_T ]\) 로 매핑하는 Transformer입니다. AR과 BERT의 목적 함수의 장단점은 아래와 같은 관점에서 비교할 수 있습니다.
- 독립 가정 (Independence Assumption): BERT의 목적 함수에서의 근사 표시에서 알 수 있듯이, BERT는 결합 조건부 확률에 대해 마스킹된 모든 토큰 $\hat{\mathbf{x}}$ 서로 독립이라고 (실제로는 독립이 아니지만) 가정하고 있습니다. 반면 AR LM은 이러한 가정 없이 product rule을 따릅니다.
- 입력 노이즈 (Input Noise): BERT에는 pretraining 단계에서 finetuning 과정에서는 등장하지 않는
[MAKS]
토큰이 등장합니다.[MAKS]
토큰을 원래 토큰으로 바꾸긴 하지만 이는 작은 확률이기 때문에 pretrain-finetune 괴리 문제를 해결하지 못합니다. 반면 AR LM은 입력에 노이즈를 가하지 않아 BERT의 이러한 문제는 겪지 않습니다. - Context 의존성 (Context dependency): AR LM의 representation \(h_{\theta}(\mathbf{x}_{1:t-1})\) 은 $t$ 시점까지의 토큰들로만 구성됩니다. 반면 BERT의 $H_{\theta}(\mathbf{x})_t$ 는 양방향에서 context 정보를 얻습니다. 결국 BERT는 양방향 context를 좀 더 잘 포착할 수 있도록 학습됩니다.
Permutation Language Model
논문에서는 AR 모델의 이점을 그대로 가지면서 양방향 context를 학습할 수 있는 Permutation LM을 제안합니다. 길이가 $T$ 인 시퀀스 $\mathbf{x}$ 는 총 $T!$ 개의 permutation을 가집니다. 이 때 모든 permutation에 대해 모델 파라미터가 공유되면 모델은 모든 위치에서 정보를 얻을 수 있을 거라고 기대할 수 있습니다.
term \(\mathcal{Z}_{T}\) 를 \(T\) 길이의 시퀀스 $[1, 2, \ldots, T]$ 의 모든 permutation 집합, $z_t$ 를 t 번째 원소, \(\mathbf{z}_{\lt t}\) 를 permutation \(\mathbf{z} \in \mathcal{Z}_{T}\) 의 t-1 원소들이라고 할 때 permutation LM 의 objective 는 다음과 같이 나타낼 수 있습니다.
\[\begin{align} \underset{\theta}{\text{max}} && \mathbb{E}_{\mathbf{z} \sim \mathcal{Z}_{T}} \Biggl[ \sum_{t=1}^{T} \log p_{\theta} (x_{z_{t}} | \mathbf{x}_{\mathbf{z}_{\lt t}}) \Biggr] \end{align}\]token $x_3$ 을 예측하는 예를 살펴보면 다음 그림과 같습니다. $x_{3}$ 을 예측하는 데 $[x_2, x_4]$, $[x_1, x_2, x_4]$, $[x_4]$ 를 이용합니다. 한편 permutation은 시퀀스 순서 자체가 아니라 factorization order에 대해서만 이루어지기 때문에 positonal encoding을 이용해 각 토큰 간의 위치 정보를 구분할 수 있습니다.
즉 모든 permutation에서의 log likelihood를 최대화하기 위한 문제가 됩니다. 실제로는 계산량 때문에 permutation 중 sampling을 합니다. 공식 구현체의 해당 코드 를 보면 다음과 같습니다.
1 |
|
perm_size
는 일종의 윈도우라고 생각하면 되는데, perm_size
만큼 셔플을 합니다. 가령 seq_len
이 10이고 perm_size
가 2이면 index = tf.transpose(tf.reshape(index, [-1, perm_size]))
와 index = tf.random_shuffle(index)
를 거친 index
는 다음과 같습니다.
1 |
|
최종적으로는 위 결과물을 tf.reshape(tf.transpose(index), [-1])
으로 원래의 inputs
과 같은 형상으로 맞춰줍니다.
Two-Stream Self-Attention for Target-Aware Representations
그러나 standard Transformer를 위의 permutation LM에 그대로 적용하면 안됩니다. 그 이유는 같은 representation으로 다른 토큰을 예측해야 하는 문제가 발생하기 때문입니다. 위에서 살펴본 permutation LM의 objective의 likelihood에 softmax를 적용하면 다음과 같습니다.
\[p_{\theta}(X_{z_{t}} = x | \mathbf{x}_{\mathbf{z}_{\lt t}}) = \frac{\exp \bigl( e(x)^{\top} h_{\theta}(\mathbf{x_{\mathbf{z}_{\lt t}}}) \bigr)} {\sum{x^{\prime}} \bigl( e(x^{\prime})^{\top}h_{\theta}(\mathbf{x}_{\mathbf{z}_{\lt t}}) \bigr) }\]여기서의 문제점은 representation \(h_{\theta}(\mathbf{x}_{\mathbf{z}_{\lt t}})\) 가 예측해야 하는 타겟 포지션 \(z_t\) 에 의존적이지 않다는 점입니다. 예를 들어 두 개의 permutation [1, 2, 4, 3], [1, 2, 3, 4] 가 있을 때 $x_4$ 를 예측하는 데에도 \(h_{\theta}(x_1, x_2)\) 를 이용하고 $x_3$ 을 예측하는 데에도 $h_{\theta}(x_1, x_2)$ 를 이용해야 합니다. 이 경우 제대로 된 representation을 학습하기 어렵습니다. 이 문제를 해결하기 위해 논문에서는 이전 context의 정보 외에도 \(g_{\theta}(\mathbf{x}_{\mathbf{z}_{\lt t}}, z_t)\) 라는 타겟 인덱스의 포지션 정보를 함께 이용하는 Two-Stream Self-Attention 구조를 사용합니다. 이 때 위의 수식은 아래처럼 변합니다.
\[p_{\theta}(X_{z_{t}} = x | \mathbf{x}_{\mathbf{z}_{\lt t}}) = \frac{\exp \bigl( e(x)^{\top} g_{\theta}(\mathbf{x}_{\mathbf{z}_{\lt t}}, z_t) \bigr)} {\sum{x^{\prime}} \bigl( e(x^{\prime})^{\top} g_{\theta}(\mathbf{x}_{\mathbf{z}_{\lt t}}, z_t) \bigr) }\]이 구조가 제대로 동작하려면 \(g_{\theta}(\mathbf{x}_{\mathbf{z}_{\lt t}}, z_t)\) 는 두 가지를 만족해야 합니다.
- 토큰 $x_{z_{t}}$ 를 예측할 때는 $x_{z_t}$ 의 포지션만 이용하고 context는 이용하면 안됩니다.
- 다른 토큰 $x_{z_j}$ ($j \gt t$) 을 예측할 때는 $x_t$ 의 context를 인코딩해야 합니다.
이는 standard transformer 구조와 모순되는데 이를 해결하기 위해 논문에서는 2개의 hidden state를 이용합니다.
- context representation \(h_{\theta}(\mathbf{x}_{\mathbf{z}_{\le t}})\) (\(h_{z_t}\)) 은 Transformer의 hidden state와 비슷한 역할을 합니다. 이 representation은 context와 \(x_{z_t}\) 모두 인코딩합니다.
- query representation \(g_{\theta}(\mathbf{x}_{\mathbf{z}_{\lt t}}, z_t)\)$ (\($g_{z_t}\)$) 은 contextual information \($\mathbf{x}_{\mathbf{z}_{\lt t}}\) 와 포지션 \(z_t\) 에만 접근할 수 있습니다.
각각은 다음과 같이 계산됩니다.
\[\begin{align} g_{z_t}^{(m)} \leftarrow \text{Attention}(\text{Q} = g_{z_t}^{(m-1)}, \text{KV}=\mathbf{h}_{z \lt t}^{(m-1)};\theta)\\ h_{z_t}^{(m)} \leftarrow \text{Attention}(\text{Q} = h_{z_t}^{(m-1)}, \text{KV}=\mathbf{h}_{z \le t}^{(m-1)};\theta) \end{align}\]$h_{z_t}^{(m)}$ 은 content stream으로 standard transformer에서의 representation과 같은 역할을 합니다. ($z_t$ 와 $x_{z_t}$ 를 모두 이용합니다.) $g_{z_t}^{(m)}$ 은 queay stream으로 $z_t$ 는 이용하지만 $x_{z_t}$ 는 이용할 수 없습니다. 첫 번째 query stream layer는 $g_i^{(0)} = w$ 으로 초기화되고 content stream은 단어 임베딩 $h_{i}^{(0)} = e(x_i)$ 으로 초기화됩니다. finetuning 단계에서는 content stream만 이용합니다. Pretrain 단계에서는 마지막 layer의 query representation을 이용합니다.
코드로 살펴보면 두 representation이 초기화 되는 과정은 아래와 같습니다.
1 |
|
Partial Prediction
permutation LM의 또 다른 문제는 permutation이 느린 수렴을 일으킨다는 점입니다. 이 최적화 문제를 덜기 위해 XLNet에서는 마지막 몇 개의 토큰만 예측합니다. 구체적으로는 $\mathbf{z}$ 를 non-target subsequence \(\mathbf{z}_{\le c}\) 와 target subsequence \(\mathbf{z}_{\gt c}\) 로 나눕니다. 이 때 objective는 non-target subsequence에 대한 target subsequence의 log likelihood를 최대화 하게 됩니다.
\[\underset{\theta}{\max} \;\;\; \mathbb{E}_{\mathbf{z} \sim \mathcal{Z}_{T}} \Biggl[ \log p_{\theta}(\mathbf{x}_{\mathbf{z}_{\gt c}} | \mathbf{x}_{\mathbf{z}_{\le c}}) \Biggr] = \mathbb{E}_{\mathbf{z} \sim \mathcal{Z}_{T}} \Biggl[ \sum_{t=c+1}^{|\mathbf{z}|} \log p_{\theta} (x_{z_{t}} | \mathbf{x}_{\mathbf{z}_{\lt t}})]\]해당 코드는 여기서 볼 수 있는데요. 예측할 토큰 개수만큼 마스킹을 합니다. 이 때 선택되지 않는 토큰은 query representation이 계산되지 않습니다.
1 |
|
Incorporating Ideas from Transformer-XL
XLNet은 저자들의 이전 논문인 Transofmer-XL에서 두 가지 아이디어를 차용합니다. 하나는 relative positional encoding이고 다른 하나는 sement recurrence mechanism입니다. (Transformer-XL은 긴 시퀀스를 처리하기 위해 데이터를 세그먼트 단위로 처리합니다.)
Relative Positional Encoding
시퀀스를 세그먼트 단위로 처리할 때 standard transformer에서 문제가 되는 부분이 positional encoding입니다. $T+1$ 번째 세그먼트는 $T$ 번째 세그먼트보다 뒤에 위치해 있지만 기존의 고정된 positional encoding을 사용하면 각각의 세그먼트가 같은 포지션 정보를 가지게 됩니다. 따라서 시퀀스 내부의 토큰 간 포지션 정보를 제대로 학습하기 어렵습니다. Relative positinoal encoding에 대한 자세한 설명은 Transformer-XL 논문에 나와있습니다. Standard transformer와 달리 attend 하는 위치와의 상대적인 거리만 고려합니다.
Segment Recurrence Mechanism
XLNet에서 사용되는 segment recurrence mechanism을 수식으로 나타내면 다음과 같습니다. 한 시퀀스를 두 개의 세그먼트 \(\tilde{\mathbf{x}} = \mathbf{s}_{1:T}\), $\mathbf{x} = \mathbf{s}_{T+1: 2T}$ 로 나눌 수 있습니다. 이 때 $\tilde{x}$ 의 permutation $\tilde{\mathbf{z}}$ 에 따라 첫 번째 세그먼트를 처리하고 각 레이어 \(m\) 마다 얻어진 content representation \(\tilde{\mathbf{h}}^{(m)}\) 을 캐싱합니다. 이 때 다음 세그먼트 \(\mathbf{x}\) 에 대한 attention은 아래와 같습니다. Query representation도 같은 방법으로 계산됩니다.
\[h_{z_t}^{(m)} \leftarrow \text{Attention}(\text{Q} = h_{z_t}^{(m-1)}, \text{KV} = \Bigl[ \tilde{\mathbf{h}}^{(m-1)}, \mathbf{h}_{z_{\le t}}^{(m-1)} \Bigr]; \theta)\]코드 상으로는 다음과 같습니다.
1 |
|
Modeling Multiple Segments
여러 downstream은 question, context (QA) 처럼 여러 세그먼트가 입력으로 이루어져 있습니다.
pretraining
pretraining 단계에서 XLNet은 BERT와 유사하게 [A, SEP, B, SEP, CLS] 로 이용합니다. 이 때 같은 context든 아니든 두 개의 세그먼트를 랜덤하게 선택해 두 개의 세그먼트로 이루어진 하나의 시퀀스로 취급합니다. 한편 XLNet-Large에서는 별다른 성능 향상이 없어 BERT에서 사용했던 NSP (Next Sentence Prediction)는 사용하지 않았습니다.
Relative Segment Encodings
XLNet에서는 세그먼트 또한 인코딩합니다. 이는 두 포지션 $i$ 와 $j$ 가 같은 세그먼트에 위치하는 지를 나타내기 위함입니다. 구체적으로는 $i$ 와 $j$ 가 같은 세그먼트에 위치해 있으면 \(\mathbf{s}_{ij} = \mathbf{s}_{+}\), 그렇지 않으면 \(\mathbf{s}_{ij} = \mathbf{s}_{-}\) 를 사용합니다.
이렇게 relative encoding이 결합된 attention의 구현은 다음과 같습니다. ac
와 bd
는 Transformer-XL 에서 등장하는 term인데 attention을 분해한 것입니다.
1 |
|
Discussion adn Analysis
Comparison with BERT
논문에서는 예시를 들어 BERT와 XLNet의 차이점을 설명하고 있습니다. [New, York, is, a, city] 라는 문장이 있을 때 BERT와 XLNet에서 모두 [New, York] 을 타겟 토큰으로 선택했다고 가정합니다. 이 때 XLNet에서 샘플링한 factorization order는 [is, a, city, New, York] 이라고 가정합니다. 이 경우 BERT와 XLNet은 아래와 같은 objective를 따라 학습합니다.
\[\mathcal{J}_{\text{BERT}} = \log p (\text{New} \; | \; \text{is a city}) + \log p (\text{York} \; | \; \text{is a city}) \\ \;\\ \mathcal{J}_{\text{XLNet}} = \log p (\text{New} \; | \; \text{is a city}) + \log p (\text{York} \; | \; \color{red}{\text{New}}, \text{is a city}) \\\]BERT와 달리 XLNet은 (New, York) 쌍 간의 의존성을 포착할 수 있습니다. AR의 장점을 살린 permutation LM으로 이전 토큰으로 이후 토큰에 대해 예측하는 데 도움을 줄 수 있는 것입니다.
Comparision with Language Model
GPT같은 standard AR LM은 (New $\rightarrow$ York) 의 의존성은 잡아내지만 (York $\rightarrow$ New) 의 의존성은 잡아내지 못합니다. 반면 XLNet은 permutation으로 이러한 의존성도 학습할 수 있을 것이라 기대할 수 있습니다. 이 부분은 현실 문제에서 중요한데 예를 들어 “Thom Yorke is the singer of Radiohead” 라는 context와 “Who is the singer of Radiohead” 라는 질문에 대해 기존의 AR로는 “Thom Yorke” 와 “Radiohead” 간의 의존성을 잡아낼 수 없지만 XLNet은 가능합니다.
Experiments
몇 가지 task에 대해 BERT와의 비교 성능은 아래와 같습니다.
https://github.com/zihangdai/xlnet/edit/master/README.md
Reading comprehension
Model | RACE accuracy | SQuAD1.1 EM | SQuAD2.0 EM |
---|---|---|---|
BERT | 72.0 | 84.1 | 78.98 |
XLNet | 81.75 | 88.95 | 86.12 |
Text classification
Model | IMDB | Yelp-2 | Yelp-5 | DBpedia | Amazon-2 | Amazon-5 |
---|---|---|---|---|---|---|
BERT | 4.51 | 1.89 | 29.32 | 0.64 | 2.63 | 34.17 |
XLNet | 3.79 | 1.55 | 27.80 | 0.62 | 2.40 | 32.26 |
GLUE
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B |
---|---|---|---|---|---|---|---|---|
BERT | 86.6 | 92.3 | 91.3 | 70.4 | 93.2 | 88.0 | 60.6 | 90.0 |
XLNet | 89.8 | 93.9 | 91.8 | 83.8 | 95.6 | 89.2 | 63.6 | 91.8 |
Conclusions
XLNet은 AR LM과 autuencoder의 장점을 모두 취해 BERT보다 좋은 성능을 이끌어냈습니다. 이는 특히 긴 시퀀스를 처리해야 하는 reading comprehension task에서 두드러지는데 애초에 긴 시퀀스를 처리하기 위해 등장한 Transformer-XL에서 발전한 개념이기 때문이라고 생각합니다.
그러나 짧은 시퀀스를 다루는 문제에서는 그만큼 높은 성능을 기대하기 어려울 것으로 보입니다. (pretraining에 소요되는 자원을 고려하면 더 그렇습니다. BERT-Large: 64 TPU chips, XLNet-Largs: 512 chips)