llama-factory训练RLHF-PPO模型

理论上RLHF(强化学习)效果比sft好,也更难训练。ppo有采用阶段,步骤比较多,训练速度很慢.
记录下工作中使用llama-factory调试rlhf-ppo算法流程及参数配置,希望对大家有所帮助.

llama-factory版本: 0.8.2

一 rlhf流程

ppo训练流程图如下, 会用到多个模型, 但初始化阶段, 只需提供sft和reward模型就行.
在这里插入图片描述

四个子模型用途:

  • Actor Model:演员模型,这就是我们想要训练的目标语言模型
  • Reference Model:参考模型,它的作用是在RLHF阶段给语言模型增加一些“约束”,防止语言模型训歪。我们希望训练出来的Actor模型既能达到符合人类喜好的目的,又尽量让它和SFT模型不要差异太大。即希望两个模型的输出分布尽量相似,通过与Actor Model之间的KL散度控制。
  • Critic Model:评论家模型,它的作用是预估总收益V->(t),在RLHF中,我们不仅要训练模型生成符合人类喜好的内容的能力(Actor),也要提升模型对人类喜好量化判断的能力(Critic)。这就是Critic模型存在的意义。
  • Reward Model:奖励模型,它的作用是计算即时收益R->(t) Actor/Critic Model. 在RLHF阶段是需要训练的;而Reward/Reference Model是参数冻结的。

整体算法流程如下:

  1. 训练sft模型

  2. 训练reward奖励模型

  3. 以sft模型初始化Reference和Actor模型,以奖励模型初始化Critic模型。其中,Actor与Critic模型权重可训练,Reference与Reward冻结权重,全程不更新

  4. rlhf-ppo执行过程分析(对应上图的step 3):
    在这里插入图片描述

  • 第一步,我们准备一个batch的prompts

  • 第二步,我们将这个batch的prompts喂给Actor模型,让它生成对应的responses

  • 第三步,我们把prompt+responses喂给我们的Critic/Reward/Reference模型,让它生成用于计算actor/critic loss的数据,按照强化学习的术语,我们称这些数据为经验(experiences)。

  • 第四步,我们根据这些经验,实际计算出actor/critic loss,然后更新Actor和Critic模型。

    涉及的损失函数:
    在这里插入图片描述

至此, 我们对RLHF-PPO工作原理已经有了清晰的认知. 若觉得上述文字不过瘾, 可以看我列出的几篇参考文献, 对ppo复杂的原理有深刻解读.

二 代码实践

2.1 数据准备: 需要准备sft指令微调和reward奖励模型的数据.

  • sft数据格式
[
    {
        "instruction": "<question>:查看备案有效期在今天之后的委托信息\nCREATE TABLE 委托备案信息 (序号 FLOAT,\n委托企业名称 VARCHAR(255),\n公司地址 VARCHAR(255),\n被委托企业名称 VARCHAR(255),\n委托产品名称 VARCHAR(255),\n备案日期 VARCHAR(255),\n备案有效期 VARCHAR(255));",
        "input": "",
        "output": "\nSELECT * FROM 委托备案信息 WHERE TO_DATE(备案有效期, 'YYYY-MM-DD') > NOW();\n"
    },
    {
        "instruction": "<question>:哪些镇名拥有重点旅游村?\nCREATE TABLE 镇名休闲农业园区休闲农庄重点旅游村 (镇名 VARCHAR(255),\n休闲农业园区 VARCHAR(255),\n休闲农庄 VARCHAR(255),\n重点旅游村 VARCHAR(255));",
        "input": "",
        "output": "\nSELECT DISTINCT 镇名 FROM 镇名休闲农业园区休闲农庄重点旅游村 WHERE 重点旅游村 IS NOT NULL;\n"
    },
    ...
]
  • reward数据格式固定,不能随意更改, 经过断点调试发现, 所有模型的reward数据都遵循以下格式, 其中chosen期望偏好, rejected是负向偏好.
[
    {
        "conversations": [
            {
                "from": "human",
                "value": "<question>:查看备案有效期在今天之后的委托信息\nCREATE TABLE 委托备案信息 (序号 FLOAT,\n委托企业名称 VARCHAR(255),\n公司地址 VARCHAR(255),\n被委托企业名称 VARCHAR(255),\n委托产品名称 VARCHAR(255),\n备案日期 VARCHAR(255),\n备案有效期 VARCHAR(255));"
            }
        ],
        "chosen": {
            "from": "gpt",
            "value": "\nSELECT * FROM 委托备案信息 WHERE TO_DATE(备案有效期, 'YYYY-MM-DD') > NOW();\n"
        },
        "rejected": {
            "from": "gpt",
            "value": "SELECT * FROM 委托备案信息 WHERE 备案有效期 > NOW()"
        }
    },
    {
        "conversations": [
            {
                "from": "human",
                "value": "<question>:哪些镇名拥有重点旅游村?\nCREATE TABLE 镇名休闲农业园区休闲农庄重点旅游村 (镇名 VARCHAR(255),\n休闲农业园区 VARCHAR(255),\n休闲农庄 VARCHAR(255),\n重点旅游村 VARCHAR(255));"
            }
        ],
        "chosen": {
            "from": "gpt",
            "value": "\nSELECT DISTINCT 镇名 FROM 镇名休闲农业园区休闲农庄重点旅游村 WHERE 重点旅游村 IS NOT NULL;\n"
        },
        "rejected": {
            "from": "gpt",
            "value": "SELECT DISTINCT 镇名 FROM PG库 WHERE 重点旅游村 IS NOT NULL;"
        }
    },
    ...
]

2.2 训练代码

新版llama-factory不再使用shell脚本传参, 而是通过yaml文件完成, 之后通过以下代码, 根据传入yaml文件不同执行对应的训练任务.

import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.llamafactory.train.tuner import run_exp
import yaml


def main(yaml_path_):
    with open(yaml_path_, 'r', encoding='utf-8') as f:
        param = yaml.safe_load(f)
    run_exp(param)


if __name__ == "__main__":
    #1.sft指令微调
    # yaml_path = '../examples/yblir_configs/qwen2_lora_sft.yaml'
    # 2.奖励模型训练
    # yaml_path = '../examples/yblir_configs/qwen2_lora_reward.yaml'
    # 3.rlhf-ppo训练
    yaml_path = '../examples/yblir_configs/qwen2_lora_ppo.yaml'
	
    main(yaml_path)

sft 超参: qwen2_lora_sft.yaml

# model
model_name_or_path: E:\PyCharm\PreTrainModel\qwen2_7b
#model_name_or_path: /media/xk/D6B8A862B8A8433B/data/qwen2_05b
# method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all

# dataset
dataset: train_clean
dataset_dir: ../data
template: qwen
cutoff_len: 1024
#max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 2

# output
output_dir: E:\PyCharm\PreTrainModel\qwen2_7b_sft
logging_steps: 10
save_steps: 100
plot_loss: true
overwrite_output_dir: true

# train
per_device_train_batch_size: 4
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_steps: 0.1
fp16: true

# eval
val_size: 0.1
per_device_eval_batch_size: 4
evaluation_strategy: steps
eval_steps: 100

sft训练效果:
在这里插入图片描述

rm模型训练参数: qwen2_lora_reward.yaml

# 训练奖励模型
### model
model_name_or_path: /mnt/e/PyCharm/PreTrainModel/qwen2_7b

### method
stage: rm
do_train: true
finetuning_type: lora
lora_target: all

### dataset
dataset: rw_data
dataset_dir: ../data
template: qwen
cutoff_len: 1024
max_samples: 3000
overwrite_cache: true
preprocessing_num_workers: 1

### output
output_dir: /mnt/e/PyCharm/PreTrainModel/qwen2_7b_rm
logging_steps: 10
save_steps: 100
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 2
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
fp16: true
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 2
eval_strategy: steps
eval_steps: 500

rm训练效果:

***** eval metrics *****
  epoch                   =        3.0
  eval_accuracy           =        1.0
  eval_loss               =        0.0
  eval_runtime            = 0:00:16.73
  eval_samples_per_second =     17.923
  eval_steps_per_second   =      8.961
[INFO|modelcard.py:450] 2024-06-26 23:02:36,246 >> Dropping the following result as it does not have all the necessary fields:
{'task': {'name': 'Causal Language Modeling', 'type': 'text-generation'}, 'metrics': [{'name': 'Accuracy', 'type': 'accuracy', 'value': 1.0}]}

在这里插入图片描述

sft训练完成后,要先merge才能进行下一步ppo训练.
merge代码及配置文件:

# -*- coding: utf-8 -*-
# @Time    : 2024/5/17 23:21
# @Author  : yblir
# @File    : lyb_merge_model.py
# explain  :
# =======================================================
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import yaml

from src.llamafactory.train.tuner import export_model

if __name__ == "__main__":
    with open('../examples/yblir_configs/qwen2_lora_sft_merge.yaml', 'r', encoding='utf-8') as f:
        param = yaml.safe_load(f)

    export_model(param)

qwen2_lora_sft_merge.yaml

# Note: DO NOT use quantized model or quantization_bit when merging lora adapters

# model
model_name_or_path: E:\PyCharm\PreTrainModel\qwen2_7b
adapter_name_or_path: E:\PyCharm\PreTrainModel\qwen2_7b_sft
#model_name_or_path: /media/xk/D6B8A862B8A8433B/data/qwen2_05b
#adapter_name_or_path: /media/xk/D6B8A862B8A8433B/data/qwen2_15b_rw
template: qwen
finetuning_type: lora

# export
export_dir: /mnt/e/PyCharm/PreTrainModel/qwen2_7b_sft_merge
export_size: 2
export_device: cpu
# 为true,保存为safetensors格式
export_legacy_format: true

ppo训练: 使用merge后的sft模型. reward_model参数是rm训练的lora参数, 这样做的好处是节约显存, 不然24G显存根本没法训练7B大小的模型. 而弊端就是, 四个子模型的基座是同一个模型. 只有全量的full训练才能选择不同的模型. 目前看, 都用同一个模型也没发现什么问题.

ppo涉及数据采样, 训练很慢, 4090显卡, 对于以下参数, 显存占用约18G, 耗时约4.5小时才训练完.

### model
model_name_or_path: /mnt/e/PyCharm/PreTrainModel/qwen2_7b_sft_merge
reward_model: /mnt/e/PyCharm/PreTrainModel/qwen2_7b_rm

### method
stage: ppo
do_train: true
finetuning_type: lora
lora_target: all

### dataset
# dataset: identity,alpaca_en_demo
dataset: train_clean
dataset_dir: ../data
template: qwen
cutoff_len: 1024
max_samples: 2000
overwrite_cache: true
preprocessing_num_workers: 1

### output
output_dir: /mnt/e/PyCharm/PreTrainModel/qwen2_7b_sql_ppo_1_batch
logging_steps: 10
save_steps: 100
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
num_train_epochs: 2.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
fp16: true
ddp_timeout: 180000000

### generate
max_new_tokens: 512
top_k: 0
top_p: 0.9

ppo训练效果
在这里插入图片描述

ppo训练后进行推理, 使用merge后的sft模型进行的ppo的推理的基座模型, ppo训练的finetuning_type是lora, 因此最终保存的也是lora参数,

lyb_qwen_sft_predict.yaml

# model
model_name_or_path: E:\PyCharm\PreTrainModel\qwen2_7b_sft_merge
adapter_name_or_path: E:\PyCharm\PreTrainModel\qwen2_7b_sql_ppo_1_batch

stage: sft
finetuning_type: lora
#lora_target: all
#quantization_bit: 8

#infer_backend: vllm

# dataset
template: qwen
#cutoff_len: 1024

一个简单的推理代码, 注意模型的输入数据, 与ppo训练时入参格式一样, 本文ppo训练使用的数据与sft是同一份.

# -*- coding: utf-8 -*-
# @Time    : 2024/6/16 20:50
# @Author  : yblir
# @File    : lyb_lora_inference.py
# explain  : 
# =======================================================
import yaml
import json
from loguru import logger
import time
import sys
from src.llamafactory.chat import ChatModel

if __name__ == '__main__':
    with open('../examples/yblir_configs/lyb_qwen_sft_predict.yaml', 'r', encoding='utf-8') as f:
        param = yaml.safe_load(f)

    chat_model = ChatModel(param)

    with open('../data/tuning_sample.json', 'r', encoding='utf-8') as f:
        data = json.load(f)

    # 预热
    messages = [{"role": "user", "content": data[0]['instruction']}]
    _ = chat_model.chat(messages)

    predict_1000 = []
    total_time = 0
    for i, item in enumerate(data):
        messages = [{"role": "user", "content": item['instruction']}]
        t1 = time.time()
        res = chat_model.chat(messages)
        total_time += time.time() - t1
        predict_1000.append(res[0].response_text)
        #print('-------------------------------------------------')
        print(i,'->',res[0].response_text)
        # sys.exit()
        if (i + 1) % 10 == 0:
            # logger.info(f'当前完成: {i + 1}')
            sys.exit()
        if i + 1 == 300:
            break

    # json_data = json.dumps(predict_1000, indent=4, ensure_ascii=False)
    # with open('saves2/qwen_7b_chat_lora_merge_vllm.json', 'w', encoding='utf-8') as f:
    #     f.write(json_data)

    logger.success(f'写入完成, 总耗时:{total_time},平均耗时: {round((total_time / 300), 5)} s')

sft与PPO部分推理结果比较, 具体指标要把sql放到数据库去跑一遍才知道, 结果在公司内网, 不再此列出了.
在这里插入图片描述

三 总结

除了ppo, dpo(Direct Preference Optimization:直接偏好优化)也是一种常见的调优手段, 不过多篇paper研究证明性能不如PPO, 在计算资源不足的情况下DPO也是个不过的选择,因为不需要训练奖励模型, 而且训练速度快,效果也比较稳定, 不像PPO那样很容易训崩.
其他LLM偏好对齐训练技术还有ORPO,IPO,CPO以及效果看起来很棒的KTO.
还有最新发表的RLOO,看起来比PPO更好更易训练.
在这里插入图片描述

这个领域发展太快, 脑子快不够用了.
在这里插入图片描述

四 参考文献

https://blog.csdn.net/sinat_37574187/article/details/138200789
https://blog.csdn.net/2301_78285120/article/details/134888984
https://blog.csdn.net/qq_27590277/article/details/132614226
https://blog.csdn.net/qq_35812205/article/details/133563158

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/766165.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】—Xshell、Xftp安装

文章目录 前言一、下载Xshell、Xftp二、安装Xshell三、使用XShell连接Linux服务器四、修改windows的主机映射文件&#xff08;hosts文件&#xff09;五、远程连接hadoop102/hadoop103/hadoop104服务器六、安装Xftp 前言 XShell远程管理工具&#xff0c;可以在Windows界面下来访…

Springboot整合RedisTemplate以及业务工具类示例

docker安装Redis参考我另一篇博客Docker安装Redis及持久化 一、Get-Started 依赖 <!-- https://mvnrepository.com/artifact/org.springframework.boot/spring-boot-starter-data-redis --> <dependency><groupId>org.springframework.boot</groupId>…

Java_多线程:线程池

1、线程池优点&#xff1a; 降低资源消耗&#xff1a;通过重复利用已创建的线程降低线程创建和销毁造成的消耗。提高响应速度&#xff1a;当任务到达时&#xff0c;任务可以不需要等到线程创建就能立即执行。提高线程的可管理性&#xff1a;线程是稀缺资源&#xff0c;如果无限…

Django 多对多关系

多对多关系作用 Django 中&#xff0c;多对多关系模型的作用主要是为了表示两个模型之间的多对多关系。具体来说&#xff0c;多对多关系允许一个模型的实例与另一个模型的多个实例相关联&#xff0c;反之亦然。这在很多实际应用场景中非常有用&#xff0c;比如&#xff1a; 博…

因版本冲突导致logback的debug日志不打印

因框架调整&#xff0c;降级了logback的版本号&#xff0c;由1.3.12降级为1.2.11&#xff08;因框架限制&#xff0c;只能采用1.2版本&#xff09;&#xff0c;降级后发现debug日志无法打印出来&#xff0c;logback.xml配置文件不生效。后排查发现是与slf4j的版本兼容问题 依赖…

搜维尔科技:数据手套为什么要选择SenseGlove

了解 SenseGlove SenseGlove 是一支由电子工程师、触觉研究人员和计算机视觉专家、XR 开发人员、UX 设计师和产品创新者组成的科幻爱好者团队&#xff0c;他们拥有丰富人类能力和赋予 Metaverse 意义的技能和热情。 推进触觉技术是我们实现这一目标的方式。 公司及产品背景 S…

基于Hadoop平台的电信客服数据的处理与分析③项目开发:搭建Kafka大数据运算环境---任务12:安装Kafka

任务描述 任务内容为安装和配置Kafka集群。 任务指导 Kafka是大数据生态圈中常用的消息队列框架 具体安装步骤如下&#xff1a; 1. 解压缩Kafka的压缩包 2. 配置Kafka的环境变量 3. 修改Kafka的配置文件&#xff0c;Kafka的配置文件存放在Kafka安装目录下的config中 4. 验证…

【融合ChatGPT等AI模型】Python-GEE遥感云大数据分析、管理与可视化及多领域案例应用

随着航空、航天、近地空间遥感平台的持续发展&#xff0c;遥感技术近年来取得显著进步。遥感数据的空间、时间、光谱分辨率及数据量均大幅提升&#xff0c;呈现出大数据特征。这为相关研究带来了新机遇&#xff0c;但同时也带来巨大挑战。传统的工作站和服务器已无法满足大区域…

JDK动态代理-AOP编程

AOPTest.java&#xff0c;相当于main函数&#xff0c;经过代理工厂出来的Hello类对象就不一样了&#xff0c;这是Proxy.newProxyInstance返回的对象&#xff0c;会hello.addUser会替换为invoke函数&#xff0c;比如这里的hello.addUser("sun", "13434");会…

【驱动篇】龙芯LS2K0300之红外驱动

实验目标 编写HX1838红外接收器驱动&#xff0c;根据接收的波形脉冲解码红外按键键值 模块连接 模块连接&#xff1a;VCC接Pin 2&#xff0c;GND接Pin1&#xff0c;DATA接Pin16 驱动代码 HX1838 GPIO初始化&#xff0c;申请中断&#xff0c;注意&#xff1a;GPIO48默认是给…

vscode语言模式

1.背景 写vue3ts项目的时候&#xff0c;用到了volar插件&#xff0c;在单文件使用的时候&#xff0c;鼠标悬浮在代码上面会有智能提示&#xff1b; 但是最近volar插件提示被弃用了&#xff0c;然后我按照它的官方提示&#xff0c;安装了Vue-official扩展插件&#xff0c;但是…

Vue3 特点以及优势-源码解剖

Vue3 特点以及优势-Vue3.4源码解剖 Vue3 特点以及优势 1.声明式框架 命令式和声明式区别 早在 JQ 的时代编写的代码都是命令式的&#xff0c;命令式框架重要特点就是关注过程声明式框架更加关注结果。命令式的代码封装到了 Vuejs 中&#xff0c;过程靠 vuejs 来实现 声明式代…

剑神诀_单机架设_无需虚拟机_小白专用

前言 今天给大家带来一款单机游戏的架设&#xff1a;剑神诀&#xff0c;一键端 无需虚拟机 如今市面上的资源参差不齐&#xff0c;大部分的都不能运行&#xff0c;本人亲自测试&#xff0c;运行视频如下&#xff1a; 剑神诀 搭建教程 此游戏架设不需要安装虚拟机&#xff0c;…

爬虫cookie是什么意思

“爬虫 cookie”指的是网络爬虫在访问网站时所使用的cookie&#xff0c;网络爬虫是一种自动化程序&#xff0c;用于在互联网上收集信息并进行索引&#xff0c;这些信息可以用于搜索引擎、数据分析或其他目的。 本教程操作系统&#xff1a;Windows10系统、Dell G3电脑。 “爬虫…

SpringBoot 项目整合 MyBatisPlus 框架,附带测试示例

文章目录 一、创建 SpringBoot 项目二、添加 MyBatisPlus 依赖三、项目结构和数据库表结构四、项目代码1、application.yml2、TestController3、TbUser4、TbUserMapper5、TestServiceImpl6、TestService7、TestApplication8、TbUserMapper.xml9、MyBatisPlusTest 五、浏览器测试…

新鲜出炉!恭喜这 5 位同学中选 NebulaGraph 社区 2024 开源之夏项目!

开源之夏是中国科学院软件研究所发起的“开源软件供应链点亮计划”系列暑期活动&#xff0c;旨在鼓励高校学生积极参与开源软件的开发维护&#xff0c;促进优秀开源软件社区的蓬勃发展。活动联合各大开源社区&#xff0c;针对重要开源软件的开发与维护提供项目开发任务&#xf…

stm32学习笔记---USART串口外设(理论部分)

目录 USART简介 USART的框图 串口的引脚 USART的基本结构 数据帧 起始位侦测 数据采样 波特率发生器 USD转串口模块的原理图 声明&#xff1a;本专栏是本人跟着B站江科大的视频的学习过程中记录下来的笔记&#xff0c;我之所以记录下来是为了方便自己日后复习。如果你…

个人微信二次开发

​ 由于自身在机器人方面滚爬多年&#xff0c;现在收藏几个宝藏机器人 推荐一下自己常用的机器人&#xff1a; 适合有技术开发的公司&#xff0c;可以自主开发所需要的功能&#xff01;十分齐全 测试问文档&#xff1a;https://www.wkteam.cn/ 有需要的兄弟可以看一下&#…

手写一个基于SpringBoot的MVC架构,默认实现CRUD和导入导出功能

文章目录 前言正文一、项目结构二、技术点三、部分核心代码3.1 core-tool 中的核心代码3.1.1 所有实体的通用父类 SuperEntity3.1.2 所有枚举的父接口 BaseEnum3.1.3 所有业务异常的父接口 BaseException 3.2 mvc-tool 中的核心代码3.2.1 CrudController 接口定义3.2.2 默认的C…

手写一个类似@RequestParam的注解(用来接收请求体的参数)

一、本文解决的痛点 按照大众认为的开发规范&#xff0c;一般post类型的请求参数应该传在请求body里面。但是我们有些post接口只需要传入一个字段&#xff0c;我们接受这种参数就得像下面这样单独创建一个类&#xff0c;类中再添加要传入的基本类型字段&#xff0c;配合Reques…
最新文章