mirror of
https://github.com/HackTricks-wiki/hacktricks.git
synced 2025-10-10 18:36:50 +00:00
115 lines
7.8 KiB
Markdown
115 lines
7.8 KiB
Markdown
# 7.1. Fine-Tuning for Classification
|
|
|
|
{{#include ../../banners/hacktricks-training.md}}
|
|
|
|
## Che cos'è
|
|
|
|
Il fine-tuning è il processo di prendere un **modello pre-addestrato** che ha appreso **modelli linguistici generali** da enormi quantità di dati e **adattarlo** per eseguire un **compito specifico** o per comprendere il linguaggio specifico di un dominio. Questo si ottiene continuando l'addestramento del modello su un dataset più piccolo e specifico per il compito, permettendogli di regolare i suoi parametri per adattarsi meglio alle sfumature dei nuovi dati, sfruttando al contempo la vasta conoscenza che ha già acquisito. Il fine-tuning consente al modello di fornire risultati più accurati e pertinenti in applicazioni specializzate senza la necessità di addestrare un nuovo modello da zero.
|
|
|
|
> [!TIP]
|
|
> Poiché il pre-addestramento di un LLM che "comprende" il testo è piuttosto costoso, di solito è più facile ed economico fare fine-tuning su modelli pre-addestrati open source per eseguire un compito specifico che vogliamo che esegua.
|
|
|
|
> [!TIP]
|
|
> L'obiettivo di questa sezione è mostrare come fare fine-tuning su un modello già pre-addestrato, in modo che invece di generare nuovo testo, l'LLM selezionerà e fornirà le **probabilità che il testo fornito venga categorizzato in ciascuna delle categorie date** (come se un testo fosse spam o meno).
|
|
|
|
## Preparazione del dataset
|
|
|
|
### Dimensione del dataset
|
|
|
|
Certo, per fare fine-tuning su un modello hai bisogno di dati strutturati da utilizzare per specializzare il tuo LLM. Nell'esempio proposto in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb), GPT2 è fine-tuned per rilevare se un'email è spam o meno utilizzando i dati di [https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip](https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip)_._
|
|
|
|
Questo dataset contiene molti più esempi di "non spam" che di "spam", quindi il libro suggerisce di **utilizzare solo tanti esempi di "non spam" quanti di "spam"** (rimuovendo quindi dal set di addestramento tutti gli esempi extra). In questo caso, erano 747 esempi di ciascuno.
|
|
|
|
Poi, il **70%** del dataset è utilizzato per il **training**, il **10%** per la **validazione** e il **20%** per il **testing**.
|
|
|
|
- Il **set di validazione** è utilizzato durante la fase di addestramento per fare fine-tuning degli **iperparametri** del modello e prendere decisioni sull'architettura del modello, aiutando effettivamente a prevenire l'overfitting fornendo feedback su come il modello si comporta su dati non visti. Consente miglioramenti iterativi senza pregiudicare la valutazione finale.
|
|
- Questo significa che, sebbene i dati inclusi in questo dataset non siano utilizzati direttamente per l'addestramento, vengono utilizzati per ottimizzare i migliori **iperparametri**, quindi questo set non può essere utilizzato per valutare le prestazioni del modello come quello di testing.
|
|
- Al contrario, il **set di test** è utilizzato **solo dopo** che il modello è stato completamente addestrato e tutti gli aggiustamenti sono stati completati; fornisce una valutazione imparziale della capacità del modello di generalizzare a nuovi dati non visti. Questa valutazione finale sul set di test fornisce un'indicazione realistica di come ci si aspetta che il modello si comporti nelle applicazioni del mondo reale.
|
|
|
|
### Lunghezza delle voci
|
|
|
|
Poiché l'esempio di addestramento si aspetta voci (testo delle email in questo caso) della stessa lunghezza, è stato deciso di rendere ogni voce grande quanto la più grande aggiungendo gli id di `<|endoftext|>` come padding.
|
|
|
|
### Inizializzare il modello
|
|
|
|
Utilizzando i pesi pre-addestrati open-source, inizializza il modello per l'addestramento. Abbiamo già fatto questo prima e seguendo le istruzioni di [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb) puoi farlo facilmente.
|
|
|
|
## Testa di classificazione
|
|
|
|
In questo esempio specifico (predire se un testo è spam o meno), non siamo interessati a fare fine-tuning secondo il vocabolario completo di GPT2, ma vogliamo solo che il nuovo modello dica se l'email è spam (1) o meno (0). Pertanto, andremo a **modificare l'ultimo strato che** fornisce le probabilità per token del vocabolario in uno che fornisce solo le probabilità di essere spam o meno (quindi come un vocabolario di 2 parole).
|
|
```python
|
|
# This code modified the final layer with a Linear one with 2 outs
|
|
num_classes = 2
|
|
model.out_head = torch.nn.Linear(
|
|
|
|
in_features=BASE_CONFIG["emb_dim"],
|
|
|
|
out_features=num_classes
|
|
)
|
|
```
|
|
## Parametri da ottimizzare
|
|
|
|
Per ottimizzare rapidamente, è più facile non ottimizzare tutti i parametri ma solo alcuni finali. Questo perché è noto che i livelli inferiori catturano generalmente strutture linguistiche di base e semantiche applicabili. Quindi, **ottimizzare solo gli ultimi livelli è di solito sufficiente e più veloce**.
|
|
```python
|
|
# This code makes all the parameters of the model unrtainable
|
|
for param in model.parameters():
|
|
param.requires_grad = False
|
|
|
|
# Allow to fine tune the last layer in the transformer block
|
|
for param in model.trf_blocks[-1].parameters():
|
|
param.requires_grad = True
|
|
|
|
# Allow to fine tune the final layer norm
|
|
for param in model.final_norm.parameters():
|
|
|
|
param.requires_grad = True
|
|
```
|
|
## Entries to use for training
|
|
|
|
Nelle sezioni precedenti, il LLM è stato addestrato riducendo la perdita di ogni token previsto, anche se quasi tutti i token previsti erano nella frase di input (solo 1 alla fine era realmente previsto) affinché il modello comprendesse meglio la lingua.
|
|
|
|
In questo caso ci interessa solo che il modello sia in grado di prevedere se il modello è spam o meno, quindi ci interessa solo l'ultimo token previsto. Pertanto, è necessario modificare le nostre precedenti funzioni di perdita di addestramento per tenere conto solo di quel token.
|
|
|
|
Questo è implementato in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb) come:
|
|
```python
|
|
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
|
model.eval()
|
|
correct_predictions, num_examples = 0, 0
|
|
|
|
if num_batches is None:
|
|
num_batches = len(data_loader)
|
|
else:
|
|
num_batches = min(num_batches, len(data_loader))
|
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
|
if i < num_batches:
|
|
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
|
|
|
with torch.no_grad():
|
|
logits = model(input_batch)[:, -1, :] # Logits of last output token
|
|
predicted_labels = torch.argmax(logits, dim=-1)
|
|
|
|
num_examples += predicted_labels.shape[0]
|
|
correct_predictions += (predicted_labels == target_batch).sum().item()
|
|
else:
|
|
break
|
|
return correct_predictions / num_examples
|
|
|
|
|
|
def calc_loss_batch(input_batch, target_batch, model, device):
|
|
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
|
logits = model(input_batch)[:, -1, :] # Logits of last output token
|
|
loss = torch.nn.functional.cross_entropy(logits, target_batch)
|
|
return loss
|
|
```
|
|
Nota come per ogni batch siamo interessati solo ai **logits dell'ultimo token previsto**.
|
|
|
|
## Codice completo per la classificazione fine-tune di GPT2
|
|
|
|
Puoi trovare tutto il codice per fine-tunare GPT2 come classificatore di spam in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb)
|
|
|
|
## Riferimenti
|
|
|
|
- [https://www.manning.com/books/build-a-large-language-model-from-scratch](https://www.manning.com/books/build-a-large-language-model-from-scratch)
|
|
|
|
{{#include ../../banners/hacktricks-training.md}}
|