Transformer を理解したいので。

公開:
深層学習 #Transformer #Attention

Transformer モデル構造

Transformer は大きく Encoder と Decoder の二種類に分けることができます。各ブロックは

  • Attention
  • Feed-forward network(FFN)
  • Residual + Layer norm

から構成されており、Encoder と Decoder とも基本的な処理は同じなのですが使用する情報に差分があります。本記事では、

図の説明
https://arxiv.org/pdf/2009.06732

で表されている Transformer の構成についてを解説していきたいと思います。

Encoder / Decoder について

Transformer では、使用する情報や想定するタスクによって Encoder と Decoder の二種類の処理フローがあります。これらはモデル構造自体に大きな差分があるわけではなく、

  • Encoder では、attention 部分に self attention を使用しており、双方向(未来)の情報を用いる
  • Decoder では、attention 部分に masked self attention を使用しており、未来の情報を予測する

というように大別することができます。より漠然と捉えると

  • トークン自体の埋め込み表現を計算したい場合は encoder
  • トークン生成が目的なら decoder

として捉えても良いと思います。Transformer 元論文におけ左側か右側どちらを拝借しているかということです。

Encoder only

Encoder のみを使用しているモデルとして、BERT や ViT が挙げられます。self attention を使用していて、双方向(未来のトークンも見える)のスコアを計算します。BERT のように文章自体を分類することに特化している場合や、ViT のように画像全体の認識を行う場合には、情報全体を扱うために Encoder を使用するのが一般的です。

Encoder ブロックについて下図を引用しておきます。

図の説明
ViT 元論文より(https://arxiv.org/pdf/2010.11929)

Decoder only

Decoder のみを使用しているモデルとして、 GPT が挙げられます(Improving Language Understanding by Generative Pre-Training)。 デコーダーでは masked self attention を使用しており、未来のトークンとのスコアを計算しないようにマスクを掛けています。こうすることで、次トークン予測を行うモデルを作成することができます。

Decoder ブロックについて下図を引用しておきます。

図の説明
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf

Encoder と比較してみると基本的には大差ないというのがわかると思います。

Encoder Decoder

Encoder と Decoder を使用しているモデルとして、元論文の Transformer や T5 が挙げられます。機械翻訳や文章要約など、入力情報の全体を把握しつつ、文章自体も生成するようなタスクで用いられています。


各種計算

ここからは、Transformer の計算が具体的にどうなっているかについてまとめていきたいと思います。

トークナイズ

Transformer は文字列をそのまま扱えないので、まず文字列からトークン列(整数のID列)へと変換します。OpenAIのモデルであればどのようにトークン化されるのかは https://platform.openai.com/tokenizer で確認することができます。 トークン化では、空白や記号も含めたトークンになることや日本語ではトークン数が増えやすい傾向にあることが見れると思います。

図の説明
英語のトークナイズの例。7トークンとして変換されます。
図の説明
日本語のトークナイズの例。10 トークンとして変換されます。

トークン分割・トークンIDはエンコーダーのモデル依存で、どういったモデルを使用するかによって変わってきます。 たとえば今回の場合だと、May the force be with you.[13561, 290, 9578, 413, 483, 481, 13] というトークンIDに変換されます。このトークンIDは埋め込み行列からベクトルを取得するためのインデックスとなっていて、語彙サイズをVV、モデル次元数を dmodeld_{\rm{model}} とすると埋め込み行列は

ERV×dmodelE \in \mathbb{R}^{V\times d_{\rm{model}}}

となり、ここからトークンIDと対応する行番号のベクトルを取ってくることで入力情報を作成することができます。文章をトークナイズして得られたトークン列を埋め込み行列でベクトル化したものを、行方向に並べた行列を

  • 系列長:TT
  • モデル次元:dmodeld_{\text{model}}

として

X=[x1x2xT]RT×dmodelX = \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_T \end{bmatrix} \in \mathbb{R}^{T \times d_{\text{model}}}

が得られます。以降は 1トークン = 1行ベクトルxix_iXXii 行目)という約束で進めます。行ベクトル / 列ベクトルの流儀は文献によって異なりますが、どちらでも数学的には同じ内容です。行ベクトルでの定義

XWQXW_Q

は、列ベクトルでの定義に書き換えると

XWQ=(WQTXT)TXW_Q = \left(W_Q^T X^T\right)^T

のように転置で対応づけられます。

Self-attention 機構

Transformer の中心的アイデアが attention であり、まずは self-attention について分解して理解していきたいと思います。 ここであえて self attention と明記したのは、cross attention と区別するためです。これらは

  • self attention:Query / Key / Value を同じ系列 XX から作成
  • cross attention:Query と Key / Value を作成する系列が別

という違いがあります。計算方法・考え方に大差はないので理論的には統一的に扱えますが、まずは self attention に絞ったほうが理解しやすいと思います。 以下では、

図の説明
Transformer ブロックの計算概要

の計算の流れを追い、図と式が一対一で対応して理解できるようになることを目標とします。

(1) QQ

Transformer は文字列そのものではなく、埋め込み済みのトークン行列 XX を入力として受け取ります。先程の例の

May the force be with you.

をトークン化して並べると、系列長 T=7T=7 であるので

XR7×dmodelX \in \mathbb{R}^{7 \times d_{\text{model}}}

の形で表せます。

図の説明
系列の行列例。行ベクトルとして感じ取りやすいように縦線を破線としてみました。各行が1トークンに対応しています。

self attention では、各トークン xix_i を役割の異なる3種類の表現に線形変換して分けます。

  • qi=xiWQq_i = x_i W_Q
  • ki=xiWKk_i = x_i W_K
  • vi=xiWVv_i = x_i W_V

ここで WQ,WK,WVRdmodel×dmodelW_Q, W_K, W_V \in \mathbb{R}^{d_{{\rm{model}}} \times d_{{\rm{model}}}} は学習される重み行列です。このとき各ベクトルの次元は

qi,ki,viRdmodelq_i, k_i, v_i \in \mathbb{R}^{d_{\text{model}}}

になります。

クエリベクトル qiq_i の作成を例に取り、具体的な計算を以下に示します。

図の説明
クエリベクトルを作成する場合の計算例

トークン xix_i のベクトル自体は定数であり、WQW_Q の重みをかけ合わせた行列計算を行うことで

qi,j=k=1dmodelxi,kwj,kq_{i,j} = \sum_{k=1}^{d_{{\rm{model}}}} x_{i,k} w_{j,k}

としてクエリベクトルの各成分が計算されます。

全体的な計算を俯瞰した後で「クエリ、キー、バリュー」という名前について考えるのが良いと思いますが、ひとまず直感的には以下の対応で理解しておくとよいです。

  • Query(qiq_i): 「いま自分(トークン xix_i)が注目したい特徴は何か」を表す検索条件
  • Key(kjk_j): 「各トークン xjx_j がどんな特徴を持つか」を表すインデックス
  • Value(vjv_j): 「実際に集約して受け渡す情報の中身」

1トークンではなく系列全体に拡張すると、行列演算としてまとめて書けます。

Q=XWQ,K=XWK,V=XWVQ = XW_Q,\quad K = XW_K,\quad V = XW_V

それぞれ

Q,K,VRT×dmodelQ, K, V \in \mathbb{R}^{T \times d_{\text{model}}}

で、行方向に qi,ki,viq_i, k_i, v_i が並びます。

図の説明
クエリベクトルを作成する場合の計算例を行列に拡張したもの。

(2) QKTQK^T

トークン同士の “関係性” を数値化するために、ベクトル同士の内積を計算します。クエリベクトル qiq_i とキーベクトル kjk_j との内積

qikjq_i \cdot k_j

を計算して、それらまとめます。QQ は行方向にクエリベクトルが並んでいること、KTK^T は転置しているので列方向にキーベクトルが並んでいることに注意すると、

QKTRT×TQK^T \in \mathbb{R}^{T \times T}

の計算をすることで、全トークン同士の内積を計算した行列を作ることができます[1]。成分 (i,j)(i, j)

(QKT)ij=qikj(QK^T)_{ij} = q_i \cdot k_j

です。例えば q4q_4k5k_5 との内積値 (QKT)45(QK^T)_{45} は以下のような形になります。

図の説明
内積の計算過程について

計算した QKTRT×TQK^T\in\mathbb{R}^{T\times T} は後段の処理で、トークン iijj をどれくらい参照したいかを表すスコアとして使用されます。

  • 行(ii): 「どのトークンが」(参照する側 = Query 側)
  • 列(jj): 「どのトークンを」(参照される側 = Key 側)

つまり QKTQK^T は系列内の全トークン間の相関を一つの行列にしたものです。 この計算によりトークン同士をクロス集計した情報を表現することができます。

(3) softmax (QKT/dk)\rm{softmax}~(QK^T / \sqrt{d_k})

前節で得たスコア行列

QKTRT×TQK^T \in \mathbb{R}^{T \times T}

は、各成分が内積

sij=qikjs_{ij} = q_i \cdot k_j

になっていて、「トークン ii(Query)がトークン jj(Key)をどれくらい参照したいか」を表す “生のスコア” です。ただしこのままだと値のスケールが大きくなりやすいので、正規化したうえで softmax を適用することで [0,1][0, 1] の確率値に変換します。

パラメーター dkd_k で正規化し行方向にソフトマックス関数を適用した

A=softmax(QKTdk)A = {\rm{softmax}} \left(\frac{QK^T}{\sqrt{d_k}} \right)

が、attention matrix と呼ばれる情報です。計算した結果の QKTQK^T に改めて着目すると下図のようになります。例えば3行目を見てみると、

図の説明
クエリベクトルを作成する場合の計算例

クエリ q3q_3 に対するキー k1,k2,k3,k4,k5,k6,k7k_1,k_2,k_3,k_4,k_5,k_6,k_7 との内積の値(s3js_{3j})にソフトマックス関数を適用した結果

a3j=softmax (s3j)=softmax(q3kjdk)a_{3j} = {\rm{softmax}}~ (s_{3j}) = {\rm{softmax}} \left( \frac{q_3k_j}{\sqrt{d_k}} \right)

が入っています。すると行方向は、クエリ q3q_3 に対しての全てのキー k1,...,kMk_1,...,k_M に対する “スコア” を計算する事ができ、クエリ q3q_3 がどのキーをどれだけの重みで参照したいかを表すことができます[2]

また列方向では、列 jj はキー kjk_j を固定して全てのクエリ q1,...,qMq_1,...,q_M がそれにつけたスコアが並んでいます。そのため、各クエリ qjq_j からどれくらい重要と見なされているかの情報があります。 ただしソフトマックス関数は行方向に計算されているので、列方向の値を足しても1にはならず、単に大小だけが意味をなします。

(4) AVA V

前節 attention matrix までの計算で、

A=softmax(QKTdk)A = {\rm{softmax}} \left(\frac{QK^T}{\sqrt{d_k}} \right)

どこを重点的に見るかを表す情報を計算することができました。次にその重みを使って実際の特徴ベクトルを作成するかを計算します。attention matrix AA は softmax の演算が掛かっているので、トークンの情報は QKTQK^T で内積値として混ざっているものの確率値になっているので特徴ベクトルとしてはそのまま使用できません。

そこで、バリュー VV を用いて

AVA V

を計算します。

  • ART×TA \in \mathbb{R}^{T \times T}
  • VRT×dvV \in \mathbb{R}^{T \times d_v}(単純化して dv=dkd_v=d_k としてもよい)

なので

ORT×dvO \in \mathbb{R}^{T \times d_v}

となり、各トークン位置ごとに新しい表現(混ぜた結果)が得られます。各行ごとにトークン由来のベクトルが入っているので、行ベクトルを viv_i として

V=[v1v2vT](Rdv)T×1V= \begin{bmatrix} v_1\\ v_2\\ \vdots\\ v_T \end{bmatrix} \in (\mathbb{R}^{d_v})^{T\times 1}

のように T×1T\times 1 の行列とみなせばより簡単に

oi=j=1Taijvjo_i=\sum_{j=1}^{T} a_{ij} v_j

と理解できるかと思います。

図の説明
attention matrix とバリューとの行列積

クエリがどのキーを参照するかの重みを計算できているので、その割合に従って value を混ぜる処理になります。ここでも例えば3行目に着目すると、それは VV の1行目を a31a_{31} で重み付けしたものと、VV の2行目を a32a_{32} で重み付けしたものと、… との和になっている事がわかります。

そのため、例えば a311a_{31} \sim 1 のような値(その他 a32,a33,...,0a_{32},a_{33},..., \sim 0 のような場合)であれば計算したとの3行目はほとんど VV の3行目に等しいということになります。均等に a3ja_{3j} が分布していればバリューを満遍なく取り込んだような値になります。このようにして、クエリ、キー、バリューを用いることで、系列同士の関係性を計算に落とし込むことができるのが attention 機構のロジックです。

ここまでの流れを全体で

ここまでの流れを全体的に示しておきたいと思います。トークン化された系列 XRT×dmodelX \in \mathbb{R}^{T\times d_{\rm{model}}} は attention の計算を経て RT×dmodel\mathbb{R}^{T\times d_{\rm{model}}} の行列が計算されます。元々の入力 XX と行列の形が同じであるので、それぞれはじめのトークン xix_i が他のどの情報を重要と感じているかを表す出力となっていることがわかります。

図の説明
クエリベクトルを作成する場合の計算例

ここまで整理できれば

図の説明
クエリベクトルを作成する場合の計算例

Q,K,VQ,K,V に関する計算についてもできると思います。ただこの図の途中に Mask (opt.) という計算が挟まっていますが、ここについては一旦説明を後回しにして、作成した行列をどのように最終出力につなげるのかを先に説明します。

(5) Position-wise FFN

attention の計算でトークン間の情報のができたので、Transformer ブロックでは次に Position-wise FFN(Feed-forward networks) と呼ばれる計算を実行します。attention 処理では主にどのトークンに着目するかという発想で情報を収集してきましたが、次に作成した情報をどのように処理するかをこの Position-wise FFN で実行します。attention での出力を

XA=softmax(QKTdk)VX_A = {\rm{softmax}} \left(\frac{QK^T}{\sqrt{d_k}} \right) V

とすると、Position-wise FFN は

FFN(xAi)=F2(ReLU(F1(XA)))FFN(x_{Ai}) = F_2(ReLU(F_1(X_A)))

の単なる feed-forward layers の処理です。ただし、行ごとに独立での処理で、各トークンの情報をそれぞれ何らかの特徴量として変換する処理であることに注意してください。

(6) スキップ接続

各 attention、Feed Forward 処理の前後を見ると skip connection があることがわかります。これはそのままの意味で、

XA=LayerNorm(softmax(QKTdk)V)+XXB=LayerNorm(PositionFFN(XA))+XA\begin{align*} X_A &= \rm{LayerNorm} \left( {\rm{softmax}} \left(\frac{QK^T}{\sqrt{d_k}} \right) V \right) + X \\ X_B &= \rm{LayerNorm} \left( {\rm{PositionFFN}} (X_A) \right) + X_A \end{align*}

となるようにそれぞれの処理の前後で入力情報を ResNet の要領で足していることを意味しています。

Masked self-attention

Transformer の self attention では、入力トークン行列 XX から

Q=XWQ,  K=XWK,  V=XWVQ = XW_Q,~~ K=XW_K,~~ V = XW_V

と作り、どのトークンがどのトークンをどれだけ参照するかを attention matrix AA として表現して

A=softmax(QKTdk)A = {\rm{softmax}} \left( \frac{QK^T}{\sqrt{d_k}} \right)

最終的な出力

O=AVO = AV

を計算していました。成分で書けば、各位置 ii の出力は

oi=j=1Taijvjo_i = \sum_{j=1}^T a_{ij} v_j

であり、トークンを重み付きで混ぜ合わせる演算になっています。

なぜマスクが必要か

次トークン予測(自己回帰言語モデル)では、時刻(位置)tt の予測は1,,t11,\dots,t-1 の情報だけから行うことが前提です。例えば

  • 入力:May the force
  • 予測:次に来るトークン(例:be など)

という問題設定では、そのまま attention を計算すると

oi=j=1Taijvjo_i = \sum_{j=1}^T a_{ij} v_j

として be 由来の情報も混じってしまい(v4v_4)、モデルが訓練中に「正解トークン be(未来の情報)」を参照できてしまい、答えを見たうえで学習するために意味のある学習ができません[3]。 このリークを防ぐために、decoder-only Transformer では self attention を masked self-attention(causal self-attention) に置き換えます。

マスクの基本アイデア

masked self attention では、softmax の計算前にマスク行列 MM を加えます。

A=softmax(QKTdk+M)A = {\rm{softmax}} \left( \frac{QK^T}{\sqrt{d_k}} + M\right)

として落とします。ここで MRT×TM\in\mathbb{R}^{T\times T} は「参照を許可しない場所を -\infty にする」行列です。典型的な因果マスク(causal mask)は

Mij={0(ji)(j>i)M_{ij}= \begin{cases} 0 & (j\le i)\\ -\infty & (j> i) \end{cases}

です。直感的には

  • jij\le i:過去・現在は参照してよい
  • j>ij> i:未来は参照禁止(スコアを -\infty に落とす)

という意味になります。softmax の定義より、行 ii の重みは

aij=exp ⁣(sij+Mij)=1Texp ⁣(si+Mi)a_{ij}= \frac{\exp\!\left(s_{ij}+M_{ij}\right)} {\sum_{\ell=1}^{T}\exp\!\left(s_{i\ell}+M_{i\ell}\right)}

です。もし j>ij>i なら Mij=M_{ij}=-\infty なので

exp(sij+Mij)=exp()=0\exp(s_{ij}+M_{ij})=\exp(-\infty)=0

となり、結果として

aij=0(j>i)a_{ij}=0\quad (j>i)

になります。つまり未来トークンには重みが割り当てられないことが保証され、マスク部分はsoftmax 計算後にゼロとなり

図の説明
マスクされた attention matrix の例

のように右上部分がマスクされた行列が出来上がります。これにより自分より未来の情報を含まず過去の情報だけを含む表現を作ることができます。

attention ブロック以降

Decoder-only Transformer ブロックは

  • masked self-attention
  • position-wise FFN
  • residual / LayerNorm(順序は流派による)

です。このうち、トークン間の情報が直接的に混ざるのは self attention の部分だけで、 AVAV のバリューとの計算では係数でその混ざり具合が調整されています。

  • FFN は position-wise:各位置 ii に同じ MLP を独立に適用するだけ
  • LayerNorm は各トークンの特徴次元内の正規化であり、位置間を混ぜない
  • Residual は同じ位置同士を足すだけで、位置間を混ぜない

したがって、自己回帰性(未来を見ない制約)を満たすためにはattention のスコアに因果マスクを入れるだけで十分です。

Cross-attention

ここまでの議論で、self attention は同じ系列 XX から Q,K,VQ,K,V を作り、系列内で情報を混ぜるものでした。

Q=XWQ,K=XWK,V=XWVQ=XW_Q,\quad K=XW_K,\quad V=XW_V A=softmax ⁣(QKTdk+M)A=\mathrm{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}+M\right) O=AVO=AV
  • 通常の self-attention:M=0M=0(マスクなし)
  • masked self-attention:MM が因果マスク(未来参照禁止)

一方で cross-attention は、Query を作る系列とKey/Value を作る系列が である点だけが本質的な違いです。特に典型例は encoder-decoder Transformer(翻訳など)で、デコーダが「生成中のトークン(自分の状態)」を Query にし、エンコーダ出力(入力文の表現)を Key/Value として参照します。

cross attention の定義

Query 用の系列 X(q)X^{(q)} と、Key/Value 用の系列 X(kv)X^{(kv)} が別。

Q=X(q)WQ,K=X(kv)WK,V=X(kv)WVQ=X^{(q)}W_Q,\quad K=X^{(kv)}W_K,\quad V=X^{(kv)}W_V

このとき attention は

A=softmax ⁣(QKTdk+M),O=AVA=\mathrm{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}+M\right),\quad O=AV

で、式の形は self-attention と同じですが、QQK,VK,V の出所が違うことが決定的です。よくよく図を見直してみると、cross attention の部分だけ K,VK,VQQ とが別々のところからやってきているのがわかります。

図の説明
Cross attention について

cross-attention では 2 系列の長さが一致する必要はなく、

  • Query 系列長:TqT_q
  • Key/Value 系列長:TkvT_{kv}

とすると、

X(q)RTq×dmodel,X(kv)RTkv×dmodelX^{(q)}\in\mathbb{R}^{T_q\times d_{\text{model}}},\quad X^{(kv)}\in\mathbb{R}^{T_{kv}\times d_{\text{model}}}

から

QRTq×dk,KRTkv×dk,VRTkv×dvQ\in\mathbb{R}^{T_q\times d_k},\quad K\in\mathbb{R}^{T_{kv}\times d_k},\quad V\in\mathbb{R}^{T_{kv}\times d_v}

となり、スコアと attention は

S=QKTdkRTq×TkvS=\frac{QK^T}{\sqrt{d_k}}\in\mathbb{R}^{T_q\times T_{kv}} A=softmax(S+M)RTq×TkvA=\mathrm{softmax}(S+M)\in\mathbb{R}^{T_q\times T_{kv}}

出力は

O=AVRTq×dvO=AV\in\mathbb{R}^{T_q\times d_v}

です。

成分での理解

成分で書けば、Query 側の位置 ii の出力は

oi=j=1Tkvaijvjo_i=\sum_{j=1}^{T_{kv}} a_{ij} v_j

です。ここで

  • ii は Query 側(例:デコーダの生成位置)
  • jj は Key/Value 側(例:エンコーダの入力文側)

になっています。

したがって cross-attention の直感は

  • Query:いま生成している位置が「入力文のどこが必要か」を問い合わせる
  • Key:入力文側の各位置が「自分はどんな情報を持つか」を示す索引
  • Value:入力文側の実際の内容(混ぜて渡す中身)

です。

マスクについて

self-attention(自己回帰)では因果マスクが重要でしたが、cross-attention では基本的には未来の情報も使用します。翻訳タスクのように入力文はすでに与えられた系列であり、時系列的に禁止すべき“未来”が存在しないからです。一方でよく使うのは padding mask です。入力側系列 X(kv)X^{(kv)} に PAD トークンが含まれる場合、その位置は参照しないように

  • PAD 位置のスコアに -\infty を足す

というマスクを入れます。


タスクについて

Transformer ブロックの出力は、以下の図のように入力と同じ次元数です。

図の説明
クエリベクトルを作成する場合の計算例

そのため、最終的なタスク(文章の感情分類や、次トークン予測)のためには何らかの方法で ORT×dO \in \mathbb{R}^{T\times d} を処理して必要な次元数に加工する必要があります。

文章の分類など

BERT などで学習されている文章の分類タスクでは、一文が与えられた状態でそれについてのクラス分類を解きます。一文全体を入力として良いため、Transformer の encoder(self-attention ブロック)を使います。 BERT で用いられていた手法は、入力トークンに追加で [CLS] トークンと呼ばれる特殊トークンを追加するというものです。

[CLS] トークンを追加して Transformer ブロックに出力して、[CLS] トークンの位置に対応する行を取り出して最終的なクラス分類タスクを解くというものです。入力時点での [CLS] トークンのベクトルは全文章共通ですが、Transformer ブロックを通過するときにそれぞれの文章固有のベクトル表現になることが期待でき、それを用いた分類を行うという発想です。

図の説明
BERTなどで用いられる [CLS] トークンを用いた分類タスクについて

次トークン生成など

GPT などで学習されている次に来る単語(トークン)の予測タスクでは、それまでのトークン列から次に来る確率の高いトークンを予測します。このときに May the force の次を予測するときに force 以降の情報を混ぜてはだめなので masked self-attention を使用する必要があり、そのため次トークン生成では Transformer decoder を使用します。

図の説明
GPTなどで用いられる次トークン予測について

最終的には、PRT×VP\in\mathbb{R}^{T\times V} の出力になり、それぞれを次トークンに来ると予測するトークンとして解釈することができます。


脚注

  1. 転置の TT と、トークン数の TT とがややこしかったですがよしなに理解してくださいませ。 ↩︎
  2. マスクしない self-attention では未来の情報も混ざっているので、これを「双方向に」読むとか表現したりもします。 ↩︎
  3. be 由来の重みが混じってリークしているということです。 ↩︎