跳转至

微调方案设计文档

创建日期: 2026-04-08 设计人: 灵研 目标: 基于训练数据建立高效的微调方案


执行摘要

本文档为三个下游任务设计微调方案: 1. 意图分类器 - 5分类任务,7,491训练样本 2. 嵌入模型 - 对比学习,2,189训练对+100困难负样本 3. QA评测基准 - 检索增强生成,3,451训练样本

关键约束: - 数据污染严重:60%+样本含\r字符,必须先清洗 - 类别不平衡:多个数据集存在类别不平衡问题 - 算力限制:采用渐进式训练策略,避免一次性资源消耗过大


1. 意图分类器微调方案

1.1 任务定义

输入:用户自然语言查询 输出:5个意图类别之一 - practice_method - 练习方法询问 - scientific_basis - 科学原理询问 - theory_explanation - 理论解释 - book_search - 书籍搜索 - comparison - 比较分析

数据规模: - 训练集:7,491 样本 - 测试集:1,873 样本 - 类别分布不均衡(comparison 类仅 6.1%)

1.2 模型选择

方案A:轻量级模型(推荐)

模型:hfl/chinese-roberta-wwm-ext-tiny - 参数量:约 24M - 训练速度:快 - 显存需求:低(~2GB) - 适合快速迭代

方案B:标准模型(备选)

模型:bert-base-chinese - 参数量:约 110M - 训练速度:中等 - 显存需求:中(~8GB) - 性能可能更好

1.3 数据预处理

步骤1:清洗

import json
import re

def clean_query(text):
    """清洗查询文本"""
    # 移除 \r 字符
    text = text.replace('\r', '')
    # 移除多余空格
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def clean_dataset(input_path, output_path):
    """清洗整个数据集"""
    with open(input_path, 'r', encoding='utf-8') as f_in, \
         open(output_path, 'w', encoding='utf-8') as f_out:
        for line in f_in:
            data = json.loads(line)
            data['query'] = clean_query(data['query'])
            f_out.write(json.dumps(data, ensure_ascii=False) + '\n')

# 执行清洗
clean_dataset('/path/to/train.jsonl', '/path/to/train_clean.jsonl')
clean_dataset('/path/to/test.jsonl', '/path/to/test_clean.jsonl')

步骤2:类别平衡

comparison 类进行数据增强:

from nltk.translate import Translator

def augment_comparison(samples, target_count=1500):
    """数据增强:回译法"""
    translator = Translator()
    augmented = []

    current_count = sum(1 for s in samples if s['intent'] == 'comparison')

    if current_count >= target_count:
        return samples

    # 获取 comparison 类样本
    comparison_samples = [s for s in samples if s['intent'] == 'comparison']
    other_samples = [s for s in samples if s['intent'] != 'comparison']

    # 回译增强
    for sample in comparison_samples:
        if len(augmented) >= target_count - current_count:
            break

        # 中→英→中
        english = translator.translate(sample['query'], src='zh-CN', dest='en').text
        back_translated = translator.translate(english, src='en', dest='zh-CN').text

        augmented.append({
            'query': back_translated,
            'intent': 'comparison'
        })

    # 合并
    return other_samples + comparison_samples + augmented

1.4 训练配置

超参数

参数 说明
batch_size 32 根据显存调整
learning_rate 2e-5 BERT标准学习率
epochs 5-10 早停机制
max_seq_length 128 中文查询长度适中
warmup_steps 500 学习率预热
weight_decay 0.01 正则化

优化器:AdamW 损失函数:CrossEntropyLoss(可加权处理类别不平衡) 评估指标:Accuracy, F1-score (macro)

1.5 训练流程

from transformers import (
    BertTokenizer, BertForSequenceClassification,
    TrainingArguments, Trainer
)
from sklearn.metrics import accuracy_score, f1_score

# 加载模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext-tiny')
model = BertForSequenceClassification.from_pretrained(
    'hfl/chinese-roberta-wwm-ext-tiny',
    num_labels=5
)

# 数据预处理
def preprocess_function(examples):
    return tokenizer(
        examples['query'],
        truncation=True,
        padding='max_length',
        max_length=128
    )

# 训练参数
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    greater_is_better=True
)

# 评估函数
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='macro')

    return {'accuracy': acc, 'f1': f1}

# 训练器
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics
)

# 训练
trainer.train()

# 保存
trainer.save_model('./intent_classifier_model')

1.6 评估标准

目标指标: - 测试集准确率:> 85% - 测试集 F1 (macro):> 0.80 - comparison 类召回率:> 0.75

预期挑战: - comparison 类样本少,可能表现不佳 - 需要关注类别间的混淆矩阵 - 可能需要调整类别权重


2. 嵌入模型微调方案

2.1 任务定义

输入:文本对(anchor, positive/negative) 输出:文本嵌入向量 目标:相似文本在嵌入空间中距离近,不相似文本距离远

数据规模: - 训练正样本对:2,189 - 验证正样本对:244 - 困难负样本:100(不足,需要扩充)

数据类型: - title_content - 标题+正文段落(主要类型) - cross_domain_hard_negative - 跨域困难负样本

2.2 模型选择

方案A:Sentence-BERT(推荐)

模型:sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 - 参数量:约 118M - 输出维度:384 - 多语言支持 - 训练速度快

方案B:中文专用模型(备选)

模型:sentence-transformers/paraphrase-multilingual-mpnet-base-v2 - 参数量:约 278M - 输出维度:768 - 性能更好,但训练较慢

2.3 数据预处理

步骤1:清洗

def clean_text(text):
    """清洗文本"""
    # 移除 \r 字符
    text = text.replace('\r', '')
    # 移除多余空白
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def clean_embedding_pairs(input_path, output_path):
    """清洗嵌入对数据"""
    with open(input_path, 'r', encoding='utf-8') as f_in, \
         open(output_path, 'w', encoding='utf-8') as f_out:
        for line in f_in:
            data = json.loads(line)
            data['anchor'] = clean_text(data['anchor'])
            data['positive'] = clean_text(data['positive'])
            if 'negative' in data:
                data['negative'] = clean_text(data['negative'])
            f_out.write(json.dumps(data, ensure_ascii=False) + '\n')

步骤2:困难负样本扩充

def generate_hard_negatives(train_pairs, target_count=500):
    """生成困难负样本"""

    # 策略1:基于关键词相似度
    def keyword_similarity(text1, text2):
        words1 = set(jieba.cut(text1))
        words2 = set(jieba.cut(text2))
        return len(words1 & words2) / len(words1 | words2)

    hard_negatives = []

    for pair in train_pairs:
        anchor = pair['anchor']
        positive = pair['positive']

        # 在所有样本中找与 anchor 关键词相似但语义不同的
        for other_pair in train_pairs:
            if other_pair == pair:
                continue

            candidate = other_pair['positive']
            sim = keyword_similarity(anchor, candidate)

            # 相似度高但类别不同
            if sim > 0.3 and sim < 0.7:
                hard_negatives.append({
                    'anchor': anchor,
                    'negative': candidate,
                    'category_anchor': pair.get('category_anchor', ''),
                    'category_negative': other_pair.get('category', ''),
                    'pair_type': 'semantic_hard_negative'
                })

                if len(hard_negatives) >= target_count:
                    break

        if len(hard_negatives) >= target_count:
            break

    return hard_negatives

步骤3:负采样策略

def build_training_examples(train_pairs, hard_negatives, easy_neg_ratio=1.0):
    """构建训练样本"""

    examples = []

    # 正样本
    for pair in train_pairs:
        examples.append({
            'anchor': pair['anchor'],
            'positive': pair['positive'],
            'negative': None,  # Triplet loss会在训练时采样
            'label': 1
        })

    # 困难负样本
    for neg in hard_negatives:
        examples.append({
            'anchor': neg['anchor'],
            'positive': neg['negative'],
            'negative': None,
            'label': 0
        })

    # 简单负样本(随机采样)
    import random
    easy_count = int(len(hard_negatives) * easy_neg_ratio)

    for _ in range(easy_count):
        anchor_pair = random.choice(train_pairs)
        positive_pair = random.choice(train_pairs)

        examples.append({
            'anchor': anchor_pair['anchor'],
            'positive': positive_pair['positive'],
            'negative': None,
            'label': 0
        })

    return examples

2.4 训练配置

损失函数:MultipleNegativesRankingLoss + InfoNCE

超参数

参数 说明
batch_size 16 每个batch包含多个正负对
learning_rate 2e-5 稳定学习率
epochs 3-5 过拟合风险高
max_seq_length 256 文本较长
warmup_steps 200 预热
weight_decay 0.01 正则化

2.5 训练流程

from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader

# 加载模型
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')

# 构建训练数据
train_examples = []
for pair in train_pairs:
    train_examples.append(InputExample(
        texts=[pair['anchor'], pair['positive']],
        label=1.0
    ))

for neg in hard_negatives:
    train_examples.append(InputExample(
        texts=[neg['anchor'], neg['negative']],
        label=0.0
    ))

# DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

# 损失函数
train_loss = losses.MultipleNegativesRankingLoss(model=model)

# 训练
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    warmup_steps=200,
    output_path='./embedding_model',
    show_progress_bar=True
)

2.6 评估标准

检索任务评估: - Recall@1:第一检索结果正确的比例 - Recall@5:前5个结果中包含正确答案的比例 - Recall@10:前10个结果中包含正确答案的比例 - MRR:平均倒数排名

目标指标: - Recall@1:> 0.50 - Recall@5:> 0.75 - MRR:> 0.60


3. QA评测基准微调方案

3.1 任务定义

输入:查询 输出:答案 数据:3,451训练样本,863测试样本 约束:古籍占比82.6%,严重不平衡

3.2 架构设计

方案:检索增强生成(RAG)

查询 → 检索器 → 相关文档 → 生成器 → 答案
       ↑                 ↑
   嵌入模型(微调)   生成模型(微调)

组件1:检索器 - 使用2.2节微调的嵌入模型 - 向量数据库:FAISS 或 Milvus - Top-K检索:K=10

组件2:生成器

模型:Qwen/Qwen2-7B-Instruct 或更小版本

3.3 数据预处理

步骤1:清洗

def clean_qa_sample(sample):
    """清洗QA样本"""
    sample['query'] = sample['query'].replace('\r', '').strip()
    sample['answer'] = sample['answer'].replace('\r', '').strip()
    return sample

步骤2:类别平衡

教材 类进行扩充: - 当前:49训练样本,13测试样本 - 目标:至少500训练样本,100测试样本 - 方法:回译法 + 合成数据

3.4 检索器配置

from sentence_transformers import SentenceTransformer
import faiss

# 加载微调后的嵌入模型
retriever = SentenceTransformer('./embedding_model')

# 构建文档库
documents = []
for doc in all_documents:
    embedding = retriever.encode(doc['content'])
    documents.append({
        'id': doc['doc_id'],
        'content': doc['content'],
        'embedding': embedding
    })

# FAISS索引
dimension = 384
index = faiss.IndexFlatL2(dimension)
index.add(np.array([d['embedding'] for d in documents]))

# 检索函数
def retrieve(query, k=10):
    query_embedding = retriever.encode([query])
    distances, indices = index.search(query_embedding, k)

    results = [documents[i] for i in indices[0]]
    return results

3.5 生成器微调

训练数据格式

def format_for_training(query, retrieved_docs, answer):
    """格式化为RAG训练数据"""

    context = '\n\n'.join([
        f"文档{i+1}{doc['content']}"
        for i, doc in enumerate(retrieved_docs[:5])
    ])

    prompt = f"""根据以下文档回答问题:

{context}

问题:{query}

答案:"""

    return {
        'prompt': prompt,
        'completion': answer
    }

训练配置

参数 说明
batch_size 4 大模型,小batch
learning_rate 1e-5 小学习率
epochs 2-3 避免过拟合
max_length 512 输入长度
generation_max_length 128 输出长度

微调流程

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer

# 加载模型
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2-7B-Instruct')
model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2-7B-Instruct')

# 数据预处理
train_dataset = [... ]  # 格式化后的数据

# 训练参数
training_args = TrainingArguments(
    output_dir='./qa_model',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    warmup_steps=100,
    logging_steps=10,
    save_steps=500,
    fp16=True,  # 混合精度
)

# 训练器
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

# 训练
trainer.train()

# 保存
trainer.save_model('./qa_model')

3.6 评估标准

检索评估: - Recall@K(K=1, 5, 10) - MRR

生成评估: - Exact Match(EM):答案完全匹配 - F1 Score:词级别F1 - BLEU:n-gram相似度

目标指标: - Recall@5:> 0.80 - EM:> 0.40 - F1:> 0.60


4. 渐进式训练策略

4.1 阶段划分

阶段1:数据清洗(1天) - 清洗所有数据集(移除\r字符) - 验证清洗结果

阶段2:类别平衡(2天) - 扩充 comparison 类 - 扩充 教材 类 - 扩充困难负样本

阶段3:意图分类器训练(2天) - 训练基线模型 - 调试超参数 - 评估测试集

阶段4:嵌入模型训练(3天) - 训练基线模型 - 生成困难负样本 - 评估检索性能

阶段5:QA系统训练(5天) - 部署检索器 - 训练生成器 - 端到端评估

阶段6:集成优化(2天) - 联合调优 - 性能测试 - 文档完善

4.2 资源规划

计算资源: - GPU:1x A100 (40GB) 或 2x V100 (32GB) - CPU:16核以上 - 内存:64GB以上 - 存储:500GB以上

时间规划: - 总时长:15个工作日 - 缓冲时间:3个工作日 - 实际预留:18个工作日

4.3 风险控制

风险1:数据清洗不彻底 - 应对:增加自动化验证脚本 - 影响:中等

风险2:类别平衡效果不佳 - 应对:尝试多种增强策略(回译、合成、标注) - 影响:高

风险3:模型过拟合 - 应对:早停、dropout、数据增强 - 影响:高

风险4:算力不足 - 应对:使用云服务、降低模型规模 - 影响:中


5. 监控与迭代

5.1 训练监控

指标监控: - Loss曲线 - 验证集性能 - 各类别性能分布

工具: - TensorBoard - Weights & Biases - MLflow

5.2 迭代策略

迭代条件: - 测试集F1 < 0.75(意图分类) - Recall@5 < 0.70(检索) - F1 < 0.55(QA)

迭代方向: 1. 调整学习率 2. 增加数据增强 3. 调整模型架构 4. 调整损失函数权重


6. 验收标准

6.1 意图分类器

  • [ ] 测试集准确率 > 85%
  • [ ] 测试集 F1 (macro) > 0.80
  • [ ] comparison 类召回率 > 0.75
  • [ ] 模型推理时间 < 50ms/sample

6.2 嵌入模型

  • [ ] Recall@1 > 0.50
  • [ ] Recall@5 > 0.75
  • [ ] MRR > 0.60
  • [ ] 嵌入维度 = 384
  • [ ] 推理时间 < 100ms/sample

6.3 QA系统

  • [ ] Recall@5 > 0.80
  • [ ] EM > 0.40
  • [ ] F1 > 0.60
  • [ ] 端到端响应时间 < 2s

7. 交付物

  1. 模型文件
  2. intent_classifier_model/
  3. embedding_model/
  4. qa_model/

  5. 数据文件

  6. train_clean.jsonl
  7. test_clean.jsonl
  8. augmented_train.jsonl

  9. 代码文件

  10. train_intent_classifier.py
  11. train_embedding_model.py
  12. train_qa_system.py
  13. inference.py

  14. 文档文件

  15. MODEL_REPORT.md - 各模型详细报告
  16. EVALUATION_REPORT.md - 评估报告
  17. API_DOCUMENTATION.md - API文档

  18. 演示文件

  19. demo.ipynb - Jupyter演示
  20. README.md - 使用说明

文档结束