hacktricks/src/AI/AI-llm-architecture/7.1.-fine-tuning-for-classification.md

115 lines
7.8 KiB
Markdown

# 7.1. Fine-Tuning for Classification
{{#include ../../banners/hacktricks-training.md}}
## Che cos'è
Il fine-tuning è il processo di prendere un **modello pre-addestrato** che ha appreso **modelli linguistici generali** da enormi quantità di dati e **adattarlo** per eseguire un **compito specifico** o per comprendere il linguaggio specifico di un dominio. Questo si ottiene continuando l'addestramento del modello su un dataset più piccolo e specifico per il compito, permettendogli di regolare i suoi parametri per adattarsi meglio alle sfumature dei nuovi dati, sfruttando al contempo la vasta conoscenza che ha già acquisito. Il fine-tuning consente al modello di fornire risultati più accurati e pertinenti in applicazioni specializzate senza la necessità di addestrare un nuovo modello da zero.
> [!TIP]
> Poiché il pre-addestramento di un LLM che "comprende" il testo è piuttosto costoso, di solito è più facile ed economico fare fine-tuning su modelli pre-addestrati open source per eseguire un compito specifico che vogliamo che esegua.
> [!TIP]
> L'obiettivo di questa sezione è mostrare come fare fine-tuning su un modello già pre-addestrato, in modo che invece di generare nuovo testo, l'LLM selezionerà e fornirà le **probabilità che il testo fornito venga categorizzato in ciascuna delle categorie date** (come se un testo fosse spam o meno).
## Preparazione del dataset
### Dimensione del dataset
Certo, per fare fine-tuning su un modello hai bisogno di dati strutturati da utilizzare per specializzare il tuo LLM. Nell'esempio proposto in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb), GPT2 è fine-tuned per rilevare se un'email è spam o meno utilizzando i dati di [https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip](https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip)_._
Questo dataset contiene molti più esempi di "non spam" che di "spam", quindi il libro suggerisce di **utilizzare solo tanti esempi di "non spam" quanti di "spam"** (rimuovendo quindi dal set di addestramento tutti gli esempi extra). In questo caso, erano 747 esempi di ciascuno.
Poi, il **70%** del dataset è utilizzato per il **training**, il **10%** per la **validazione** e il **20%** per il **testing**.
- Il **set di validazione** è utilizzato durante la fase di addestramento per fare fine-tuning degli **iperparametri** del modello e prendere decisioni sull'architettura del modello, aiutando effettivamente a prevenire l'overfitting fornendo feedback su come il modello si comporta su dati non visti. Consente miglioramenti iterativi senza pregiudicare la valutazione finale.
- Questo significa che, sebbene i dati inclusi in questo dataset non siano utilizzati direttamente per l'addestramento, vengono utilizzati per ottimizzare i migliori **iperparametri**, quindi questo set non può essere utilizzato per valutare le prestazioni del modello come quello di testing.
- Al contrario, il **set di test** è utilizzato **solo dopo** che il modello è stato completamente addestrato e tutti gli aggiustamenti sono stati completati; fornisce una valutazione imparziale della capacità del modello di generalizzare a nuovi dati non visti. Questa valutazione finale sul set di test fornisce un'indicazione realistica di come ci si aspetta che il modello si comporti nelle applicazioni del mondo reale.
### Lunghezza delle voci
Poiché l'esempio di addestramento si aspetta voci (testo delle email in questo caso) della stessa lunghezza, è stato deciso di rendere ogni voce grande quanto la più grande aggiungendo gli id di `<|endoftext|>` come padding.
### Inizializzare il modello
Utilizzando i pesi pre-addestrati open-source, inizializza il modello per l'addestramento. Abbiamo già fatto questo prima e seguendo le istruzioni di [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb) puoi farlo facilmente.
## Testa di classificazione
In questo esempio specifico (predire se un testo è spam o meno), non siamo interessati a fare fine-tuning secondo il vocabolario completo di GPT2, ma vogliamo solo che il nuovo modello dica se l'email è spam (1) o meno (0). Pertanto, andremo a **modificare l'ultimo strato che** fornisce le probabilità per token del vocabolario in uno che fornisce solo le probabilità di essere spam o meno (quindi come un vocabolario di 2 parole).
```python
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
```
## Parametri da ottimizzare
Per ottimizzare rapidamente, è più facile non ottimizzare tutti i parametri ma solo alcuni finali. Questo perché è noto che i livelli inferiori catturano generalmente strutture linguistiche di base e semantiche applicabili. Quindi, **ottimizzare solo gli ultimi livelli è di solito sufficiente e più veloce**.
```python
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
```
## Entries to use for training
Nelle sezioni precedenti, il LLM è stato addestrato riducendo la perdita di ogni token previsto, anche se quasi tutti i token previsti erano nella frase di input (solo 1 alla fine era realmente previsto) affinché il modello comprendesse meglio la lingua.
In questo caso ci interessa solo che il modello sia in grado di prevedere se il modello è spam o meno, quindi ci interessa solo l'ultimo token previsto. Pertanto, è necessario modificare le nostre precedenti funzioni di perdita di addestramento per tenere conto solo di quel token.
Questo è implementato in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb) come:
```python
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
```
Nota come per ogni batch siamo interessati solo ai **logits dell'ultimo token previsto**.
## Codice completo per la classificazione fine-tune di GPT2
Puoi trovare tutto il codice per fine-tunare GPT2 come classificatore di spam in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb)
## Riferimenti
- [https://www.manning.com/books/build-a-large-language-model-from-scratch](https://www.manning.com/books/build-a-large-language-model-from-scratch)
{{#include ../../banners/hacktricks-training.md}}