forked from SupritYoung/RLHF-Label-Tool
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
144 lines (119 loc) · 5.65 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# !/usr/bin/env python3
import pandas as pd
import streamlit as st
import json
st.set_page_config(page_title="Rank List Labeler", page_icon="📌", layout="wide")
CONFIGS = {
"input_path": "./data/input_file.json",
"output_path": "./data/output_result.tsv", # 标注数据集的存放文件
"rank_list_len": 4,
}
if "configs" not in st.session_state:
st.session_state["configs"] = CONFIGS
######################## 页面配置初始化 ###########################
RANK_COLOR = ["red", "green", "blue", "orange", "violet"]
######################### 页面定义区(侧边栏) ########################
st.sidebar.title("📌 RLHF Rank标注平台")
st.sidebar.markdown(
"""
```python
用于大模型在 RLHF 阶段的数据排序标注。
```
"""
)
st.sidebar.markdown("标注思路参考自 [InstructGPT](https://arxiv.org/pdf/2203.02155.pdf) 。")
st.sidebar.markdown(
"项目 [github地址](https://github.com/SupritYoung/RLHF-Label-Tool). I need your ⭐️."
)
st.sidebar.header("📢 注意事项")
st.sidebar.write("1. 需要预先构建好数据文件,格式参见 input_file.json。")
st.sidebar.write("2. 将构造好的数据地址替换配置中的 output_path。")
st.sidebar.write("3. 可以跳转标注,重复标注会覆盖,但建议按顺序标注。")
st.sidebar.header("⚙️ Model Config")
st.sidebar.write("当前标注配置(可在源码中修改):")
st.sidebar.write(st.session_state["configs"])
label_tab, dataset_tab = st.tabs(["Label", "Dataset"])
######################### 页面定义区(标注页面) ########################
with label_tab:
with st.expander("🔍 Setting Prompts", expanded=True):
with open(CONFIGS["input_path"], "r", encoding="utf-8") as f:
data = json.load(f)
query_ids = list(data.keys())
query_index_number = st.number_input(
"当前 query 编号(点击右边的➕➖前后跳转):",
min_value=0,
max_value=len(query_ids) - 1,
value=0,
)
current_query_id = query_ids[query_index_number]
current_query = data[current_query_id]["query"]
current_history = data[current_query_id]["history"]
st.markdown(f"**Query:** {current_query}")
st.markdown("**History:**")
for history_item in current_history:
st.write(f"- {history_item[0]}")
st.write(f" {history_item[1]}")
# 排序功能
with st.expander("💡 Generate Results", expanded=True):
rank_results = []
for i in range(CONFIGS["rank_list_len"]):
# st.write(f'**Response {i + 1}:**,请标注其排名')
response_text = data[current_query_id][f"response_{i}"]
rank = st.selectbox(
f"请标注回答 {i + 1} 的排名",
[-1, 1, 2, 3, 4],
help="为当前 Response 选择排名,回答质量越好,排名越靠前。(-1代表当前句子暂未设置排名)",
)
conflict_index = next(
(idx + 1 for idx, r in enumerate(rank_results) if r == rank), None
)
if conflict_index is not None and rank != -1:
st.info(
f"当前排名[{rank}]已经被句子[{conflict_index}]占用,请先将占用排名的句子置为-1再为当前句子分配该排名。"
)
else:
rank_results.append(rank)
st.markdown(
f"<span style='color:{RANK_COLOR[i]}'>{response_text}</span>",
unsafe_allow_html=True,
)
# st.write(f'当前排名:**{rank}**')
# st.write('---')
# 排序存储功能
if -1 not in rank_results:
save_button = st.button("存储当前排序")
if save_button:
dataset_file = CONFIGS["output_path"]
df = pd.read_csv(dataset_file, delimiter="\t", dtype=str)
# print(df)
existing_ids = df["id"].tolist()
rank_texts = [
data[current_query_id][f"response_{rank - 1}"]
for rank in rank_results
]
line = [current_query_id, current_query, current_history] + rank_texts
new_row = pd.DataFrame([line], columns=df.columns)
if current_query_id in existing_ids:
df = df[df["id"] != current_query_id] # 删除已存在的行
df = pd.concat([df, new_row], ignore_index=True) # 追加新行
df.to_csv(dataset_file, index=False, sep="\t") # 保存到文件
query_index_number += 1
if query_index_number >= len(query_ids):
st.write("已完成所有查询的标注")
st.stop()
st.success(f"{current_query_id} 数据保存完成")
else:
st.error("请完成排序后再存储!", icon="🚨")
# with st.expander('🥇 Rank Results', expanded=True):
# columns = st.columns([1] * CONFIGS['rank_list_len'])
# for i, c in enumerate(columns):
# with c:
# st.write(f'Rank {i+1}:')
# if i + 1 in rank_results:
# color = RANK_COLOR[rank_results.index(i+1)] if rank_results.index(i+1) < len(RANK_COLOR) else 'white'
# st.markdown(f":{color}[{st.session_state['current_results'][rank_results.index(i+1)]}]")
######################### 页面定义区(数据集页面) #######################
with dataset_tab:
dataset_file = CONFIGS["output_path"]
df = pd.read_csv(dataset_file, delimiter="\t", dtype=str)
st.dataframe(df)