-
Notifications
You must be signed in to change notification settings - Fork 0
/
zeroshotcot.py
229 lines (184 loc) · 8.86 KB
/
zeroshotcot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import os
import re
import tqdm
import retrying
import openai
from vllm import LLM, SamplingParams
from .model import AbstractModel
from .utils import create_batches, get_answer_cleansing_prompt
FULL_PRMOPT = """Q: {{input}}
A: {{instruction}}
"""
class ZeroshotCoTModel(AbstractModel):
def __init__(self, args, logger):
super().__init__(args, logger)
if self.args.task_model in ["mistralai/Mistral-7B-Instruct-v0.2", "mosaicml/mpt-7b-instruct", "01-ai/Yi-6B"]:
logger.info("Setting up `{}` model in vLLM".format(self.args.task_model))
self.llm = LLM(model=self.args.task_model)
def load_prompt(self):
self.full_prompt = FULL_PRMOPT
# @retrying.retry(
# stop_max_attempt_number=100, # Maximum number of attempts
# wait_exponential_multiplier=1000, wait_exponential_max=10000,
# # Wait 2^x * 1000 milliseconds between each retry, up to 10 seconds, then 10 seconds afterwards
# retry_on_exception=lambda e: isinstance(e, Exception)
# )
def _request_instruct(self, queries):
OPENAI_INSTRUCT_MODELS = {"openai_td003": "text-davinci-003", "openai_gpt35_turbo_instruct": "gpt-3.5-turbo-instruct"}
engine = OPENAI_INSTRUCT_MODELS[self.args.task_model]
responses = openai.Completion.create(
engine=engine,
prompt=queries,
max_tokens=400,
temperature=0.0,
n=1,
top_p=1,
stream=False,
stop=None
)
assert len(responses.choices) == len(queries)
all_reasoning = [choice.text.strip() for choice in responses.choices]
new_queries = []
for query, choice in zip(queries, responses.choices):
new_query = query + choice.text.strip() + " \n" + get_answer_cleansing_prompt(self.args)
new_queries.append(new_query)
responses = openai.Completion.create(
engine=engine,
prompt=new_queries,
max_tokens=100 if self.args.subtask == "word_sorting" else 10, # word sorting answer is significantly longer
temperature=0.0,
n=1,
top_p=1,
stream=False,
stop=None
)
assert len(responses.choices) == len(queries)
all_raw_output = [choice.text.strip() for choice in responses.choices]
return all_reasoning, all_raw_output
@retrying.retry(
stop_max_attempt_number=100, # Maximum number of attempts
wait_exponential_multiplier=1000, wait_exponential_max=10000,
# Wait 2^x * 1000 milliseconds between each retry, up to 10 seconds, then 10 seconds afterwards
retry_on_exception=lambda e: isinstance(e, Exception)
)
def _request_chat(self, queries):
OPENAI_CHAT_MODELS = {"openai_gpt35_turbo": "gpt-3.5-turbo-1106", "openai_gpt4_turbo": "gpt-4-0125-preview"}
engine = OPENAI_CHAT_MODELS[self.args.task_model]
# client = OpenAI()
all_reasoning, all_raw_output = [], []
for query in queries:
response = openai.ChatCompletion.create(
model=engine,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": query},
],
max_tokens=400,
temperature=0.0,
n=1,
top_p=1,
stream=False,
stop=None
)
reasoning = response.choices[0].message.content
all_reasoning.append(reasoning)
# Second API call to do answer cleansing
response = openai.ChatCompletion.create(
model=engine,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": query},
{"role": "assistant", "content": reasoning},
{"role": "user", "content": get_answer_cleansing_prompt(self.args)},
],
max_tokens=10,
temperature=0.0,
n=1,
top_p=1,
stream=False,
stop=None
)
final_answer = response.choices[0].message.content
all_raw_output.append(final_answer)
return all_reasoning, all_raw_output
def _request_vllm(self, queries):
all_reasoning, all_raw_output, new_queries = [], [], []
# first call
sampling_params = SamplingParams(temperature=0.0,
max_tokens=400, n=1,
top_p=1, stop=None)
responses = self.llm.generate(queries, sampling_params)
for query, response in zip(queries, responses):
reasoning = response.outputs[0].text
all_reasoning.append(reasoning)
new_query = query + reasoning + " \n" + get_answer_cleansing_prompt(self.args)
new_queries.append(new_query)
# second call
responses = self.llm.generate(new_queries, sampling_params)
max_tokens = 100 if self.args.subtask == "word_sorting" else 10
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens, n=1,
top_p=1, stop=None)
responses = self.llm.generate(new_queries, sampling_params)
for response in responses:
final_answer = response.outputs[0].text
all_raw_output.append(final_answer)
return all_reasoning, all_raw_output
def _request(self, queries):
if self.args.task_model in ["openai_td003", "openai_gpt35_turbo_instruct"]:
return self._request_instruct(queries) # legacy openai completion endpoints
elif self.args.task_model in ["openai_gpt35_turbo", "openai_gpt4_turbo"]:
return self._request_chat(queries) # chat endpoints which may needs additional formating
elif self.args.task_model in ["mistralai/Mistral-7B-Instruct-v0.2", "mosaicml/mpt-7b-instruct", "01-ai/Yi-6B"]:
return self._request_vllm(queries)
else:
raise NotImplementedError
def _run_batch(self, instruction, batch):
all_inputs = batch["input"].tolist()
all_queries = []
prompt = self.full_prompt.replace("{{instruction}}", instruction)
for i in range(len(all_inputs)):
query = prompt.replace("{{input}}", all_inputs[i])
all_queries.append(query)
# response = self._request_substrate(all_queries)
# all_raw_output = [choice["text"] for choice in response["choices"]]
all_reasoning, all_raw_output = self._request(all_queries)
return all_reasoning, all_raw_output
def run(self, instruction, data):
# if self.args.task_model in ["mistralai/Mistral-7B-Instruct-v0.2", "mosaicml/mpt-7b-instruct", "01-ai/Yi-6B"]:
# openai.default_headers = {"x-foo": "true"}
# openai.api_key = "EMPTY"
# openai.api_base = "http://localhost:8001/v1"
data = data.copy()
batch_size = 10 if self.args.task_model in ["openai_td003", "openai_gpt35_turbo_instruct"] else len(data)
batches = create_batches(data, batch_size=batch_size)
all_raw_output = []
all_reasoning = []
for batch in tqdm.tqdm(batches, desc="Batches"):
batch_reasoning, batch_raw_output = self._run_batch(instruction, batch)
all_reasoning += batch_reasoning
all_raw_output += batch_raw_output
data["reasoning"] = all_reasoning
data["raw_output"] = all_raw_output
return data
def run_till_enough_errors(self, instruction, data, task, target_n_errors):
data = data.copy().sample(frac=1.0, random_state=42) # random shuffle
batch_size = 10 if self.args.task_model in ["openai_td003", "openai_gpt35_turbo_instruct"] else 20
batches = create_batches(data, batch_size=batch_size)
all_raw_output = []
all_reasoning = []
n_done = 0
for batch in tqdm.tqdm(batches, desc="Batches"):
batch_reasoning, batch_raw_output = self._run_batch(instruction, batch)
n_done += len(batch)
all_reasoning += batch_reasoning
all_raw_output += batch_raw_output
data_done = data.head(n_done).copy()
data_done["reasoning"] = all_reasoning
data_done["raw_output"] = all_raw_output
data_done, _ = task.evaluate(data_done)
n_errors = (data_done['score'] == 0.0).sum()
if n_errors >= target_n_errors:
self.logger.info("Found {} model mistakes by running {} examples".format(n_errors, n_done))
break
return data_done