mirror of
https://github.com/HackTricks-wiki/hacktricks.git
synced 2025-10-10 18:36:50 +00:00
417 lines
20 KiB
Markdown
417 lines
20 KiB
Markdown
# 4. Aufmerksamkeitsmechanismen
|
||
|
||
## Aufmerksamkeitsmechanismen und Selbstaufmerksamkeit in neuronalen Netzwerken
|
||
|
||
Aufmerksamkeitsmechanismen ermöglichen es neuronalen Netzwerken, sich **auf spezifische Teile der Eingabe zu konzentrieren, wenn sie jeden Teil der Ausgabe generieren**. Sie weisen unterschiedlichen Eingaben unterschiedliche Gewichte zu, was dem Modell hilft zu entscheiden, welche Eingaben für die jeweilige Aufgabe am relevantesten sind. Dies ist entscheidend bei Aufgaben wie maschineller Übersetzung, bei denen das Verständnis des Kontexts des gesamten Satzes für eine genaue Übersetzung notwendig ist.
|
||
|
||
> [!TIP]
|
||
> Das Ziel dieser vierten Phase ist sehr einfach: **Wenden Sie einige Aufmerksamkeitsmechanismen an**. Diese werden viele **wiederholte Schichten** sein, die die **Beziehung eines Wortes im Vokabular zu seinen Nachbarn im aktuellen Satz, der zum Trainieren des LLM verwendet wird, erfassen**.\
|
||
> Es werden viele Schichten dafür verwendet, sodass viele trainierbare Parameter diese Informationen erfassen werden.
|
||
|
||
### Verständnis der Aufmerksamkeitsmechanismen
|
||
|
||
In traditionellen Sequenz-zu-Sequenz-Modellen, die für die Sprachübersetzung verwendet werden, kodiert das Modell eine Eingabesequenz in einen kontextuellen Vektor fester Größe. Dieses Vorgehen hat jedoch Schwierigkeiten mit langen Sätzen, da der kontextuelle Vektor fester Größe möglicherweise nicht alle notwendigen Informationen erfasst. Aufmerksamkeitsmechanismen beheben diese Einschränkung, indem sie dem Modell erlauben, alle Eingabetoken zu berücksichtigen, wenn es jedes Ausgabetoken generiert.
|
||
|
||
#### Beispiel: Maschinelle Übersetzung
|
||
|
||
Betrachten Sie die Übersetzung des deutschen Satzes "Kannst du mir helfen diesen Satz zu übersetzen" ins Englische. Eine wortwörtliche Übersetzung würde keinen grammatikalisch korrekten englischen Satz ergeben, da es Unterschiede in den grammatikalischen Strukturen zwischen den Sprachen gibt. Ein Aufmerksamkeitsmechanismus ermöglicht es dem Modell, sich auf relevante Teile des Eingabesatzes zu konzentrieren, wenn es jedes Wort des Ausgabesatzes generiert, was zu einer genaueren und kohärenteren Übersetzung führt.
|
||
|
||
### Einführung in die Selbstaufmerksamkeit
|
||
|
||
Selbstaufmerksamkeit, oder Intra-Aufmerksamkeit, ist ein Mechanismus, bei dem Aufmerksamkeit innerhalb einer einzelnen Sequenz angewendet wird, um eine Darstellung dieser Sequenz zu berechnen. Sie ermöglicht es jedem Token in der Sequenz, auf alle anderen Tokens zu achten, was dem Modell hilft, Abhängigkeiten zwischen Tokens unabhängig von ihrer Entfernung in der Sequenz zu erfassen.
|
||
|
||
#### Schlüsselkonzepte
|
||
|
||
- **Tokens**: Einzelne Elemente der Eingabesequenz (z. B. Wörter in einem Satz).
|
||
- **Embeddings**: Vektorielle Darstellungen von Tokens, die semantische Informationen erfassen.
|
||
- **Aufmerksamkeitsgewichte**: Werte, die die Bedeutung jedes Tokens im Verhältnis zu anderen bestimmen.
|
||
|
||
### Berechnung der Aufmerksamkeitsgewichte: Ein Schritt-für-Schritt-Beispiel
|
||
|
||
Betrachten wir den Satz **"Hello shiny sun!"** und repräsentieren jedes Wort mit einem 3-dimensionalen Embedding:
|
||
|
||
- **Hello**: `[0.34, 0.22, 0.54]`
|
||
- **shiny**: `[0.53, 0.34, 0.98]`
|
||
- **sun**: `[0.29, 0.54, 0.93]`
|
||
|
||
Unser Ziel ist es, den **Kontextvektor** für das Wort **"shiny"** mithilfe von Selbstaufmerksamkeit zu berechnen.
|
||
|
||
#### Schritt 1: Berechnung der Aufmerksamkeitswerte
|
||
|
||
> [!TIP]
|
||
> Multiplizieren Sie einfach jeden Dimensionswert der Abfrage mit dem entsprechenden Wert jedes Tokens und addieren Sie die Ergebnisse. Sie erhalten 1 Wert pro Token-Paar.
|
||
|
||
Für jedes Wort im Satz berechnen Sie den **Aufmerksamkeitswert** in Bezug auf "shiny", indem Sie das Skalarprodukt ihrer Embeddings berechnen.
|
||
|
||
**Aufmerksamkeitswert zwischen "Hello" und "shiny"**
|
||
|
||
<figure><img src="../../images/image (4) (1) (1).png" alt="" width="563"><figcaption></figcaption></figure>
|
||
|
||
**Aufmerksamkeitswert zwischen "shiny" und "shiny"**
|
||
|
||
<figure><img src="../../images/image (1) (1) (1) (1) (1) (1) (1) (1).png" alt="" width="563"><figcaption></figcaption></figure>
|
||
|
||
**Aufmerksamkeitswert zwischen "sun" und "shiny"**
|
||
|
||
<figure><img src="../../images/image (2) (1) (1) (1) (1).png" alt="" width="563"><figcaption></figcaption></figure>
|
||
|
||
#### Schritt 2: Normalisieren der Aufmerksamkeitswerte zur Ermittlung der Aufmerksamkeitsgewichte
|
||
|
||
> [!TIP]
|
||
> Lassen Sie sich nicht von den mathematischen Begriffen verwirren, das Ziel dieser Funktion ist einfach, normalisieren Sie alle Gewichte, sodass **sie insgesamt 1 ergeben**.
|
||
>
|
||
> Darüber hinaus wird die **Softmax**-Funktion verwendet, da sie Unterschiede aufgrund des exponentiellen Teils verstärkt, was es einfacher macht, nützliche Werte zu erkennen.
|
||
|
||
Wenden Sie die **Softmax-Funktion** auf die Aufmerksamkeitswerte an, um sie in Aufmerksamkeitsgewichte umzuwandeln, die sich zu 1 summieren.
|
||
|
||
<figure><img src="../../images/image (3) (1) (1) (1) (1).png" alt="" width="293"><figcaption></figcaption></figure>
|
||
|
||
Berechnung der Exponentialwerte:
|
||
|
||
<figure><img src="../../images/image (4) (1) (1).png" alt="" width="249"><figcaption></figcaption></figure>
|
||
|
||
Berechnung der Summe:
|
||
|
||
<figure><img src="../../images/image (5) (1) (1).png" alt="" width="563"><figcaption></figcaption></figure>
|
||
|
||
Berechnung der Aufmerksamkeitsgewichte:
|
||
|
||
<figure><img src="../../images/image (6) (1) (1).png" alt="" width="404"><figcaption></figcaption></figure>
|
||
|
||
#### Schritt 3: Berechnung des Kontextvektors
|
||
|
||
> [!TIP]
|
||
> Nehmen Sie einfach jedes Aufmerksamkeitsgewicht, multiplizieren Sie es mit den entsprechenden Token-Dimensionen und summieren Sie dann alle Dimensionen, um nur 1 Vektor (den Kontextvektor) zu erhalten.
|
||
|
||
Der **Kontextvektor** wird als gewichtete Summe der Embeddings aller Wörter unter Verwendung der Aufmerksamkeitsgewichte berechnet.
|
||
|
||
<figure><img src="../../images/image (16).png" alt="" width="369"><figcaption></figcaption></figure>
|
||
|
||
Berechnung jeder Komponente:
|
||
|
||
- **Gewichtetes Embedding von "Hello"**:
|
||
|
||
<figure><img src="../../images/image (7) (1) (1).png" alt=""><figcaption></figcaption></figure>
|
||
|
||
- **Gewichtetes Embedding von "shiny"**:
|
||
|
||
<figure><img src="../../images/image (8) (1) (1).png" alt=""><figcaption></figcaption></figure>
|
||
|
||
- **Gewichtetes Embedding von "sun"**:
|
||
|
||
<figure><img src="../../images/image (9) (1) (1).png" alt=""><figcaption></figcaption></figure>
|
||
|
||
Summierung der gewichteten Embeddings:
|
||
|
||
`context vector=[0.0779+0.2156+0.1057, 0.0504+0.1382+0.1972, 0.1237+0.3983+0.3390]=[0.3992,0.3858,0.8610]`
|
||
|
||
**Dieser Kontextvektor repräsentiert das angereicherte Embedding für das Wort "shiny", das Informationen aus allen Wörtern im Satz integriert.**
|
||
|
||
### Zusammenfassung des Prozesses
|
||
|
||
1. **Berechnung der Aufmerksamkeitswerte**: Verwenden Sie das Skalarprodukt zwischen dem Embedding des Zielworts und den Embeddings aller Wörter in der Sequenz.
|
||
2. **Normalisieren der Werte zur Ermittlung der Aufmerksamkeitsgewichte**: Wenden Sie die Softmax-Funktion auf die Aufmerksamkeitswerte an, um Gewichte zu erhalten, die sich zu 1 summieren.
|
||
3. **Berechnung des Kontextvektors**: Multiplizieren Sie das Embedding jedes Wortes mit seinem Aufmerksamkeitsgewicht und summieren Sie die Ergebnisse.
|
||
|
||
## Selbstaufmerksamkeit mit trainierbaren Gewichten
|
||
|
||
In der Praxis verwenden Selbstaufmerksamkeitsmechanismen **trainierbare Gewichte**, um die besten Darstellungen für Abfragen, Schlüssel und Werte zu lernen. Dies beinhaltet die Einführung von drei Gewichtsmatrizen:
|
||
|
||
<figure><img src="../../images/image (10) (1) (1).png" alt="" width="239"><figcaption></figcaption></figure>
|
||
|
||
Die Abfrage ist die zu verwendende Daten wie zuvor, während die Schlüssel- und Wertematrizen einfach zufällige, trainierbare Matrizen sind.
|
||
|
||
#### Schritt 1: Berechnung von Abfragen, Schlüsseln und Werten
|
||
|
||
Jedes Token hat seine eigene Abfrage-, Schlüssel- und Wertematrix, indem es seine Dimensionswerte mit den definierten Matrizen multipliziert:
|
||
|
||
<figure><img src="../../images/image (11).png" alt="" width="253"><figcaption></figcaption></figure>
|
||
|
||
Diese Matrizen transformieren die ursprünglichen Embeddings in einen neuen Raum, der für die Berechnung der Aufmerksamkeit geeignet ist.
|
||
|
||
**Beispiel**
|
||
|
||
Angenommen:
|
||
|
||
- Eingabedimension `din=3` (Embedding-Größe)
|
||
- Ausgabedimension `dout=2` (gewünschte Dimension für Abfragen, Schlüssel und Werte)
|
||
|
||
Initialisieren Sie die Gewichtsmatrizen:
|
||
```python
|
||
import torch.nn as nn
|
||
|
||
d_in = 3
|
||
d_out = 2
|
||
|
||
W_query = nn.Parameter(torch.rand(d_in, d_out))
|
||
W_key = nn.Parameter(torch.rand(d_in, d_out))
|
||
W_value = nn.Parameter(torch.rand(d_in, d_out))
|
||
```
|
||
Berechne Abfragen, Schlüssel und Werte:
|
||
```python
|
||
queries = torch.matmul(inputs, W_query)
|
||
keys = torch.matmul(inputs, W_key)
|
||
values = torch.matmul(inputs, W_value)
|
||
```
|
||
#### Schritt 2: Berechnung der skalierten Dot-Produkt-Attention
|
||
|
||
**Berechnung der Attention-Werte**
|
||
|
||
Ähnlich wie im vorherigen Beispiel, aber diesmal verwenden wir anstelle der Werte der Dimensionen der Tokens die Schlüsselmatrix des Tokens (bereits unter Verwendung der Dimensionen berechnet):. Für jede Abfrage `qi` und Schlüssel `kj`:
|
||
|
||
<figure><img src="../../images/image (12).png" alt=""><figcaption></figcaption></figure>
|
||
|
||
**Skalierung der Werte**
|
||
|
||
Um zu verhindern, dass die Dot-Produkte zu groß werden, skalieren Sie sie durch die Quadratwurzel der Schlüssel-Dimension `dk`:
|
||
|
||
<figure><img src="../../images/image (13).png" alt="" width="295"><figcaption></figcaption></figure>
|
||
|
||
> [!TIPP]
|
||
> Der Wert wird durch die Quadratwurzel der Dimensionen geteilt, da Dot-Produkte sehr groß werden können und dies hilft, sie zu regulieren.
|
||
|
||
**Anwenden von Softmax zur Ermittlung der Attention-Gewichte:** Wie im ursprünglichen Beispiel, normalisieren Sie alle Werte, sodass sie 1 ergeben.
|
||
|
||
<figure><img src="../../images/image (14).png" alt="" width="295"><figcaption></figcaption></figure>
|
||
|
||
#### Schritt 3: Berechnung der Kontextvektoren
|
||
|
||
Wie im ursprünglichen Beispiel, summieren Sie einfach alle Wertematrizen und multiplizieren jede mit ihrem Attention-Gewicht:
|
||
|
||
<figure><img src="../../images/image (15).png" alt="" width="328"><figcaption></figcaption></figure>
|
||
|
||
### Codebeispiel
|
||
|
||
Ein Beispiel von [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb) zeigt diese Klasse, die die Selbst-Attention-Funktionalität implementiert, über die wir gesprochen haben:
|
||
```python
|
||
import torch
|
||
|
||
inputs = torch.tensor(
|
||
[[0.43, 0.15, 0.89], # Your (x^1)
|
||
[0.55, 0.87, 0.66], # journey (x^2)
|
||
[0.57, 0.85, 0.64], # starts (x^3)
|
||
[0.22, 0.58, 0.33], # with (x^4)
|
||
[0.77, 0.25, 0.10], # one (x^5)
|
||
[0.05, 0.80, 0.55]] # step (x^6)
|
||
)
|
||
|
||
import torch.nn as nn
|
||
class SelfAttention_v2(nn.Module):
|
||
|
||
def __init__(self, d_in, d_out, qkv_bias=False):
|
||
super().__init__()
|
||
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||
|
||
def forward(self, x):
|
||
keys = self.W_key(x)
|
||
queries = self.W_query(x)
|
||
values = self.W_value(x)
|
||
|
||
attn_scores = queries @ keys.T
|
||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||
|
||
context_vec = attn_weights @ values
|
||
return context_vec
|
||
|
||
d_in=3
|
||
d_out=2
|
||
torch.manual_seed(789)
|
||
sa_v2 = SelfAttention_v2(d_in, d_out)
|
||
print(sa_v2(inputs))
|
||
```
|
||
> [!TIP]
|
||
> Beachten Sie, dass anstelle der Initialisierung der Matrizen mit zufälligen Werten `nn.Linear` verwendet wird, um alle Gewichte als Parameter zum Trainieren zu kennzeichnen.
|
||
|
||
## Kausale Aufmerksamkeit: Zukünftige Wörter verbergen
|
||
|
||
Für LLMs möchten wir, dass das Modell nur die Tokens berücksichtigt, die vor der aktuellen Position erscheinen, um das **nächste Token vorherzusagen**. **Kausale Aufmerksamkeit**, auch bekannt als **maskierte Aufmerksamkeit**, erreicht dies, indem der Aufmerksamkeitsmechanismus modifiziert wird, um den Zugriff auf zukünftige Tokens zu verhindern.
|
||
|
||
### Anwendung einer kausalen Aufmerksamkeitsmaske
|
||
|
||
Um kausale Aufmerksamkeit zu implementieren, wenden wir eine Maske auf die Aufmerksamkeitswerte **vor der Softmax-Operation** an, sodass die verbleibenden Werte immer noch 1 ergeben. Diese Maske setzt die Aufmerksamkeitswerte zukünftiger Tokens auf negative Unendlichkeit, wodurch sichergestellt wird, dass nach der Softmax ihre Aufmerksamkeitsgewichte null sind.
|
||
|
||
**Schritte**
|
||
|
||
1. **Berechnung der Aufmerksamkeitswerte**: Wie zuvor.
|
||
2. **Maske anwenden**: Verwenden Sie eine obere Dreiecksmatrix, die über der Diagonalen mit negativer Unendlichkeit gefüllt ist.
|
||
|
||
```python
|
||
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * float('-inf')
|
||
masked_scores = attention_scores + mask
|
||
```
|
||
|
||
3. **Softmax anwenden**: Berechnen Sie die Aufmerksamkeitsgewichte mit den maskierten Werten.
|
||
|
||
```python
|
||
attention_weights = torch.softmax(masked_scores, dim=-1)
|
||
```
|
||
|
||
### Maskierung zusätzlicher Aufmerksamkeitsgewichte mit Dropout
|
||
|
||
Um **Überanpassung zu verhindern**, können wir **Dropout** auf die Aufmerksamkeitsgewichte nach der Softmax-Operation anwenden. Dropout **setzt zufällig einige der Aufmerksamkeitsgewichte während des Trainings auf null.**
|
||
```python
|
||
dropout = nn.Dropout(p=0.5)
|
||
attention_weights = dropout(attention_weights)
|
||
```
|
||
Ein regulärer Dropout liegt bei etwa 10-20%.
|
||
|
||
### Codebeispiel
|
||
|
||
Codebeispiel von [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb):
|
||
```python
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
inputs = torch.tensor(
|
||
[[0.43, 0.15, 0.89], # Your (x^1)
|
||
[0.55, 0.87, 0.66], # journey (x^2)
|
||
[0.57, 0.85, 0.64], # starts (x^3)
|
||
[0.22, 0.58, 0.33], # with (x^4)
|
||
[0.77, 0.25, 0.10], # one (x^5)
|
||
[0.05, 0.80, 0.55]] # step (x^6)
|
||
)
|
||
|
||
batch = torch.stack((inputs, inputs), dim=0)
|
||
print(batch.shape)
|
||
|
||
class CausalAttention(nn.Module):
|
||
|
||
def __init__(self, d_in, d_out, context_length,
|
||
dropout, qkv_bias=False):
|
||
super().__init__()
|
||
self.d_out = d_out
|
||
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New
|
||
|
||
def forward(self, x):
|
||
b, num_tokens, d_in = x.shape
|
||
# b is the num of batches
|
||
# num_tokens is the number of tokens per batch
|
||
# d_in is the dimensions er token
|
||
|
||
keys = self.W_key(x) # This generates the keys of the tokens
|
||
queries = self.W_query(x)
|
||
values = self.W_value(x)
|
||
|
||
attn_scores = queries @ keys.transpose(1, 2) # Moves the third dimension to the second one and the second one to the third one to be able to multiply
|
||
attn_scores.masked_fill_( # New, _ ops are in-place
|
||
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
|
||
attn_weights = torch.softmax(
|
||
attn_scores / keys.shape[-1]**0.5, dim=-1
|
||
)
|
||
attn_weights = self.dropout(attn_weights)
|
||
|
||
context_vec = attn_weights @ values
|
||
return context_vec
|
||
|
||
torch.manual_seed(123)
|
||
|
||
context_length = batch.shape[1]
|
||
d_in = 3
|
||
d_out = 2
|
||
ca = CausalAttention(d_in, d_out, context_length, 0.0)
|
||
|
||
context_vecs = ca(batch)
|
||
|
||
print(context_vecs)
|
||
print("context_vecs.shape:", context_vecs.shape)
|
||
```
|
||
## Erweiterung der Ein-Kopf-Attention zur Mehr-Kopf-Attention
|
||
|
||
**Mehr-Kopf-Attention** besteht in der Praxis darin, **mehrere Instanzen** der Selbst-Attention-Funktion auszuführen, wobei jede von ihnen **ihre eigenen Gewichte** hat, sodass unterschiedliche finale Vektoren berechnet werden.
|
||
|
||
### Codebeispiel
|
||
|
||
Es wäre möglich, den vorherigen Code wiederzuverwenden und einfach einen Wrapper hinzuzufügen, der ihn mehrere Male ausführt, aber dies ist eine optimierte Version von [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb), die alle Köpfe gleichzeitig verarbeitet (was die Anzahl der teuren for-Schleifen reduziert). Wie im Code zu sehen ist, werden die Dimensionen jedes Tokens in verschiedene Dimensionen entsprechend der Anzahl der Köpfe aufgeteilt. Auf diese Weise, wenn ein Token 8 Dimensionen hat und wir 3 Köpfe verwenden möchten, werden die Dimensionen in 2 Arrays mit 4 Dimensionen aufgeteilt, und jeder Kopf verwendet eines davon:
|
||
```python
|
||
class MultiHeadAttention(nn.Module):
|
||
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
|
||
super().__init__()
|
||
assert (d_out % num_heads == 0), \
|
||
"d_out must be divisible by num_heads"
|
||
|
||
self.d_out = d_out
|
||
self.num_heads = num_heads
|
||
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
|
||
|
||
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.register_buffer(
|
||
"mask",
|
||
torch.triu(torch.ones(context_length, context_length),
|
||
diagonal=1)
|
||
)
|
||
|
||
def forward(self, x):
|
||
b, num_tokens, d_in = x.shape
|
||
# b is the num of batches
|
||
# num_tokens is the number of tokens per batch
|
||
# d_in is the dimensions er token
|
||
|
||
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
|
||
queries = self.W_query(x)
|
||
values = self.W_value(x)
|
||
|
||
# We implicitly split the matrix by adding a `num_heads` dimension
|
||
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
|
||
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
|
||
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
|
||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
||
|
||
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
|
||
keys = keys.transpose(1, 2)
|
||
queries = queries.transpose(1, 2)
|
||
values = values.transpose(1, 2)
|
||
|
||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||
|
||
# Original mask truncated to the number of tokens and converted to boolean
|
||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||
|
||
# Use the mask to fill attention scores
|
||
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||
|
||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||
attn_weights = self.dropout(attn_weights)
|
||
|
||
# Shape: (b, num_tokens, num_heads, head_dim)
|
||
context_vec = (attn_weights @ values).transpose(1, 2)
|
||
|
||
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
||
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
|
||
context_vec = self.out_proj(context_vec) # optional projection
|
||
|
||
return context_vec
|
||
|
||
torch.manual_seed(123)
|
||
|
||
batch_size, context_length, d_in = batch.shape
|
||
d_out = 2
|
||
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
|
||
|
||
context_vecs = mha(batch)
|
||
|
||
print(context_vecs)
|
||
print("context_vecs.shape:", context_vecs.shape)
|
||
|
||
```
|
||
Für eine weitere kompakte und effiziente Implementierung könnten Sie die [`torch.nn.MultiheadAttention`](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) Klasse in PyTorch verwenden.
|
||
|
||
> [!TIP]
|
||
> Kurze Antwort von ChatGPT, warum es besser ist, die Dimensionen der Tokens unter den Köpfen zu teilen, anstatt dass jeder Kopf alle Dimensionen aller Tokens überprüft:
|
||
>
|
||
> Während es vorteilhaft erscheinen mag, jedem Kopf den Zugriff auf alle Einbettungsdimensionen zu ermöglichen, da jeder Kopf auf die vollständigen Informationen zugreifen könnte, ist die gängige Praxis, die **Einbettungsdimensionen unter den Köpfen zu teilen**. Dieser Ansatz balanciert die rechnerische Effizienz mit der Modellleistung und fördert, dass jeder Kopf unterschiedliche Darstellungen lernt. Daher wird das Teilen der Einbettungsdimensionen im Allgemeinen bevorzugt, anstatt dass jeder Kopf alle Dimensionen überprüft.
|
||
|
||
## References
|
||
|
||
- [https://www.manning.com/books/build-a-large-language-model-from-scratch](https://www.manning.com/books/build-a-large-language-model-from-scratch)
|