Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
liucongg committed Apr 7, 2023
1 parent bbbd2ba commit a15b8eb
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ prompt_text:你现在是一个信息抽取模型,请你帮我抽取出关系
- 模型在指定任务上微调之后,并没有丧失原有能力,例如生成“帮我写个快排算法”,依然可以生成-快排代码。
- 由于大模型微调都采用大量instruction进行模型训练,仅采用单一的指令进行微调时,对原来其他的指令影响不大,因此并没导致原来模型的能力丧失。

很多同学在微调后出现了灾难性遗忘现象,但本项目的训练代码并没有出现,对“翻译任务”、“代码任务”、“问答任务”进行测试,采用freeze模型,具体测试效果如下:
很多同学在微调后出现了灾难性遗忘现象,但本项目的训练代码并没有出现,对“翻译任务”、“代码任务”、“问答任务”进行测试,采用freeze模型,可以用test_forgetting.py进行测试,具体测试效果如下:
<details><summary><b>翻译任务</b></summary>

![](images/ft_fanyi.png)
Expand Down
69 changes: 69 additions & 0 deletions test_forgetting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding:utf-8 -*-
# @project: ChatGLM-Finetuning
# @filename: test_forgetting
# @author: 刘聪NLP
# @zhihu: https://www.zhihu.com/people/LiuCongNLP
# @contact: logcongcong@gmail.com
# @time: 2023/4/7 15:00
"""
文件说明:
"""
import torch
from modeling_chatglm import ChatGLMForConditionalGeneration
from tokenization_chatglm import ChatGLMTokenizer
import argparse


def set_args():
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='0', type=str, help='')
parser.add_argument('--model_dir', default="/data/work/lcong/ChatGPT/LLMFTProj/output_dir_freeze/global_step-2160/",
type=str, help='')
parser.add_argument('--max_len', type=int, default=2048, help='')
parser.add_argument('--max_src_len', type=int, default=450, help='')
parser.add_argument('--top_p', type=float, default=0.7, help='')
parser.add_argument('--do_sample', type=bool, default=True, help='')
parser.add_argument('--num_return_sequences', type=int, default=1, help='')
return parser.parse_args()


def predict_one_sample(model, tokenizer, args, text):
max_tgt_len = args.max_len - args.max_src_len - 3
with torch.no_grad():
input_ids = tokenizer.encode(text, max_length=args.max_src_len, truncation=True)
input_ids = torch.tensor([input_ids]).to("cuda:{}".format(args.device))
generation_kwargs = {
"min_length": 5,
"max_new_tokens": max_tgt_len,
"top_p": args.top_p,
"temperature": 0.95,
"do_sample": args.do_sample,
"num_return_sequences": args.num_return_sequences,
}
response = model.generate(input_ids, **generation_kwargs)

res = []
for i_r in range(generation_kwargs["num_return_sequences"]):
outputs = response.tolist()[i_r][input_ids.shape[1]:]
r = tokenizer.decode(outputs).replace("<eop>", "")
res.append(r)
return res[0]


def main():
args = set_args()
model = ChatGLMForConditionalGeneration.from_pretrained(args.model_dir)
model.half().to("cuda:{}".format(args.device))
model.eval()
tokenizer = ChatGLMTokenizer.from_pretrained(args.model_dir)

print('开始进行问答,输入CTRL + C,则退出')
while True:
text = input("问:")
pre_res = predict_one_sample(model, tokenizer, args, text)
print("答:{}".format(pre_res))


if __name__ == '__main__':
main()

0 comments on commit a15b8eb

Please sign in to comment.