# 7.1. Fine-Tuning for Classification {{#include ../../banners/hacktricks-training.md}} ## Che cos'è Il fine-tuning è il processo di prendere un **modello pre-addestrato** che ha appreso **modelli linguistici generali** da enormi quantità di dati e **adattarlo** per eseguire un **compito specifico** o per comprendere il linguaggio specifico di un dominio. Questo si ottiene continuando l'addestramento del modello su un dataset più piccolo e specifico per il compito, permettendogli di regolare i suoi parametri per adattarsi meglio alle sfumature dei nuovi dati, sfruttando al contempo la vasta conoscenza che ha già acquisito. Il fine-tuning consente al modello di fornire risultati più accurati e pertinenti in applicazioni specializzate senza la necessità di addestrare un nuovo modello da zero. > [!TIP] > Poiché il pre-addestramento di un LLM che "comprende" il testo è piuttosto costoso, di solito è più facile ed economico fare fine-tuning su modelli pre-addestrati open source per eseguire un compito specifico che vogliamo che esegua. > [!TIP] > L'obiettivo di questa sezione è mostrare come fare fine-tuning su un modello già pre-addestrato, in modo che invece di generare nuovo testo, l'LLM selezionerà e fornirà le **probabilità che il testo fornito venga categorizzato in ciascuna delle categorie date** (come se un testo fosse spam o meno). ## Preparazione del dataset ### Dimensione del dataset Certo, per fare fine-tuning su un modello hai bisogno di dati strutturati da utilizzare per specializzare il tuo LLM. Nell'esempio proposto in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb), GPT2 è fine-tuned per rilevare se un'email è spam o meno utilizzando i dati di [https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip](https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip)_._ Questo dataset contiene molti più esempi di "non spam" che di "spam", quindi il libro suggerisce di **utilizzare solo tanti esempi di "non spam" quanti di "spam"** (rimuovendo quindi dal set di addestramento tutti gli esempi extra). In questo caso, erano 747 esempi di ciascuno. Poi, il **70%** del dataset è utilizzato per il **training**, il **10%** per la **validazione** e il **20%** per il **testing**. - Il **set di validazione** è utilizzato durante la fase di addestramento per fare fine-tuning degli **iperparametri** del modello e prendere decisioni sull'architettura del modello, aiutando effettivamente a prevenire l'overfitting fornendo feedback su come il modello si comporta su dati non visti. Consente miglioramenti iterativi senza pregiudicare la valutazione finale. - Questo significa che, sebbene i dati inclusi in questo dataset non siano utilizzati direttamente per l'addestramento, vengono utilizzati per ottimizzare i migliori **iperparametri**, quindi questo set non può essere utilizzato per valutare le prestazioni del modello come quello di testing. - Al contrario, il **set di test** è utilizzato **solo dopo** che il modello è stato completamente addestrato e tutti gli aggiustamenti sono stati completati; fornisce una valutazione imparziale della capacità del modello di generalizzare a nuovi dati non visti. Questa valutazione finale sul set di test fornisce un'indicazione realistica di come ci si aspetta che il modello si comporti nelle applicazioni del mondo reale. ### Lunghezza delle voci Poiché l'esempio di addestramento si aspetta voci (testo delle email in questo caso) della stessa lunghezza, è stato deciso di rendere ogni voce grande quanto la più grande aggiungendo gli id di `<|endoftext|>` come padding. ### Inizializzare il modello Utilizzando i pesi pre-addestrati open-source, inizializza il modello per l'addestramento. Abbiamo già fatto questo prima e seguendo le istruzioni di [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb) puoi farlo facilmente. ## Testa di classificazione In questo esempio specifico (predire se un testo è spam o meno), non siamo interessati a fare fine-tuning secondo il vocabolario completo di GPT2, ma vogliamo solo che il nuovo modello dica se l'email è spam (1) o meno (0). Pertanto, andremo a **modificare l'ultimo strato che** fornisce le probabilità per token del vocabolario in uno che fornisce solo le probabilità di essere spam o meno (quindi come un vocabolario di 2 parole). ```python # This code modified the final layer with a Linear one with 2 outs num_classes = 2 model.out_head = torch.nn.Linear( in_features=BASE_CONFIG["emb_dim"], out_features=num_classes ) ``` ## Parametri da ottimizzare Per ottimizzare rapidamente, è più facile non ottimizzare tutti i parametri ma solo alcuni finali. Questo perché è noto che i livelli inferiori catturano generalmente strutture linguistiche di base e semantiche applicabili. Quindi, **ottimizzare solo gli ultimi livelli è di solito sufficiente e più veloce**. ```python # This code makes all the parameters of the model unrtainable for param in model.parameters(): param.requires_grad = False # Allow to fine tune the last layer in the transformer block for param in model.trf_blocks[-1].parameters(): param.requires_grad = True # Allow to fine tune the final layer norm for param in model.final_norm.parameters(): param.requires_grad = True ``` ## Entries to use for training Nelle sezioni precedenti, il LLM è stato addestrato riducendo la perdita di ogni token previsto, anche se quasi tutti i token previsti erano nella frase di input (solo 1 alla fine era realmente previsto) affinché il modello comprendesse meglio la lingua. In questo caso ci interessa solo che il modello sia in grado di prevedere se il modello è spam o meno, quindi ci interessa solo l'ultimo token previsto. Pertanto, è necessario modificare le nostre precedenti funzioni di perdita di addestramento per tenere conto solo di quel token. Questo è implementato in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb) come: ```python def calc_accuracy_loader(data_loader, model, device, num_batches=None): model.eval() correct_predictions, num_examples = 0, 0 if num_batches is None: num_batches = len(data_loader) else: num_batches = min(num_batches, len(data_loader)) for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: input_batch, target_batch = input_batch.to(device), target_batch.to(device) with torch.no_grad(): logits = model(input_batch)[:, -1, :] # Logits of last output token predicted_labels = torch.argmax(logits, dim=-1) num_examples += predicted_labels.shape[0] correct_predictions += (predicted_labels == target_batch).sum().item() else: break return correct_predictions / num_examples def calc_loss_batch(input_batch, target_batch, model, device): input_batch, target_batch = input_batch.to(device), target_batch.to(device) logits = model(input_batch)[:, -1, :] # Logits of last output token loss = torch.nn.functional.cross_entropy(logits, target_batch) return loss ``` Nota come per ogni batch siamo interessati solo ai **logits dell'ultimo token previsto**. ## Codice completo per la classificazione fine-tune di GPT2 Puoi trovare tutto il codice per fine-tunare GPT2 come classificatore di spam in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb) ## Riferimenti - [https://www.manning.com/books/build-a-large-language-model-from-scratch](https://www.manning.com/books/build-a-large-language-model-from-scratch) {{#include ../../banners/hacktricks-training.md}}