hacktricks/src/AI/AI-llm-architecture/7.1.-fine-tuning-for-classification.md

115 lines
6.6 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 7.1. Fine-Tuning for Classification
{{#include ../../banners/hacktricks-training.md}}
## 什么是
微调是将一个**预训练模型**的过程,该模型已经从大量数据中学习了**通用语言模式**,并**调整**它以执行**特定任务**或理解特定领域的语言。这是通过在一个较小的、特定任务的数据集上继续训练模型来实现的,使其能够调整参数以更好地适应新数据的细微差别,同时利用其已经获得的广泛知识。微调使模型能够在专业应用中提供更准确和相关的结果,而无需从头开始训练一个新模型。
> [!TIP]
> 由于预训练一个“理解”文本的LLM相当昂贵因此通常更容易和便宜地微调开源的预训练模型以执行我们希望其执行的特定任务。
> [!TIP]
> 本节的目标是展示如何微调一个已经预训练的模型因此LLM将选择给出**给定文本被分类到每个给定类别的概率**(例如,文本是否为垃圾邮件)。
## 准备数据集
### 数据集大小
当然为了微调模型您需要一些结构化数据来专门化您的LLM。在[https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb)中提出的示例中GPT2被微调以检测电子邮件是否为垃圾邮件使用的数据来自[https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip](https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip)_。
该数据集包含的“非垃圾邮件”示例远多于“垃圾邮件”示例,因此本书建议**仅使用与“垃圾邮件”相同数量的“非垃圾邮件”示例**因此从训练数据中删除所有额外示例。在这种情况下每种情况有747个示例。
然后,**70%**的数据集用于**训练****10%**用于**验证****20%**用于**测试**。
- **验证集**在训练阶段用于微调模型的**超参数**并做出关于模型架构的决策,有效地通过提供模型在未见数据上的表现反馈来帮助防止过拟合。它允许在不偏倚最终评估的情况下进行迭代改进。
- 这意味着尽管此数据集中包含的数据不直接用于训练,但它用于调整最佳**超参数**,因此该集不能像测试集那样用于评估模型的性能。
- 相比之下,**测试集**仅在模型完全训练并完成所有调整后使用;它提供了对模型在新未见数据上泛化能力的无偏评估。对测试集的最终评估提供了模型在实际应用中预期表现的现实指示。
### 条目长度
由于训练示例期望条目(在这种情况下为电子邮件文本)具有相同的长度,因此决定通过添加`<|endoftext|>`的ID作为填充使每个条目与最长的条目一样大。
### 初始化模型
使用开源的预训练权重初始化模型进行训练。我们之前已经做过这个,并按照[https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb)的说明,您可以轻松做到这一点。
## 分类头
在这个特定示例中预测文本是否为垃圾邮件我们并不关心根据GPT2的完整词汇进行微调而只希望新模型能够判断电子邮件是否为垃圾邮件1或不是0。因此我们将**修改最终层**使其提供每个词汇的概率改为仅提供是否为垃圾邮件的概率就像一个包含2个单词的词汇
```python
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
```
## 参数调整
为了快速微调,最好只微调一些最终参数,而不是所有参数。这是因为已知较低层通常捕捉基本的语言结构和适用的语义。因此,**通常只微调最后几层就足够且更快**。
```python
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
```
## Entries to use for training
在之前的部分LLM通过减少每个预测标记的损失进行训练尽管几乎所有预测的标记都在输入句子中只有最后一个是真正预测的以便模型更好地理解语言。
在这种情况下,我们只关心模型是否能够预测该模型是垃圾邮件,因此我们只关心最后一个预测的标记。因此,需要修改我们之前的训练损失函数,仅考虑该标记。
这在 [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb) 中实现为:
```python
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
```
注意,对于每个批次,我们只对**最后一个预测的标记的logits**感兴趣。
## 完整的GPT2微调分类代码
您可以在[https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb)找到所有微调GPT2以成为垃圾邮件分类器的代码。
## 参考文献
- [https://www.manning.com/books/build-a-large-language-model-from-scratch](https://www.manning.com/books/build-a-large-language-model-from-scratch)
{{#include ../../banners/hacktricks-training.md}}