Skip to content

Commit

Permalink
Merge pull request baidubce#62 from baidubce/data-cookbook
Browse files Browse the repository at this point in the history
update SFT cookbook with dataset operations
  • Loading branch information
stonekim authored Nov 9, 2023
2 parents 973ae85 + c5d4737 commit 3ad55a0
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 31 deletions.
242 changes: 214 additions & 28 deletions cookbook/console-finetune/console-finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"source": [
"### 前言\n",
"\n",
"本篇主要介绍end-to-end的LLMops流程中的SFT微调->发布->推理流程,使用的SDK版本为0.1.0。建议提前熟悉预测服务相关SDK功能作为前置知识。"
"本篇主要介绍end-to-end的LLMops流程中的数据->SFT微调->发布->推理流程,使用的SDK版本为0.1.3。建议提前熟悉预测服务相关SDK功能作为前置知识。"
]
},
{
Expand All @@ -16,64 +16,250 @@
"metadata": {},
"outputs": [],
"source": [
"# 通过环境变量传递(作用于全局,优先级最低)\n",
"import os\n",
"os.environ[\"QIANFAN_ACCESS_KEY\"] = \"your_iam_ak\"\n",
"os.environ[\"QIANFAN_SECRET_KEY\"] = \"your_iam_sk\"\n",
"# 初始化百度智能云的IAM ak, sk用于bos和千帆平台的鉴权\n",
"bce_ak = \"your_iam_ak\"\n",
"bce_sk = \"your_iam_sk\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据上传\n",
"\n",
"# 通过内置函数传递(作用于全局,优先级大于环境变量)\n",
"# import qianfan\n",
"# qianfan.AccessKey(\"...\")\n",
"# qianfan.SecretKey(\"...\")\n",
"在进行SFT微调训练前,我们需要准备我们的训练数据;不同的训练任务需要准备不同类型的数据集,具体来说,对于LLM SFT训练任务,需要准备的是`已标注的、非排序的对话数据集`\n",
"推荐使用的数据格式为`jsonl`,即每一行文本都包含了一个json字符串,此json需要包含prompt,response两个字段,以下是一个示例,[下载](https://console.bce.baidu.com/api/qianfan/canghai/entity/static/sample-text-dialog-unsort-annotated.jsonl):\n",
"```\n",
"[{\"prompt\" : \"你好\", \"response\": [[\"你需要什么帮助\"]]}]\n",
"```\n",
"每一行表示一组数据,每组数据中的prompt和response加起来之和字符数不超过8000Token(包括中英文、数字、符号等),超出部分将被截断。\n",
"\n",
"# 调用相关接口时传递(仅作用于该请求,优先级最高)\n",
"# import qianfan\n",
"# task = qianfan.FineTune.create_task(ak=\"...\", sk=\"...\")"
"### Bos\n",
"\n",
"Bos是百度智能云提供的对象存储云服务,可以高效的存取数据。本篇教程基于Bos,实现本地的数据集到千帆平台数据集的导入:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 首先我们需要安装bce-python-sdk\n",
"!pip install bce-python-sdk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{metadata:{date:u'Thu, 09 Nov 2023 10:50:57 GMT',content_length:u'0',connection:u'keep-alive',content_md5:u'kbo1u82WYdCFGVLAbeqXbQ==',etag:u'91ba35bbcd9661d0851952c06dea976d',server:u'BceBos',bce_content_crc_32:u'86170999',bce_debug_id:u'JUrX2nUmpvcbaRPRMsY+uS3KUFDB1YjYIbZ9aaJtEgw16FpXFpCwVQG7+iVDt2rD4dVWAh+SmNZzCEUXGOXHiQ==',bce_flow_control_type:u'-1',bce_is_transition:u'false',bce_request_id:u'b65583f2-c7fb-4fa6-ad52-c07569270120'}}"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from baidubce.bce_client_configuration import BceClientConfiguration\n",
"from baidubce.auth.bce_credentials import BceCredentials\n",
"from baidubce.services.bos.bos_client import BosClient\n",
"\n",
"# 初始化bos配置\n",
"BosEndpoint = \"bj.bcebos.com\"\n",
"bucket_name = \"your_bucketname\"\n",
"\n",
"bos_config = BceClientConfiguration(credentials=BceCredentials(bce_ak, bce_sk), endpoint=BosEndpoint)\n",
"\n",
"file_name = \"./data/sample-text-dialog-unsort-annotated.jsonl\"\n",
"key = \"/dataset/dialog01/sample-text-dialog-unsort-annotated.jsonl\"\n",
"prefix = \"/dataset/dialog01/\"\n",
"\n",
"bos_client = BosClient(bos_config)\n",
"bos_client.put_object_from_file(bucket_name, key, file_name)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 大模型调优\n",
"千帆平台支持SFT/RLHF两种方法进行模型优化,当前SDK已支持对SFT训练微调任务的创建和管理。\n",
"SFT 相关操作使用“安全认证/Access Key”中的 Access Key ID 和 Secret Access Key 进行鉴权,无法使用获取Access Token的方式鉴权,相关 key 可以在百度智能云控制台中安全认证获取,详细流程可以参见文档。\n",
"鉴权方式除`命名`外,使用方法与预测功能使用的AK 与 SK 方式相同,提供如下三种方式:\n",
"## 大模型平台鉴权介绍:\n",
"\n",
"- 通过`环境变量`传递(作用于全局,优先级最低)\n",
"- 通过`内置函数`传递(作用于全局,优先级大于环境变量)\n",
"- 通过`调用接口`时传递(仅作用于该请求,优先级最高)"
"大模型平台和Bos同处于百度智能云下,所以可以使用同一个AK,SK来通过权限校验:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"QIANFAN_ACCESS_KEY\"] = bce_ak\n",
"os.environ[\"QIANFAN_SECRET_KEY\"] = bce_sk"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据导入\n",
"\n",
"在完成了以上从本地到bos的上传过程后,我们就开始着手创建数据集并导入之前上传到bos的数据"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.1.0'"
"QfResponse(code=200, headers={'Content-Length': '1110', 'Content-Type': 'application/json; charset=utf-8', 'Date': 'Thu, 09 Nov 2023 08:41:06 GMT', 'X-Bce-Gateway-Region': 'BJ', 'X-Bce-Request-Id': '8aef6c3b-8630-49db-823d-55a0115203d5'}, body={'log_id': 'qnxrdigwje6aiyyf', 'result': {'id': 32518, 'groupId': 26707, 'groupName': 'hi_sft_ds', 'displayName': '', 'createFrom': 0, 'bmlDatasetId': 'ds-nu54erbqtvfpgpr9', 'isBmlLocking': 0, 'easyDLProId': 0, 'versionId': 1, 'userId': 1493592, 'projectId': '', 'organizationId': '', 'visibility': 'Project', 'productId': 3, 'dataType': 4, 'projectType': 20, 'templateType': 2001, 'scene': 0, 'remark': '', 'storageType': 'usrBos', 'storageInfo': {'storageId': 'qianfanhj', 'storagePath': '/qianfanhj/dataset/dialog01/_system_/dataset/ds-nu54erbqtvfpgpr9/texts', 'storageName': 'qianfanhj', 'rawStoragePath': '/dataset/dialog01/', 'region': 'bj'}, 'importStatus': -1, 'importProgress': 0, 'importScheduledJobId': 0, 'importJobId': 0, 'exportStatus': -1, 'releaseStatus': 0, 'publishPublicStatus': '', 'publishPublicErrCode': 0, 'statsJobId': 0, 'statisticStatus': 0, 'statisticProgress': 0, 'ShouldHide': 0, 'status': 0, 'isUnique': 0, 'isConfirm': 0, 'publishStatus': 0, 'errCode': None, 'hasTitle': 0, 'displayFeatures': '', 'latestDeltaIndex': 0, 'adversarialStatus': 0, 'createTime': '2023-11-09T16:41:06.600928653+08:00', 'modifyTime': '2023-11-09T16:41:06.600940755+08:00'}, 'status': 200, 'success': True})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
"output_type": "display_data"
}
],
"source": [
"from qianfan.resources.console.data import Data\n",
"from qianfan.resources.console.consts import DataSetType, DataProjectType, DataTemplateType, DataStorageType\n",
"\n",
"# 创建数据集\n",
"ds = Data.create_bare_dataset(name=\"hi_sft_ds\", \n",
" data_set_type=DataSetType.TextOnly,\n",
" project_type=DataProjectType.Conversation,\n",
" template_type=DataTemplateType.AnnotatedConversation,\n",
" storage_type=DataStorageType.PrivateBos,\n",
" storage_id=bucket_name,\n",
" storage_path=prefix)\n",
"ds\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 使用bos进行数据导入\n",
"from qianfan.resources.console.consts import DataSourceType\n",
"\n",
"ds_id=ds[\"result\"][\"id\"]\n",
"import_resp = Data.create_data_import_task(dataset_id=ds_id,\n",
" is_annotated=True,\n",
" import_source=DataSourceType.PrivateBos,\n",
" file_url=\"bos:/{}{}\".format(bucket_name, key))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 获取数据集详情\n",
"ds_info = Data.get_dataset_info(ds_id)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### 监听导入状态\n",
"\n",
"由于数据集导入是一个耗时任务,所以我们需要等待其完成才能进行下一步的动作,这里我们通过轮询的方式简单的监听任务状态直到数据完成导入成功。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"from qianfan.resources.console.consts import DataImportStatus\n",
"while True:\n",
" # 获取数据集详情\n",
" ds_info = Data.get_dataset_info(ds_id)\n",
" import_status = ds_info[\"result\"][\"versionInfo\"][\"importStatus\"]\n",
" if import_status == DataImportStatus.Finished.value:\n",
" print(\"dataset import finish, ready to release\")\n",
" break\n",
" print(\"current_import_status\", import_status)\n",
" time.sleep(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 发布数据集\n",
"\n",
"恭喜你到达了进行SFT训练的最后一步,我们已经完成了数据集的准备,现在需要发布数据集。\n",
"> Note:\n",
"> 发布数据集后后无法再进行数据集的处理,导入或者修改!\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"current_release_status 1\n",
"current_release_status 1\n",
"current_release_status 1\n",
"current_release_status 1\n",
"dataset release finish, ready to train\n"
]
}
],
"source": [
"import qianfan\n",
"from qianfan.resources.console.consts import DataReleaseStatus\n",
"\n",
"# 发布 并监听数据集发布状态\n",
"resp = Data.release_dataset(ds_id)\n",
"\n",
"# 模型调优,模型管理与发布能力仅在qianfan>=0.1.0支持\n",
"qianfan.__version__"
"while True:\n",
" # 获取数据集详情\n",
" ds_info = Data.get_dataset_info(ds_id)\n",
" release_status = ds_info[\"result\"][\"versionInfo\"][\"releaseStatus\"]\n",
" if release_status == DataReleaseStatus.Finished.value:\n",
" print(\"dataset release finish, ready to train\")\n",
" break\n",
" print(\"current_release_status\", release_status)\n",
" time.sleep(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"至此,数据部分的准备已经完成!我们话不多说赶紧开始LLM的Finetune:"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Finetune\n",
"\n",
"目前千帆平台支持如下 SFT 相关操作:\n",
"* 创建训练任务\n",
"* 创建任务运行\n",
Expand Down Expand Up @@ -129,7 +315,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"req QfRequest(method='POST', url='/wenxinworkshop/finetune/createJob', query={}, headers={}, json_body={'taskId': 12765, 'baseTrainType': 'ERNIE-Bot-turbo', 'trainType': 'ERNIE-Bot-turbo-0725', 'trainMode': 'SFT', 'peftType': 'ALL', 'trainConfig': {'epoch': 1, 'learningRate': 2e-05, 'maxSeqLen': 4096}, 'trainset': [{'type': 1, 'id': 12563}], 'trainsetRate': 20}, retry_config=RetryConfig(retry_count=1, timeout=10, backoff_factor=0))\n"
"req QfRequest(method='POST', url='/wenxinworkshop/finetune/createJob', query={}, headers={}, json_body={'taskId': 12765, 'baseTrainType': 'ERNIE-Bot-turbo', 'trainType': 'ERNIE-Bot-turbo-0725', 'trainMode': 'SFT', 'peftType': 'ALL', 'trainConfig': {'epoch': 1, 'learningRate': 2e-05, 'maxSeqLen': 4096}, 'trainset': [{'type': 1, 'id': 32518}], 'trainsetRate': 20}, retry_config=RetryConfig(retry_count=1, timeout=10, backoff_factor=0))\n"
]
},
{
Expand Down Expand Up @@ -160,7 +346,7 @@
" \"trainset\": [\n",
" {\n",
" \"type\": 1,\n",
" \"id\": 12563\n",
" \"id\": ds_id\n",
" }\n",
" ],\n",
" \"trainsetRate\": 20\n",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[{"prompt": "请根据下面的新闻生成摘要, 内容如下:新华社受权于18日全文播发修改后的《中华人民共和国立法法》,修改后的立法法分为“总则”“法律”“行政法规”“地方性法规、自治条例和单行条例、规章”“适用与备案审查”“附则”等6章,共计105条。\n生成摘要如下:", "response": [["修改后的立法法全文公布"]]}]
[{"prompt": "请根据下面的新闻生成摘要, 内容如下:一辆小轿车,一名女司机,竟造成9死24伤。日前,深圳市交警局对事故进行通报:从目前证据看,事故系司机超速行驶且操作不当导致。目前24名伤员已有6名治愈出院,其余正接受治疗,预计事故赔偿费或超一千万元。\n生成摘要如下:", "response": [["深圳机场9死24伤续:司机全责赔偿或超千万"]]}]
[{"prompt": "请根据下面的新闻生成摘要, 内容如下:1月18日,习近平总书记对政法工作作出重要指示:2014年,政法战线各项工作特别是改革工作取得新成效。新形势下,希望全国政法机关主动适应新形势,为公正司法和提高执法司法公信力提供有力制度保障。\n生成摘要如下:", "response": [["孟建柱:主动适应形势新变化提高政法机关服务大局的能力"]]}]
[{"prompt": "请根据下面的新闻生成摘要, 内容如下:针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示,将严查央视3·15晚会曝光通信违规违法行为。工信部称,已约谈三大运营商有关负责人,并连夜责成三大运营商和所在省通信管理局进行调查,依法依规严肃处理。\n生成摘要如下:", "response": [["工信部约谈三大运营商严查通信违规"]]}]
[{"prompt": "请根据下面的新闻生成摘要, 内容如下:国家食药监管总局近日发布《食品召回管理办法》,明确:食用后已经或可能导致严重健康损害甚至死亡的,属一级召回,食品生产者应在知悉食品安全风险后24小时内启动召回,且自公告发布之日起10个工作日内完成召回。\n生成摘要如下:", "response": [["食品一级召回限24小时内启动10工作日完成"]]}]
[{"prompt": "请根据下面的新闻生成摘要, 内容如下:人民检察院刑事诉讼涉案财物管理规定明确,不得查封、扣押、冻结与案件无关的财物,严禁在立案前查封、扣押、冻结财物,对查明确实与案件无关的,应当在三日内予以解除、退还。\n生成摘要如下:", "response": [["最高检:诉讼未终结涉案财物不得上缴国库"]]}]
[{"prompt": "请根据下面的新闻生成摘要, 内容如下:聂树斌案的复查有了新进展,山东高院已正式通知聂树斌案申诉代理律师阅卷。这也是该案律师10年来首次获准阅卷。此前山东省高院复查聂树斌案合议庭成员提讯了王书金,其仍坚称石家庄西郊玉米地强奸杀人案是他所为。\n生成摘要如下:", "response": [["聂树斌案律师10年来首获准阅卷"]]}]
[{"prompt": "请根据下面的新闻生成摘要, 内容如下:近日北京市司法局、北京市检察院联合召开人民监督员选任管理方式改革工作会议,宣布了北京市第一届113名人民监督员的任命决定,并组织监督员向宪法集体宣誓。人民监督员以“第三方”去监督检察院办案,机制上做到了相互制衡。\n生成摘要如下:", "response": [["北京市第一届人民监督员向宪法集体宣誓"]]}]
[{"prompt": "请根据下面的新闻生成摘要, 内容如下:据315晚会报道,公共免费WIFI存在隐患。黑客可利用轻易盗取用户个人信息,如账号、密码等。为了保证您个人信息安全,在公共场所尽量不要使用那些不需要密码免费wifi。\n生成摘要如下:", "response": [["免费公共wifi存隐患黑客可轻易获取用户信息"]]}]
[{"prompt": "请根据下面的新闻生成摘要, 内容如下:3月12日,最高人民法院院长周强作最高人民法院工作报告。周强表示,去年各级法院再审改判刑事案件1317件,其中纠正一批重大冤假错案。对错案的发生,我们深感自责,要求各级法院深刻汲取教训……更多工作报告要点详见↓\n生成摘要如下:", "response": [["一张图看懂最高法2014年工作报告"]]}]
4 changes: 4 additions & 0 deletions docs/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
- [获取数据集导出记录](#获取数据集导出记录)
- [获取数据集导入错误详情](#获取数据集导入错误详情)

使用前需要引用入Data类
```python
from qianfan.resources.console.data import Data
```

#### **创建数据集**
可以创建数据集,需要提供数据集名称 `name` 、数据集类型 `data_set_type` 、数据集模板 `DataTemplateType` 等信息。详细方法和返回参数字段参见 [API 文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/qloic44vr)
Expand Down
2 changes: 0 additions & 2 deletions src/qianfan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


from qianfan.config import AK, SK, AccessKey, AccessToken, SecretKey, get_config
from qianfan.resources.console.data import Data
from qianfan.resources.console.finetune import FineTune
from qianfan.resources.console.model import Model
from qianfan.resources.console.service import Service
Expand All @@ -37,7 +36,6 @@
Response = QfResponse

__all__ = [
"Data",
"ChatCompletion",
"Embedding",
"Completion",
Expand Down
39 changes: 39 additions & 0 deletions src/qianfan/resources/console/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,42 @@ class DataSourceType(int, Enum):
class DataExportDestinationType(int, Enum):
PlatformBos: int = 0
PrivateBos: int = 1


class DataImportStatus(int, Enum):
NotStarted: int = -1
"""未发起导入"""
Initialized: int = 0
"""导入初始化"""
Running: int = 1
"""导入进行中"""
Finished: int = 2
"""导入完成"""
Failed: int = 3
"""导入失败"""


class DataExportStatus(int, Enum):
NotStarted: int = -1
"""未发起导出"""
Initialized: int = 0
"""导出初始化"""
Running: int = 1
"""导出进行中"""
Finished: int = 2
"""导出完成"""
Failed: int = 3
"""导出失败"""


class DataReleaseStatus(int, Enum):
NotStarted: int = -1
"""未发起发布"""
Initialized: int = 0
"""发布初始化"""
Running: int = 1
"""发布进行中"""
Finished: int = 2
"""发布完成"""
Failed: int = 3
"""发布失败"""
Loading

0 comments on commit 3ad55a0

Please sign in to comment.