-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 21c4ed1
Showing
11 changed files
with
426 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# TCES |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from openai import OpenAI | ||
import csv | ||
import pandas as pd | ||
import time | ||
|
||
def load_results(csv_file): | ||
""" | ||
从CSV文件中加载续写测试结果数据。 | ||
""" | ||
data = [] | ||
with open(csv_file, 'r', encoding='utf-8') as f: | ||
reader = csv.DictReader(f) | ||
for row in reader: | ||
data.append(row) | ||
return pd.DataFrame(data) | ||
|
||
def score_with_model(text, continuation, api_key): | ||
""" | ||
使用大模型对续写的准确性和逻辑性进行评分。 | ||
""" | ||
prompt = f"原文:{text}\n\n续写:{continuation}\n\n请根据以下标准对续写内容进行评分:\n" \ | ||
"1. 续写准确性(0-10分):内容是否延续了原文主题,保持了逻辑一致性。\n" \ | ||
"2. 逻辑准确性(0-10分):续写的内容是否具备合理的逻辑性和连贯性。\n" \ | ||
"请提供准确性和逻辑性的分数以及简单的评分理由。\n" \ | ||
"输出格式如下:续写准确性:x分 \n 逻辑准确性:x分 \n 不要回复任何多余内容!" | ||
|
||
try: | ||
client = OpenAI( | ||
api_key=api_key, | ||
base_url='http://localhost:8000/v1' | ||
) | ||
response = client.chat.completions.create( | ||
model="/mnt/af0931a4-88fe-4e4d-86ed-54f4a275dad4/gutai/vllm_test/glm-4-9b-chat", | ||
messages=[{"role": "user", "content": prompt}], | ||
temperature=0 | ||
) | ||
reply = response.choices[0].message.content.strip() | ||
|
||
# 从模型回复中提取评分 | ||
accuracy_score = int(reply.split("续写准确性")[1].split(":")[1].split("分")[0].strip()) | ||
logic_score = int(reply.split("逻辑准确性")[1].split(":")[1].split("分")[0].strip()) | ||
|
||
return accuracy_score, logic_score, reply | ||
except Exception as e: | ||
print(e) | ||
return None, None, f"评分出错:{e}" | ||
|
||
def main(): | ||
# 输入API密钥 | ||
api_key = "YOUR_API_KEY" | ||
|
||
# 加载续写结果数据 | ||
csv_file = 'continuation_test_results.csv' | ||
results_df = load_results(csv_file) | ||
|
||
# 准备输出的评分结果 | ||
scores = [] | ||
|
||
# 遍历每条续写内容进行评分 | ||
for _, row in results_df.iterrows(): | ||
text = row['原始文本'] | ||
continuation = row['续写结果'] | ||
|
||
# 使用大模型进行评分 | ||
accuracy_score, logic_score, model_feedback = score_with_model(text, continuation, api_key) | ||
|
||
# 添加到评分结果列表 | ||
scores.append({ | ||
'文本编号': row['文本编号'], | ||
'模型': row['模型'], | ||
'截断百分比': row['截断百分比'], | ||
'续写准确性': accuracy_score, | ||
'逻辑准确性': logic_score, | ||
'评分反馈': model_feedback, | ||
'续写结果': continuation | ||
}) | ||
|
||
time.sleep(1) # 避免API调用过快 | ||
|
||
# 保存评分结果到CSV文件 | ||
scored_df = pd.DataFrame(scores) | ||
scored_df.to_csv('continuation_scores_with_model.csv', index=False, encoding='utf-8') | ||
print("评分完成,结果已保存到 continuation_scores_with_model.csv") | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from openai import OpenAI | ||
|
||
client = OpenAI( | ||
api_key="your zhipuai api key", | ||
#base_url="http://localhost:8000/v1/completions" | ||
base_url="http://localhost:8000/v1" | ||
) | ||
|
||
# completion = client.chat.completions.create( | ||
# model="glm-4-9b/", | ||
# messages=[ | ||
# {"role": "system", "content": "你是一个聪明且富有创造力的小说作家"}, | ||
# {"role": "user", "content": "请你作为童话故事大王,写一篇短篇童话故事,故事的主题是要永远保持一颗善良的心,要能够激发儿童的学习兴趣和想象力,同时也能够帮助儿童更好地理解和接受故事中所蕴含的道理和价值观。"} | ||
# ], | ||
# top_p=0.7, | ||
# temperature=0.9 | ||
# ) | ||
#print(completion.choices[0].message) | ||
text=""" | ||
深圳市罗湖区人民法院民 事 裁 定 书(2012)深罗法民二初字第353号原告广发银行股份有限公司深圳分行,住所地深圳市深南东路123号百货广场大厦西座19-22层。负责人杨小舟,行长。被告林镜鹏。上列原告诉被告信用卡欠款纠纷一案,本院于2012年1月9日受理后通知原告在七日内预交案件受理费,原告在规定期间内未预交又不提出缓交、减交、免交申请。依照最高人民法院《关于适用若干问题的意见》第143条以及《诉讼费用交纳办法》第二十二条的规定,裁定如下:本案按撤诉处理。审 判 长 饶 弢代理审判员 袁晶晶代理审判员 刘 娟二〇一二年一月十九日""" | ||
completion = client.completions.create(model="falv", | ||
prompt=text, | ||
temperature=0 | ||
) | ||
|
||
# print("Completion result:", completion) | ||
print("Completion result:", completion.choices[0]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from openai import OpenAI | ||
|
||
client = OpenAI( | ||
api_key="your zhipuai api key", | ||
#base_url="http://localhost:8000/v1/completions" | ||
base_url="http://localhost:8000/v1" | ||
) | ||
|
||
|
||
# completion = client.completions.create( | ||
# model="gpt-3.5-turbo", | ||
# messages=[{"role": "user", "content": prompt}], | ||
# temperature=0.1 | ||
# ) | ||
chat_completion = client.chat.completions.create( | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "财产保全措施在民事诉讼中起到什么作用?", | ||
} | ||
], | ||
# model="/mnt/af0931a4-88fe-4e4d-86ed-54f4a275dad4/gutai/vllm_test/glm-4-9b-chat", | ||
model="falv", | ||
) | ||
# print("Completion result:", completion) | ||
#print("Completion result:", completion.choices[0]) | ||
print("Completion result:", chat_completion.choices[0].message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
httpx==0.27.2 | ||
openai==1.52.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True vllm serve glm-4-9b-chat --trust_remote_code --gpu-memory-utilization 0.9999 | ||
|
||
|
||
|
||
~/LLaMA-Factory/saves/glm-4-9b/lora/pretrain | ||
|
||
model_name_or_path: /root/autodl-tmp/glm-4-9b | ||
adapter_name_or_path: saves/glm-4-9b/lora/pretrain | ||
template: glm4 | ||
finetuning_type: lora | ||
|
||
autodl-tmp/glm-4-9b | ||
|
||
|
||
|
||
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True vllm serve /mnt/af0931a4-88fe-4e4d-86ed-54f4a275dad4/gutai/vllm_test/glm-4-9b-chat \ | ||
--enable-lora \ | ||
--lora-modules falv=/mnt/af0931a4-88fe-4e4d-86ed-54f4a275dad4/gutai/vllm_test/pretrain \ | ||
--trust_remote_code \ | ||
--gpu-memory-utilization 0.8 \ | ||
--max-model-len 4096 \ | ||
--enforce-eager \ | ||
--max_num_seqs=4 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import csv | ||
import time | ||
from openai import OpenAI | ||
|
||
# 初始化客户端 | ||
client = OpenAI( | ||
api_key="your zhipuai api key", # 请替换为您的API密钥 | ||
base_url="http://localhost:8000/v1" | ||
) | ||
|
||
# 定义两个模型的名称 | ||
model1 = "falv" # 替换为模型1的名称 | ||
model2 = "/mnt/af0931a4-88fe-4e4d-86ed-54f4a275dad4/gutai/vllm_test/glm-4-9b-chat" # 替换为模型2的名称 | ||
|
||
# 定义需要续写的文本列表 | ||
texts = [ | ||
"""深圳市罗湖区人民法院民 事 裁 定 书(2012)深罗法民二初字第353号原告广发银行股份有限公司深圳分行,住所地深圳市深南东路123号百货广场大厦西座19-22层。负责人杨小舟,行长。被告林镜鹏。上列原告诉被告信用卡欠款纠纷一案,本院于2012年1月9日受理后通知原告在七日内预交案件受理费,原告在规定期间内未预交又不提出缓交、减交、免交申请。依照最高人民法院《关于适用若干问题的意见》第143条以及《诉讼费用交纳办法》第二十二条的规定,裁定如下:本案按撤诉处理。审 判 长 饶 弢代理审判员 袁晶晶代理审判员 刘 娟二〇一二年一月十九日""", | ||
# 可以在此添加更多的文本片段 | ||
"""河北省邯郸市中级人民法院民 事 裁 定 书(2012)邯市民监字第7号抗诉机关:河北省邯郸市人民检察院。申诉人(原审原告):永年县金谷粮食购销有限公司。住所地:永年县苗庄村。法定代表人:孔令勇,该公司经理。被申诉人(原审被告):周运才。申诉人永年县金谷粮食购销有限公司与被申诉人周运才劳动争议纠纷一案,河北省永年县人民法院于2011年9月5日作出(2010)永民初字第3070号民事裁定,已经发生法律效力。永年县金谷粮食购销有限公司不服,向检察机关申诉。2011年12月20日,河北省邯郸市人民检察院作出邯检民行抗(2011)96号民事抗诉书,以(2010)永民初字第3070号民事裁定适用法律确有错误为由,于2012年1月12日对本案向本院提出抗诉。依照《中华人民共和国民事诉讼法》第一百八十八条、第一百八十五条的规定,裁定如下:""", | ||
# 添加更多文本 | ||
] | ||
|
||
# 函数:向指定模型发送文本并获取续写结果 | ||
def get_continuation(model_name, text): | ||
try: | ||
response = client.completions.create( | ||
model=model_name, | ||
prompt=text, | ||
temperature=0.1, # 调整温度以控制随机性 | ||
max_tokens=200, # 设置续写的最大长度 | ||
top_p=0.9, # 核采样参数 | ||
n=1, # 生成一个续写结果 | ||
stop=None # 可以设置停止符号 | ||
) | ||
return response.choices[0].text.strip() | ||
except Exception as e: | ||
return f"出错:{e}" | ||
|
||
# 创建CSV文件并写入标题行 | ||
with open('continuation_comparison.csv', 'w', newline='', encoding='utf-8') as csvfile: | ||
csvwriter = csv.writer(csvfile) | ||
csvwriter.writerow(['原始文本', '模型1续写', '模型2续写']) | ||
|
||
# 遍历文本列表 | ||
for text in texts: | ||
print("正在处理文本:") | ||
print(text[:50] + '...') # 打印前50个字符作为预览 | ||
|
||
# 获取模型1的续写结果 | ||
continuation1 = get_continuation(model1, text) | ||
time.sleep(1) # 等待一段时间,避免请求过快 | ||
|
||
# 获取模型2的续写结果 | ||
continuation2 = get_continuation(model2, text) | ||
time.sleep(1) | ||
|
||
# 将结果写入CSV文件 | ||
csvwriter.writerow([text, continuation1, continuation2]) | ||
|
||
print("续写对比完成,结果已保存到 continuation_comparison.csv 文件中。") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import json | ||
import csv | ||
import random | ||
import time | ||
import argparse | ||
from openai import OpenAI | ||
|
||
def load_data(json_file): | ||
with open(json_file, 'r', encoding='utf-8') as f: | ||
data = json.load(f) | ||
contents = [item['content'] for item in data] | ||
return contents | ||
|
||
def select_data(contents, mode, N): | ||
if mode == 'random': | ||
if N > len(contents): | ||
N = len(contents) | ||
selected = random.sample(contents, N) | ||
elif mode == 'first': | ||
selected = contents[:N] | ||
elif mode == 'all': | ||
selected = contents | ||
else: | ||
raise ValueError("模式应为 'random'、'first' 或 'all'") | ||
return selected | ||
|
||
def truncate_text(text, percentages): | ||
truncations = [] | ||
length = len(text) | ||
for p in percentages: | ||
idx = int(length * p) | ||
truncated = text[:idx] | ||
truncations.append((p, truncated)) | ||
return truncations | ||
|
||
def get_continuation(client, model_name, prompt): | ||
try: | ||
response = client.completions.create( | ||
model=model_name, | ||
prompt=prompt, | ||
temperature=0.7, | ||
max_tokens=200, | ||
top_p=0.9, | ||
n=1 | ||
) | ||
return response.choices[0].text.strip() | ||
except Exception as e: | ||
return f"出错:{e}" | ||
|
||
def main(): | ||
# 解析命令行参数 | ||
parser = argparse.ArgumentParser(description='续写能力测试脚本') | ||
parser.add_argument('--json_file', type=str, required=True, help='输入的JSON文件路径') | ||
parser.add_argument('--mode', type=str, choices=['random', 'first', 'all'], default='all', help='数据选择模式') | ||
parser.add_argument('--num_samples', type=int, default=5, help='随机或前N条数据的N值') | ||
parser.add_argument('--output_file', type=str, default='continuation_test_results.csv', help='输出的CSV文件名') | ||
parser.add_argument('--model1', type=str, required=True, help='模型1的名称') | ||
parser.add_argument('--model2', type=str, required=True, help='模型2的名称') | ||
parser.add_argument('--api_key', type=str, required=True, help='API密钥') | ||
parser.add_argument('--api_base', type=str, default='http://localhost:8000/v1', help='API基础URL') | ||
args = parser.parse_args() | ||
|
||
# 初始化OpenAI客户端 | ||
client = OpenAI( | ||
api_key=args.api_key, | ||
base_url=args.api_base | ||
) | ||
|
||
# 读取数据 | ||
contents = load_data(args.json_file) | ||
|
||
# 根据模式选择数据 | ||
selected_contents = select_data(contents, args.mode, args.num_samples) | ||
|
||
# 定义截断百分比 | ||
percentages = [0.2, 0.4, 0.6, 0.8] | ||
|
||
# 创建CSV文件并写入标题行 | ||
with open(args.output_file, 'w', newline='', encoding='utf-8') as csvfile: | ||
fieldnames = ['文本编号', '原始文本', '截断百分比', '模型', '续写结果'] | ||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | ||
writer.writeheader() | ||
|
||
# 遍历选定的文本 | ||
for idx, text in enumerate(selected_contents): | ||
print(f"正在处理第 {idx+1} 条文本") | ||
# 对每个文本进行截断 | ||
truncations = truncate_text(text, percentages) | ||
for p, truncated_text in truncations: | ||
# 对截断的文本进行续写,分别使用两个模型 | ||
for model_name in [args.model1, args.model2]: | ||
continuation = get_continuation(client, model_name, truncated_text) | ||
# 将结果写入CSV | ||
writer.writerow({ | ||
'文本编号': f'文本 {idx+1}', | ||
'原始文本': text, | ||
'截断百分比': f'{int(p*100)}%', | ||
'模型': model_name, | ||
'续写结果': continuation | ||
}) | ||
time.sleep(1) # 避免请求过快 | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
使用示例 | ||
|
||
假设您的JSON文件名为input.json,API密钥为your_api_key,两个模型分别为model_a和model_b。 | ||
|
||
|
||
全部测试 | ||
|
||
bash | ||
|
||
python testbasepro.py --json_file input.json --mode all --model1 model_a --model2 model_b --api_key your_api_key | ||
|
||
随机抽取5条数据测试 | ||
|
||
bash | ||
|
||
python testbasepro.py --json_file input.json --mode random --num_samples 5 --model1 model_a --model2 model_b --api_key your_api_key | ||
|
||
选择前10条数据测试 | ||
|
||
bash | ||
|
||
python testbasepro.py --json_file input.json --mode first --num_samples 10 --model1 model_a --model2 model_b --api_key your_api_key | ||
|
||
CSV文件格式 | ||
|
||
|
||
|
||
|
||
python testbasepro.py --json_file pretrain_data.json --mode random --num_samples 5 --model1 falv --model2 /mnt/af0931a4-88fe-4e4d-86ed-54f4a275dad4/gutai/vllm_test/glm-4-9b-chat --api_key your_api_key |
Oops, something went wrong.